|
|
package openai |
|
|
|
|
|
import ( |
|
|
"encoding/json" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/backend" |
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
model "github.com/mudler/LocalAI/pkg/model" |
|
|
) |
|
|
|
|
|
func ComputeChoices( |
|
|
req *schema.OpenAIRequest, |
|
|
predInput string, |
|
|
config *config.ModelConfig, |
|
|
bcl *config.ModelConfigLoader, |
|
|
o *config.ApplicationConfig, |
|
|
loader *model.ModelLoader, |
|
|
cb func(string, *[]schema.Choice), |
|
|
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { |
|
|
n := req.N |
|
|
result := []schema.Choice{} |
|
|
|
|
|
if n == 0 { |
|
|
n = 1 |
|
|
} |
|
|
|
|
|
images := []string{} |
|
|
for _, m := range req.Messages { |
|
|
images = append(images, m.StringImages...) |
|
|
} |
|
|
videos := []string{} |
|
|
for _, m := range req.Messages { |
|
|
videos = append(videos, m.StringVideos...) |
|
|
} |
|
|
audios := []string{} |
|
|
for _, m := range req.Messages { |
|
|
audios = append(audios, m.StringAudios...) |
|
|
} |
|
|
|
|
|
|
|
|
toolsJSON := "" |
|
|
if len(req.Tools) > 0 { |
|
|
toolsBytes, err := json.Marshal(req.Tools) |
|
|
if err == nil { |
|
|
toolsJSON = string(toolsBytes) |
|
|
} |
|
|
} |
|
|
toolChoiceJSON := "" |
|
|
if req.ToolsChoice != nil { |
|
|
toolChoiceBytes, err := json.Marshal(req.ToolsChoice) |
|
|
if err == nil { |
|
|
toolChoiceJSON = string(toolChoiceBytes) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var logprobs *int |
|
|
var topLogprobs *int |
|
|
if req.Logprobs.IsEnabled() { |
|
|
|
|
|
if req.TopLogprobs != nil { |
|
|
topLogprobs = req.TopLogprobs |
|
|
|
|
|
logprobs = req.TopLogprobs |
|
|
} else { |
|
|
|
|
|
val := 1 |
|
|
logprobs = &val |
|
|
topLogprobs = &val |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var logitBias map[string]float64 |
|
|
if len(req.LogitBias) > 0 { |
|
|
logitBias = req.LogitBias |
|
|
} |
|
|
|
|
|
|
|
|
predFunc, err := backend.ModelInference( |
|
|
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias) |
|
|
if err != nil { |
|
|
return result, backend.TokenUsage{}, err |
|
|
} |
|
|
|
|
|
tokenUsage := backend.TokenUsage{} |
|
|
|
|
|
for i := 0; i < n; i++ { |
|
|
prediction, err := predFunc() |
|
|
if err != nil { |
|
|
return result, backend.TokenUsage{}, err |
|
|
} |
|
|
|
|
|
tokenUsage.Prompt += prediction.Usage.Prompt |
|
|
tokenUsage.Completion += prediction.Usage.Completion |
|
|
tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing |
|
|
tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration |
|
|
|
|
|
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) |
|
|
cb(finetunedResponse, &result) |
|
|
|
|
|
|
|
|
if prediction.Logprobs != nil && len(result) > 0 { |
|
|
result[len(result)-1].Logprobs = prediction.Logprobs |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
return result, tokenUsage, err |
|
|
} |
|
|
|