|
|
package openai |
|
|
|
|
|
import ( |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"github.com/google/uuid" |
|
|
"github.com/labstack/echo/v4" |
|
|
"github.com/mudler/LocalAI/core/backend" |
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
"github.com/mudler/LocalAI/core/http/middleware" |
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
"github.com/mudler/LocalAI/pkg/functions" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/templates" |
|
|
"github.com/mudler/LocalAI/pkg/model" |
|
|
|
|
|
"github.com/mudler/xlog" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc { |
|
|
var id, textContentToReturn string |
|
|
var created int |
|
|
|
|
|
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { |
|
|
initialMessage := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}}, |
|
|
} |
|
|
responses <- initialMessage |
|
|
|
|
|
|
|
|
accumulatedContent := "" |
|
|
lastEmittedReasoning := "" |
|
|
lastEmittedCleanedContent := "" |
|
|
|
|
|
_, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { |
|
|
accumulatedContent += s |
|
|
|
|
|
currentReasoning, cleanedContent := functions.ExtractReasoning(accumulatedContent) |
|
|
|
|
|
|
|
|
var reasoningDelta *string |
|
|
if currentReasoning != lastEmittedReasoning { |
|
|
|
|
|
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) { |
|
|
newReasoning := currentReasoning[len(lastEmittedReasoning):] |
|
|
reasoningDelta = &newReasoning |
|
|
lastEmittedReasoning = currentReasoning |
|
|
} else if currentReasoning != "" { |
|
|
|
|
|
reasoningDelta = ¤tReasoning |
|
|
lastEmittedReasoning = currentReasoning |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var deltaContent string |
|
|
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) { |
|
|
deltaContent = cleanedContent[len(lastEmittedCleanedContent):] |
|
|
lastEmittedCleanedContent = cleanedContent |
|
|
} else if cleanedContent != lastEmittedCleanedContent { |
|
|
|
|
|
|
|
|
if lastEmittedCleanedContent == "" { |
|
|
deltaContent = cleanedContent |
|
|
lastEmittedCleanedContent = cleanedContent |
|
|
} else { |
|
|
|
|
|
deltaContent = cleanedContent |
|
|
lastEmittedCleanedContent = cleanedContent |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
usage := schema.OpenAIUsage{ |
|
|
PromptTokens: tokenUsage.Prompt, |
|
|
CompletionTokens: tokenUsage.Completion, |
|
|
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, |
|
|
} |
|
|
if extraUsage { |
|
|
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration |
|
|
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing |
|
|
} |
|
|
|
|
|
delta := &schema.Message{} |
|
|
|
|
|
if deltaContent != "" { |
|
|
delta.Content = &deltaContent |
|
|
} |
|
|
if reasoningDelta != nil && *reasoningDelta != "" { |
|
|
delta.Reasoning = reasoningDelta |
|
|
} |
|
|
|
|
|
resp := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}}, |
|
|
Object: "chat.completion.chunk", |
|
|
Usage: usage, |
|
|
} |
|
|
|
|
|
responses <- resp |
|
|
return true |
|
|
}) |
|
|
close(responses) |
|
|
return err |
|
|
} |
|
|
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { |
|
|
result := "" |
|
|
lastEmittedCount := 0 |
|
|
_, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { |
|
|
result += s |
|
|
|
|
|
|
|
|
cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig) |
|
|
|
|
|
|
|
|
var xmlFormat *functions.XMLToolCallFormat |
|
|
if config.FunctionsConfig.XMLFormat != nil { |
|
|
xmlFormat = config.FunctionsConfig.XMLFormat |
|
|
} else if config.FunctionsConfig.XMLFormatPreset != "" { |
|
|
xmlFormat = functions.GetXMLFormatPreset(config.FunctionsConfig.XMLFormatPreset) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true) |
|
|
if parseErr == nil && len(partialResults) > 0 { |
|
|
|
|
|
if len(partialResults) > lastEmittedCount { |
|
|
for i := lastEmittedCount; i < len(partialResults); i++ { |
|
|
toolCall := partialResults[i] |
|
|
initialMessage := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{ |
|
|
Delta: &schema.Message{ |
|
|
Role: "assistant", |
|
|
ToolCalls: []schema.ToolCall{ |
|
|
{ |
|
|
Index: i, |
|
|
ID: id, |
|
|
Type: "function", |
|
|
FunctionCall: schema.FunctionCall{ |
|
|
Name: toolCall.Name, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
Index: 0, |
|
|
FinishReason: nil, |
|
|
}}, |
|
|
Object: "chat.completion.chunk", |
|
|
} |
|
|
select { |
|
|
case responses <- initialMessage: |
|
|
default: |
|
|
} |
|
|
} |
|
|
lastEmittedCount = len(partialResults) |
|
|
} |
|
|
} else { |
|
|
|
|
|
|
|
|
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) |
|
|
if jsonErr == nil && len(jsonResults) > 0 { |
|
|
|
|
|
for _, jsonObj := range jsonResults { |
|
|
if name, ok := jsonObj["name"].(string); ok && name != "" { |
|
|
|
|
|
args := "{}" |
|
|
if argsVal, ok := jsonObj["arguments"]; ok { |
|
|
if argsStr, ok := argsVal.(string); ok { |
|
|
args = argsStr |
|
|
} else { |
|
|
argsBytes, _ := json.Marshal(argsVal) |
|
|
args = string(argsBytes) |
|
|
} |
|
|
} |
|
|
|
|
|
initialMessage := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{ |
|
|
Delta: &schema.Message{ |
|
|
Role: "assistant", |
|
|
ToolCalls: []schema.ToolCall{ |
|
|
{ |
|
|
Index: lastEmittedCount, |
|
|
ID: id, |
|
|
Type: "function", |
|
|
FunctionCall: schema.FunctionCall{ |
|
|
Name: name, |
|
|
Arguments: args, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
Index: 0, |
|
|
FinishReason: nil, |
|
|
}}, |
|
|
Object: "chat.completion.chunk", |
|
|
} |
|
|
select { |
|
|
case responses <- initialMessage: |
|
|
default: |
|
|
} |
|
|
lastEmittedCount++ |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return true |
|
|
}) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
reasoning, cleanedResult := functions.ExtractReasoning(result) |
|
|
result = cleanedResult |
|
|
|
|
|
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig) |
|
|
result = functions.CleanupLLMResult(result, config.FunctionsConfig) |
|
|
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig) |
|
|
xlog.Debug("Text content to return", "text", textContentToReturn) |
|
|
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0 |
|
|
|
|
|
switch { |
|
|
case noActionToRun: |
|
|
initialMessage := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}}, |
|
|
Object: "chat.completion.chunk", |
|
|
} |
|
|
responses <- initialMessage |
|
|
|
|
|
result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt) |
|
|
if err != nil { |
|
|
xlog.Error("error handling question", "error", err) |
|
|
return err |
|
|
} |
|
|
usage := schema.OpenAIUsage{ |
|
|
PromptTokens: tokenUsage.Prompt, |
|
|
CompletionTokens: tokenUsage.Completion, |
|
|
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, |
|
|
} |
|
|
if extraUsage { |
|
|
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration |
|
|
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing |
|
|
} |
|
|
|
|
|
var deltaReasoning *string |
|
|
if reasoning != "" { |
|
|
deltaReasoning = &reasoning |
|
|
} |
|
|
delta := &schema.Message{Content: &result} |
|
|
if deltaReasoning != nil { |
|
|
delta.Reasoning = deltaReasoning |
|
|
} |
|
|
|
|
|
resp := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}}, |
|
|
Object: "chat.completion.chunk", |
|
|
Usage: usage, |
|
|
} |
|
|
|
|
|
responses <- resp |
|
|
|
|
|
default: |
|
|
for i, ss := range functionResults { |
|
|
name, args := ss.Name, ss.Arguments |
|
|
|
|
|
initialMessage := schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{ |
|
|
Delta: &schema.Message{ |
|
|
Role: "assistant", |
|
|
ToolCalls: []schema.ToolCall{ |
|
|
{ |
|
|
Index: i, |
|
|
ID: id, |
|
|
Type: "function", |
|
|
FunctionCall: schema.FunctionCall{ |
|
|
Name: name, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
Index: 0, |
|
|
FinishReason: nil, |
|
|
}}, |
|
|
Object: "chat.completion.chunk", |
|
|
} |
|
|
responses <- initialMessage |
|
|
|
|
|
responses <- schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: req.Model, |
|
|
Choices: []schema.Choice{{ |
|
|
Delta: &schema.Message{ |
|
|
Role: "assistant", |
|
|
Content: &textContentToReturn, |
|
|
ToolCalls: []schema.ToolCall{ |
|
|
{ |
|
|
Index: i, |
|
|
ID: id, |
|
|
Type: "function", |
|
|
FunctionCall: schema.FunctionCall{ |
|
|
Arguments: args, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
Index: 0, |
|
|
FinishReason: nil, |
|
|
}}, |
|
|
Object: "chat.completion.chunk", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
close(responses) |
|
|
return err |
|
|
} |
|
|
|
|
|
return func(c echo.Context) error { |
|
|
textContentToReturn = "" |
|
|
id = uuid.New().String() |
|
|
created = int(time.Now().Unix()) |
|
|
|
|
|
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) |
|
|
if !ok || input.Model == "" { |
|
|
return echo.ErrBadRequest |
|
|
} |
|
|
|
|
|
extraUsage := c.Request().Header.Get("Extra-Usage") != "" |
|
|
|
|
|
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) |
|
|
if !ok || config == nil { |
|
|
return echo.ErrBadRequest |
|
|
} |
|
|
|
|
|
xlog.Debug("Chat endpoint configuration read", "config", config) |
|
|
|
|
|
funcs := input.Functions |
|
|
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() |
|
|
strictMode := false |
|
|
|
|
|
for _, f := range input.Functions { |
|
|
if f.Strict { |
|
|
strictMode = true |
|
|
break |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
noActionName := "answer" |
|
|
noActionDescription := "use this action to answer without performing any action" |
|
|
|
|
|
if config.FunctionsConfig.NoActionFunctionName != "" { |
|
|
noActionName = config.FunctionsConfig.NoActionFunctionName |
|
|
} |
|
|
if config.FunctionsConfig.NoActionDescriptionName != "" { |
|
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
|
|
} |
|
|
|
|
|
|
|
|
if config.ResponseFormatMap != nil { |
|
|
d := schema.ChatCompletionResponseFormat{} |
|
|
dat, err := json.Marshal(config.ResponseFormatMap) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
err = json.Unmarshal(dat, &d) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
switch d.Type { |
|
|
case "json_object": |
|
|
input.Grammar = functions.JSONBNF |
|
|
case "json_schema": |
|
|
d := schema.JsonSchemaRequest{} |
|
|
dat, err := json.Marshal(config.ResponseFormatMap) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
err = json.Unmarshal(dat, &d) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
fs := &functions.JSONFunctionStructure{ |
|
|
AnyOf: []functions.Item{d.JsonSchema.Schema}, |
|
|
} |
|
|
g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...) |
|
|
if err == nil { |
|
|
input.Grammar = g |
|
|
} else { |
|
|
xlog.Error("Failed generating grammar", "error", err) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
config.Grammar = input.Grammar |
|
|
|
|
|
if shouldUseFn { |
|
|
xlog.Debug("Response needs to process functions") |
|
|
} |
|
|
|
|
|
switch { |
|
|
|
|
|
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn: |
|
|
noActionGrammar := functions.Function{ |
|
|
Name: noActionName, |
|
|
Description: noActionDescription, |
|
|
Parameters: map[string]interface{}{ |
|
|
"properties": map[string]interface{}{ |
|
|
"message": map[string]interface{}{ |
|
|
"type": "string", |
|
|
"description": "The message to reply the user with", |
|
|
}}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
if !config.FunctionsConfig.DisableNoAction && !strictMode { |
|
|
funcs = append(funcs, noActionGrammar) |
|
|
} |
|
|
|
|
|
|
|
|
if config.FunctionToCall() != "" { |
|
|
funcs = funcs.Select(config.FunctionToCall()) |
|
|
} |
|
|
|
|
|
|
|
|
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) |
|
|
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) |
|
|
if err == nil { |
|
|
config.Grammar = g |
|
|
} else { |
|
|
xlog.Error("Failed generating grammar", "error", err) |
|
|
} |
|
|
case input.JSONFunctionGrammarObject != nil: |
|
|
g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...) |
|
|
if err == nil { |
|
|
config.Grammar = g |
|
|
} else { |
|
|
xlog.Error("Failed generating grammar", "error", err) |
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
if config.FunctionToCall() != "" { |
|
|
funcs = funcs.Select(config.FunctionToCall()) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toStream := input.Stream |
|
|
|
|
|
xlog.Debug("Parameters", "config", config) |
|
|
|
|
|
var predInput string |
|
|
|
|
|
|
|
|
|
|
|
if !config.TemplateConfig.UseTokenizerTemplate { |
|
|
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) |
|
|
|
|
|
xlog.Debug("Prompt (after templating)", "prompt", predInput) |
|
|
if config.Grammar != "" { |
|
|
xlog.Debug("Grammar", "grammar", config.Grammar) |
|
|
} |
|
|
} |
|
|
|
|
|
switch { |
|
|
case toStream: |
|
|
|
|
|
xlog.Debug("Stream request received") |
|
|
c.Response().Header().Set("Content-Type", "text/event-stream") |
|
|
c.Response().Header().Set("Cache-Control", "no-cache") |
|
|
c.Response().Header().Set("Connection", "keep-alive") |
|
|
c.Response().Header().Set("X-Correlation-ID", id) |
|
|
|
|
|
responses := make(chan schema.OpenAIResponse) |
|
|
ended := make(chan error, 1) |
|
|
|
|
|
go func() { |
|
|
if !shouldUseFn { |
|
|
ended <- process(predInput, input, config, ml, responses, extraUsage) |
|
|
} else { |
|
|
ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage) |
|
|
} |
|
|
}() |
|
|
|
|
|
usage := &schema.OpenAIUsage{} |
|
|
toolsCalled := false |
|
|
|
|
|
LOOP: |
|
|
for { |
|
|
select { |
|
|
case <-input.Context.Done(): |
|
|
|
|
|
xlog.Debug("Request context cancelled, stopping stream") |
|
|
input.Cancel() |
|
|
break LOOP |
|
|
case ev := <-responses: |
|
|
if len(ev.Choices) == 0 { |
|
|
xlog.Debug("No choices in the response, skipping") |
|
|
continue |
|
|
} |
|
|
usage = &ev.Usage |
|
|
if len(ev.Choices[0].Delta.ToolCalls) > 0 { |
|
|
toolsCalled = true |
|
|
} |
|
|
respData, err := json.Marshal(ev) |
|
|
if err != nil { |
|
|
xlog.Debug("Failed to marshal response", "error", err) |
|
|
input.Cancel() |
|
|
continue |
|
|
} |
|
|
xlog.Debug("Sending chunk", "chunk", string(respData)) |
|
|
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) |
|
|
if err != nil { |
|
|
xlog.Debug("Sending chunk failed", "error", err) |
|
|
input.Cancel() |
|
|
return err |
|
|
} |
|
|
c.Response().Flush() |
|
|
case err := <-ended: |
|
|
if err == nil { |
|
|
break LOOP |
|
|
} |
|
|
xlog.Error("Stream ended with error", "error", err) |
|
|
|
|
|
stopReason := FinishReasonStop |
|
|
resp := &schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: input.Model, |
|
|
Choices: []schema.Choice{ |
|
|
{ |
|
|
FinishReason: &stopReason, |
|
|
Index: 0, |
|
|
Delta: &schema.Message{Content: "Internal error: " + err.Error()}, |
|
|
}}, |
|
|
Object: "chat.completion.chunk", |
|
|
Usage: *usage, |
|
|
} |
|
|
respData, marshalErr := json.Marshal(resp) |
|
|
if marshalErr != nil { |
|
|
xlog.Error("Failed to marshal error response", "error", marshalErr) |
|
|
|
|
|
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n") |
|
|
} else { |
|
|
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) |
|
|
} |
|
|
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") |
|
|
c.Response().Flush() |
|
|
|
|
|
return nil |
|
|
} |
|
|
} |
|
|
|
|
|
finishReason := FinishReasonStop |
|
|
if toolsCalled && len(input.Tools) > 0 { |
|
|
finishReason = FinishReasonToolCalls |
|
|
} else if toolsCalled { |
|
|
finishReason = FinishReasonFunctionCall |
|
|
} |
|
|
|
|
|
resp := &schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: input.Model, |
|
|
Choices: []schema.Choice{ |
|
|
{ |
|
|
FinishReason: &finishReason, |
|
|
Index: 0, |
|
|
Delta: &schema.Message{}, |
|
|
}}, |
|
|
Object: "chat.completion.chunk", |
|
|
Usage: *usage, |
|
|
} |
|
|
respData, _ := json.Marshal(resp) |
|
|
|
|
|
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) |
|
|
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") |
|
|
c.Response().Flush() |
|
|
xlog.Debug("Stream ended") |
|
|
return nil |
|
|
|
|
|
|
|
|
default: |
|
|
|
|
|
tokenCallback := func(s string, c *[]schema.Choice) { |
|
|
|
|
|
reasoning, cleanedS := functions.ExtractReasoning(s) |
|
|
s = cleanedS |
|
|
|
|
|
if !shouldUseFn { |
|
|
|
|
|
stopReason := FinishReasonStop |
|
|
message := &schema.Message{Role: "assistant", Content: &s} |
|
|
if reasoning != "" { |
|
|
message.Reasoning = &reasoning |
|
|
} |
|
|
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: message}) |
|
|
return |
|
|
} |
|
|
|
|
|
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig) |
|
|
s = functions.CleanupLLMResult(s, config.FunctionsConfig) |
|
|
results := functions.ParseFunctionCall(s, config.FunctionsConfig) |
|
|
xlog.Debug("Text content to return", "text", textContentToReturn) |
|
|
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0 |
|
|
|
|
|
switch { |
|
|
case noActionsToRun: |
|
|
result, err := handleQuestion(config, cl, input, ml, startupOptions, results, s, predInput) |
|
|
if err != nil { |
|
|
xlog.Error("error handling question", "error", err) |
|
|
return |
|
|
} |
|
|
|
|
|
stopReason := FinishReasonStop |
|
|
message := &schema.Message{Role: "assistant", Content: &result} |
|
|
if reasoning != "" { |
|
|
message.Reasoning = &reasoning |
|
|
} |
|
|
*c = append(*c, schema.Choice{ |
|
|
FinishReason: &stopReason, |
|
|
Message: message}) |
|
|
default: |
|
|
toolCallsReason := FinishReasonToolCalls |
|
|
toolChoice := schema.Choice{ |
|
|
FinishReason: &toolCallsReason, |
|
|
Message: &schema.Message{ |
|
|
Role: "assistant", |
|
|
}, |
|
|
} |
|
|
if reasoning != "" { |
|
|
toolChoice.Message.Reasoning = &reasoning |
|
|
} |
|
|
|
|
|
for _, ss := range results { |
|
|
name, args := ss.Name, ss.Arguments |
|
|
if len(input.Tools) > 0 { |
|
|
|
|
|
|
|
|
toolChoice.Message.Content = textContentToReturn |
|
|
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, |
|
|
schema.ToolCall{ |
|
|
ID: id, |
|
|
Type: "function", |
|
|
FunctionCall: schema.FunctionCall{ |
|
|
Name: name, |
|
|
Arguments: args, |
|
|
}, |
|
|
}, |
|
|
) |
|
|
} else { |
|
|
|
|
|
functionCallReason := FinishReasonFunctionCall |
|
|
message := &schema.Message{ |
|
|
Role: "assistant", |
|
|
Content: &textContentToReturn, |
|
|
FunctionCall: map[string]interface{}{ |
|
|
"name": name, |
|
|
"arguments": args, |
|
|
}, |
|
|
} |
|
|
if reasoning != "" { |
|
|
message.Reasoning = &reasoning |
|
|
} |
|
|
*c = append(*c, schema.Choice{ |
|
|
FinishReason: &functionCallReason, |
|
|
Message: message, |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
if len(input.Tools) > 0 { |
|
|
|
|
|
*c = append(*c, toolChoice) |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result, tokenUsage, err := ComputeChoices( |
|
|
input, |
|
|
predInput, |
|
|
config, |
|
|
cl, |
|
|
startupOptions, |
|
|
ml, |
|
|
tokenCallback, |
|
|
nil, |
|
|
) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
usage := schema.OpenAIUsage{ |
|
|
PromptTokens: tokenUsage.Prompt, |
|
|
CompletionTokens: tokenUsage.Completion, |
|
|
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, |
|
|
} |
|
|
if extraUsage { |
|
|
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration |
|
|
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing |
|
|
} |
|
|
|
|
|
resp := &schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Model: input.Model, |
|
|
Choices: result, |
|
|
Object: "chat.completion", |
|
|
Usage: usage, |
|
|
} |
|
|
respData, _ := json.Marshal(resp) |
|
|
xlog.Debug("Response", "response", string(respData)) |
|
|
|
|
|
|
|
|
return c.JSON(200, resp) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { |
|
|
|
|
|
if len(funcResults) == 0 && result != "" { |
|
|
xlog.Debug("nothing function results but we had a message from the LLM") |
|
|
|
|
|
return result, nil |
|
|
} |
|
|
|
|
|
xlog.Debug("nothing to do, computing a reply") |
|
|
arg := "" |
|
|
if len(funcResults) > 0 { |
|
|
arg = funcResults[0].Arguments |
|
|
} |
|
|
|
|
|
arguments := map[string]interface{}{} |
|
|
if err := json.Unmarshal([]byte(arg), &arguments); err != nil { |
|
|
xlog.Debug("handleQuestion: function result did not contain a valid JSON object") |
|
|
} |
|
|
m, exists := arguments["message"] |
|
|
if exists { |
|
|
switch message := m.(type) { |
|
|
case string: |
|
|
if message != "" { |
|
|
xlog.Debug("Reply received from LLM", "message", message) |
|
|
message = backend.Finetune(*config, prompt, message) |
|
|
xlog.Debug("Reply received from LLM(finetuned)", "message", message) |
|
|
|
|
|
return message, nil |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
xlog.Debug("No action received from LLM, without a message, computing a reply") |
|
|
|
|
|
|
|
|
config.Grammar = "" |
|
|
images := []string{} |
|
|
for _, m := range input.Messages { |
|
|
images = append(images, m.StringImages...) |
|
|
} |
|
|
videos := []string{} |
|
|
for _, m := range input.Messages { |
|
|
videos = append(videos, m.StringVideos...) |
|
|
} |
|
|
audios := []string{} |
|
|
for _, m := range input.Messages { |
|
|
audios = append(audios, m.StringAudios...) |
|
|
} |
|
|
|
|
|
|
|
|
toolsJSON := "" |
|
|
if len(input.Tools) > 0 { |
|
|
toolsBytes, err := json.Marshal(input.Tools) |
|
|
if err == nil { |
|
|
toolsJSON = string(toolsBytes) |
|
|
} |
|
|
} |
|
|
toolChoiceJSON := "" |
|
|
if input.ToolsChoice != nil { |
|
|
toolChoiceBytes, err := json.Marshal(input.ToolsChoice) |
|
|
if err == nil { |
|
|
toolChoiceJSON = string(toolChoiceBytes) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var logprobs *int |
|
|
var topLogprobs *int |
|
|
if input.Logprobs.IsEnabled() { |
|
|
|
|
|
if input.TopLogprobs != nil { |
|
|
topLogprobs = input.TopLogprobs |
|
|
|
|
|
logprobs = input.TopLogprobs |
|
|
} else { |
|
|
|
|
|
val := 1 |
|
|
logprobs = &val |
|
|
topLogprobs = &val |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var logitBias map[string]float64 |
|
|
if len(input.LogitBias) > 0 { |
|
|
logitBias = input.LogitBias |
|
|
} |
|
|
|
|
|
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias) |
|
|
if err != nil { |
|
|
xlog.Error("model inference failed", "error", err) |
|
|
return "", err |
|
|
} |
|
|
|
|
|
prediction, err := predFunc() |
|
|
if err != nil { |
|
|
xlog.Error("prediction failed", "error", err) |
|
|
return "", err |
|
|
} |
|
|
return backend.Finetune(*config, prompt, prediction.Response), nil |
|
|
} |
|
|
|