File size: 7,058 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
package templates

import (
	"encoding/json"
	"fmt"
	"strings"

	"github.com/mudler/LocalAI/core/config"
	"github.com/mudler/LocalAI/core/schema"
	"github.com/mudler/LocalAI/pkg/functions"
	"github.com/mudler/xlog"
)

// Rather than pass an interface{} to the prompt template:
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
type PromptTemplateData struct {
	SystemPrompt         string
	SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
	Input                string
	Instruction          string
	Functions            []functions.Function
	MessageIndex         int
	ReasoningEffort      string
	Metadata             map[string]string
}

type ChatMessageTemplateData struct {
	SystemPrompt string
	Role         string
	RoleName     string
	FunctionName string
	Content      string
	MessageIndex int
	Function     bool
	FunctionCall interface{}
	LastMessage  bool
}

const (
	ChatPromptTemplate TemplateType = iota
	ChatMessageTemplate
	CompletionPromptTemplate
	EditPromptTemplate
	FunctionsPromptTemplate
)

type Evaluator struct {
	cache *templateCache
}

func NewEvaluator(modelPath string) *Evaluator {
	return &Evaluator{
		cache: newTemplateCache(modelPath),
	}
}

func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) {
	template := ""

	// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
	if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
		template = config.Model
	}

	switch templateType {
	case CompletionPromptTemplate:
		if config.TemplateConfig.Completion != "" {
			template = config.TemplateConfig.Completion
		}
	case EditPromptTemplate:
		if config.TemplateConfig.Edit != "" {
			template = config.TemplateConfig.Edit
		}
	case ChatPromptTemplate:
		if config.TemplateConfig.Chat != "" {
			template = config.TemplateConfig.Chat
		}
	case FunctionsPromptTemplate:
		if config.TemplateConfig.Functions != "" {
			template = config.TemplateConfig.Functions
		}
	}

	if template == "" {
		return in.Input, nil
	}

	return e.cache.evaluateTemplate(templateType, template, in)
}

func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
	return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
}

func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.ModelConfig, funcs []functions.Function, shouldUseFn bool) string {
	var predInput string
	suppressConfigSystemPrompt := false
	mess := []string{}
	for messageIndex, i := range messages {
		var content string
		role := i.Role

		// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
		// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
		if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
			roleFn := "assistant_function_call"
			r := config.Roles[roleFn]
			if r != "" {
				role = roleFn
			}
		}
		r := config.Roles[role]
		contentExists := i.Content != nil && i.StringContent != ""

		fcall := i.FunctionCall
		if len(i.ToolCalls) > 0 {
			fcall = i.ToolCalls
		}

		// First attempt to populate content via a chat message specific template
		if config.TemplateConfig.ChatMessage != "" {
			chatMessageData := ChatMessageTemplateData{
				SystemPrompt: config.SystemPrompt,
				Role:         r,
				RoleName:     role,
				Content:      i.StringContent,
				FunctionCall: fcall,
				FunctionName: i.Name,
				LastMessage:  messageIndex == (len(messages) - 1),
				Function:     config.Grammar != "" && (messageIndex == (len(messages) - 1)),
				MessageIndex: messageIndex,
			}
			templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
			if err != nil {
				xlog.Error("error processing message with template, skipping", "error", err, "message", chatMessageData, "template", config.TemplateConfig.ChatMessage)
			} else {
				if templatedChatMessage == "" {
					xlog.Warn("template produced blank output, skipping", "template", config.TemplateConfig.ChatMessage, "message", chatMessageData)
					continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
				}
				xlog.Debug("templated message for chat", "message", templatedChatMessage)
				content = templatedChatMessage
			}
		}

		marshalAnyRole := func(f any) {
			j, err := json.Marshal(f)
			if err == nil {
				if contentExists {
					content += "\n" + fmt.Sprint(r, " ", string(j))
				} else {
					content = fmt.Sprint(r, " ", string(j))
				}
			}
		}
		marshalAny := func(f any) {
			j, err := json.Marshal(f)
			if err == nil {
				if contentExists {
					content += "\n" + string(j)
				} else {
					content = string(j)
				}
			}
		}
		// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
		if content == "" {
			if r != "" {
				if contentExists {
					content = fmt.Sprint(r, i.StringContent)
				}

				if i.FunctionCall != nil {
					marshalAnyRole(i.FunctionCall)
				}
				if i.ToolCalls != nil {
					marshalAnyRole(i.ToolCalls)
				}
			} else {
				if contentExists {
					content = fmt.Sprint(i.StringContent)
				}
				if i.FunctionCall != nil {
					marshalAny(i.FunctionCall)
				}
				if i.ToolCalls != nil {
					marshalAny(i.ToolCalls)
				}
			}
			// Special Handling: System. We care if it was printed at all, not the r branch, so check separately
			if contentExists && role == "system" {
				suppressConfigSystemPrompt = true
			}
		}

		mess = append(mess, content)
	}

	joinCharacter := "\n"
	if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
		joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
	}

	predInput = strings.Join(mess, joinCharacter)
	xlog.Debug("Prompt (before templating)", "prompt", predInput)

	promptTemplate := ChatPromptTemplate

	if config.TemplateConfig.Functions != "" && shouldUseFn {
		promptTemplate = FunctionsPromptTemplate
	}

	templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
		SystemPrompt:         config.SystemPrompt,
		SuppressSystemPrompt: suppressConfigSystemPrompt,
		Input:                predInput,
		Functions:            funcs,
		ReasoningEffort:      input.ReasoningEffort,
		Metadata:             input.Metadata,
	})
	if err == nil {
		predInput = templatedInput
		xlog.Debug("Template found, input modified", "input", predInput)
	} else {
		xlog.Debug("Template failed loading", "error", err)
	}

	return predInput
}