| | package templates |
| |
|
| | import ( |
| | "encoding/json" |
| | "fmt" |
| | "strings" |
| |
|
| | "github.com/mudler/LocalAI/core/config" |
| | "github.com/mudler/LocalAI/core/schema" |
| | "github.com/mudler/LocalAI/pkg/functions" |
| | "github.com/mudler/xlog" |
| | ) |
| |
|
| | |
| | |
| | |
| | type PromptTemplateData struct { |
| | SystemPrompt string |
| | SuppressSystemPrompt bool |
| | Input string |
| | Instruction string |
| | Functions []functions.Function |
| | MessageIndex int |
| | ReasoningEffort string |
| | Metadata map[string]string |
| | } |
| |
|
| | type ChatMessageTemplateData struct { |
| | SystemPrompt string |
| | Role string |
| | RoleName string |
| | FunctionName string |
| | Content string |
| | MessageIndex int |
| | Function bool |
| | FunctionCall interface{} |
| | LastMessage bool |
| | } |
| |
|
| | const ( |
| | ChatPromptTemplate TemplateType = iota |
| | ChatMessageTemplate |
| | CompletionPromptTemplate |
| | EditPromptTemplate |
| | FunctionsPromptTemplate |
| | ) |
| |
|
| | type Evaluator struct { |
| | cache *templateCache |
| | } |
| |
|
| | func NewEvaluator(modelPath string) *Evaluator { |
| | return &Evaluator{ |
| | cache: newTemplateCache(modelPath), |
| | } |
| | } |
| |
|
| | func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) { |
| | template := "" |
| |
|
| | |
| | if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { |
| | template = config.Model |
| | } |
| |
|
| | switch templateType { |
| | case CompletionPromptTemplate: |
| | if config.TemplateConfig.Completion != "" { |
| | template = config.TemplateConfig.Completion |
| | } |
| | case EditPromptTemplate: |
| | if config.TemplateConfig.Edit != "" { |
| | template = config.TemplateConfig.Edit |
| | } |
| | case ChatPromptTemplate: |
| | if config.TemplateConfig.Chat != "" { |
| | template = config.TemplateConfig.Chat |
| | } |
| | case FunctionsPromptTemplate: |
| | if config.TemplateConfig.Functions != "" { |
| | template = config.TemplateConfig.Functions |
| | } |
| | } |
| |
|
| | if template == "" { |
| | return in.Input, nil |
| | } |
| |
|
| | return e.cache.evaluateTemplate(templateType, template, in) |
| | } |
| |
|
| | func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { |
| | return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData) |
| | } |
| |
|
| | func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.ModelConfig, funcs []functions.Function, shouldUseFn bool) string { |
| | var predInput string |
| | suppressConfigSystemPrompt := false |
| | mess := []string{} |
| | for messageIndex, i := range messages { |
| | var content string |
| | role := i.Role |
| |
|
| | |
| | |
| | if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { |
| | roleFn := "assistant_function_call" |
| | r := config.Roles[roleFn] |
| | if r != "" { |
| | role = roleFn |
| | } |
| | } |
| | r := config.Roles[role] |
| | contentExists := i.Content != nil && i.StringContent != "" |
| |
|
| | fcall := i.FunctionCall |
| | if len(i.ToolCalls) > 0 { |
| | fcall = i.ToolCalls |
| | } |
| |
|
| | |
| | if config.TemplateConfig.ChatMessage != "" { |
| | chatMessageData := ChatMessageTemplateData{ |
| | SystemPrompt: config.SystemPrompt, |
| | Role: r, |
| | RoleName: role, |
| | Content: i.StringContent, |
| | FunctionCall: fcall, |
| | FunctionName: i.Name, |
| | LastMessage: messageIndex == (len(messages) - 1), |
| | Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), |
| | MessageIndex: messageIndex, |
| | } |
| | templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) |
| | if err != nil { |
| | xlog.Error("error processing message with template, skipping", "error", err, "message", chatMessageData, "template", config.TemplateConfig.ChatMessage) |
| | } else { |
| | if templatedChatMessage == "" { |
| | xlog.Warn("template produced blank output, skipping", "template", config.TemplateConfig.ChatMessage, "message", chatMessageData) |
| | continue |
| | } |
| | xlog.Debug("templated message for chat", "message", templatedChatMessage) |
| | content = templatedChatMessage |
| | } |
| | } |
| |
|
| | marshalAnyRole := func(f any) { |
| | j, err := json.Marshal(f) |
| | if err == nil { |
| | if contentExists { |
| | content += "\n" + fmt.Sprint(r, " ", string(j)) |
| | } else { |
| | content = fmt.Sprint(r, " ", string(j)) |
| | } |
| | } |
| | } |
| | marshalAny := func(f any) { |
| | j, err := json.Marshal(f) |
| | if err == nil { |
| | if contentExists { |
| | content += "\n" + string(j) |
| | } else { |
| | content = string(j) |
| | } |
| | } |
| | } |
| | |
| | if content == "" { |
| | if r != "" { |
| | if contentExists { |
| | content = fmt.Sprint(r, i.StringContent) |
| | } |
| |
|
| | if i.FunctionCall != nil { |
| | marshalAnyRole(i.FunctionCall) |
| | } |
| | if i.ToolCalls != nil { |
| | marshalAnyRole(i.ToolCalls) |
| | } |
| | } else { |
| | if contentExists { |
| | content = fmt.Sprint(i.StringContent) |
| | } |
| | if i.FunctionCall != nil { |
| | marshalAny(i.FunctionCall) |
| | } |
| | if i.ToolCalls != nil { |
| | marshalAny(i.ToolCalls) |
| | } |
| | } |
| | |
| | if contentExists && role == "system" { |
| | suppressConfigSystemPrompt = true |
| | } |
| | } |
| |
|
| | mess = append(mess, content) |
| | } |
| |
|
| | joinCharacter := "\n" |
| | if config.TemplateConfig.JoinChatMessagesByCharacter != nil { |
| | joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter |
| | } |
| |
|
| | predInput = strings.Join(mess, joinCharacter) |
| | xlog.Debug("Prompt (before templating)", "prompt", predInput) |
| |
|
| | promptTemplate := ChatPromptTemplate |
| |
|
| | if config.TemplateConfig.Functions != "" && shouldUseFn { |
| | promptTemplate = FunctionsPromptTemplate |
| | } |
| |
|
| | templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{ |
| | SystemPrompt: config.SystemPrompt, |
| | SuppressSystemPrompt: suppressConfigSystemPrompt, |
| | Input: predInput, |
| | Functions: funcs, |
| | ReasoningEffort: input.ReasoningEffort, |
| | Metadata: input.Metadata, |
| | }) |
| | if err == nil { |
| | predInput = templatedInput |
| | xlog.Debug("Template found, input modified", "input", predInput) |
| | } else { |
| | xlog.Debug("Template failed loading", "error", err) |
| | } |
| |
|
| | return predInput |
| | } |
| |
|