|
|
package backend |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"encoding/json" |
|
|
"regexp" |
|
|
"slices" |
|
|
"strings" |
|
|
"sync" |
|
|
"unicode/utf8" |
|
|
|
|
|
"github.com/mudler/xlog" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
"github.com/mudler/LocalAI/core/services" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/gallery" |
|
|
"github.com/mudler/LocalAI/pkg/grpc/proto" |
|
|
model "github.com/mudler/LocalAI/pkg/model" |
|
|
"github.com/mudler/LocalAI/pkg/utils" |
|
|
) |
|
|
|
|
|
type LLMResponse struct { |
|
|
Response string |
|
|
Usage TokenUsage |
|
|
AudioOutput string |
|
|
Logprobs *schema.Logprobs |
|
|
} |
|
|
|
|
|
type TokenUsage struct { |
|
|
Prompt int |
|
|
Completion int |
|
|
TimingPromptProcessing float64 |
|
|
TimingTokenGeneration float64 |
|
|
} |
|
|
|
|
|
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) { |
|
|
modelFile := c.Model |
|
|
|
|
|
|
|
|
if o.AutoloadGalleries { |
|
|
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
if !slices.Contains(modelNames, c.Name) { |
|
|
utils.ResetDownloadTimers() |
|
|
|
|
|
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) |
|
|
if err != nil { |
|
|
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile) |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
opts := ModelOptions(*c, o) |
|
|
inferenceModel, err := loader.Load(opts...) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
var protoMessages []*proto.Message |
|
|
|
|
|
|
|
|
if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 { |
|
|
protoMessages = messages.ToProto() |
|
|
} |
|
|
|
|
|
|
|
|
fn := func() (LLMResponse, error) { |
|
|
opts := gRPCPredictOpts(*c, loader.ModelPath) |
|
|
opts.Prompt = s |
|
|
opts.Messages = protoMessages |
|
|
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate |
|
|
opts.Images = images |
|
|
opts.Videos = videos |
|
|
opts.Audios = audios |
|
|
opts.Tools = tools |
|
|
opts.ToolChoice = toolChoice |
|
|
if logprobs != nil { |
|
|
opts.Logprobs = int32(*logprobs) |
|
|
} |
|
|
if topLogprobs != nil { |
|
|
opts.TopLogprobs = int32(*topLogprobs) |
|
|
} |
|
|
if len(logitBias) > 0 { |
|
|
|
|
|
logitBiasJSON, err := json.Marshal(logitBias) |
|
|
if err == nil { |
|
|
opts.LogitBias = string(logitBiasJSON) |
|
|
} |
|
|
} |
|
|
|
|
|
tokenUsage := TokenUsage{} |
|
|
|
|
|
|
|
|
|
|
|
if c.FeatureFlag.Enabled("usage") { |
|
|
userTokenCallback := tokenCallback |
|
|
if userTokenCallback == nil { |
|
|
userTokenCallback = func(token string, usage TokenUsage) bool { |
|
|
return true |
|
|
} |
|
|
} |
|
|
|
|
|
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) |
|
|
if pErr == nil && promptInfo.Length > 0 { |
|
|
tokenUsage.Prompt = int(promptInfo.Length) |
|
|
} |
|
|
|
|
|
tokenCallback = func(token string, usage TokenUsage) bool { |
|
|
tokenUsage.Completion++ |
|
|
return userTokenCallback(token, tokenUsage) |
|
|
} |
|
|
} |
|
|
|
|
|
if tokenCallback != nil { |
|
|
|
|
|
if c.TemplateConfig.ReplyPrefix != "" { |
|
|
tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage) |
|
|
} |
|
|
|
|
|
ss := "" |
|
|
var logprobs *schema.Logprobs |
|
|
|
|
|
var partialRune []byte |
|
|
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) { |
|
|
msg := reply.Message |
|
|
partialRune = append(partialRune, msg...) |
|
|
|
|
|
tokenUsage.Prompt = int(reply.PromptTokens) |
|
|
tokenUsage.Completion = int(reply.Tokens) |
|
|
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration |
|
|
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing |
|
|
|
|
|
|
|
|
if len(reply.Logprobs) > 0 { |
|
|
var parsedLogprobs schema.Logprobs |
|
|
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil { |
|
|
logprobs = &parsedLogprobs |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var completeRunes []byte |
|
|
for len(partialRune) > 0 { |
|
|
r, size := utf8.DecodeRune(partialRune) |
|
|
if r == utf8.RuneError { |
|
|
|
|
|
break |
|
|
} |
|
|
completeRunes = append(completeRunes, partialRune[:size]...) |
|
|
partialRune = partialRune[size:] |
|
|
} |
|
|
|
|
|
|
|
|
if len(completeRunes) > 0 { |
|
|
tokenCallback(string(completeRunes), tokenUsage) |
|
|
ss += string(completeRunes) |
|
|
} |
|
|
|
|
|
if len(msg) == 0 { |
|
|
tokenCallback("", tokenUsage) |
|
|
} |
|
|
}) |
|
|
return LLMResponse{ |
|
|
Response: ss, |
|
|
Usage: tokenUsage, |
|
|
Logprobs: logprobs, |
|
|
}, err |
|
|
} else { |
|
|
|
|
|
reply, err := inferenceModel.Predict(ctx, opts) |
|
|
if err != nil { |
|
|
return LLMResponse{}, err |
|
|
} |
|
|
if tokenUsage.Prompt == 0 { |
|
|
tokenUsage.Prompt = int(reply.PromptTokens) |
|
|
} |
|
|
if tokenUsage.Completion == 0 { |
|
|
tokenUsage.Completion = int(reply.Tokens) |
|
|
} |
|
|
|
|
|
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration |
|
|
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing |
|
|
|
|
|
response := string(reply.Message) |
|
|
if c.TemplateConfig.ReplyPrefix != "" { |
|
|
response = c.TemplateConfig.ReplyPrefix + response |
|
|
} |
|
|
|
|
|
|
|
|
var logprobs *schema.Logprobs |
|
|
if len(reply.Logprobs) > 0 { |
|
|
var parsedLogprobs schema.Logprobs |
|
|
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil { |
|
|
logprobs = &parsedLogprobs |
|
|
} |
|
|
} |
|
|
|
|
|
return LLMResponse{ |
|
|
Response: response, |
|
|
Usage: tokenUsage, |
|
|
Logprobs: logprobs, |
|
|
}, err |
|
|
} |
|
|
} |
|
|
|
|
|
return fn, nil |
|
|
} |
|
|
|
|
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
|
|
var mu sync.Mutex = sync.Mutex{} |
|
|
|
|
|
func Finetune(config config.ModelConfig, input, prediction string) string { |
|
|
if config.Echo { |
|
|
prediction = input + prediction |
|
|
} |
|
|
|
|
|
for _, c := range config.Cutstrings { |
|
|
mu.Lock() |
|
|
reg, ok := cutstrings[c] |
|
|
if !ok { |
|
|
r, err := regexp.Compile(c) |
|
|
if err != nil { |
|
|
xlog.Fatal("failed to compile regex", "error", err) |
|
|
} |
|
|
cutstrings[c] = r |
|
|
reg = cutstrings[c] |
|
|
} |
|
|
mu.Unlock() |
|
|
prediction = reg.ReplaceAllString(prediction, "") |
|
|
} |
|
|
|
|
|
|
|
|
var predResult string |
|
|
for _, r := range config.ExtractRegex { |
|
|
mu.Lock() |
|
|
reg, ok := cutstrings[r] |
|
|
if !ok { |
|
|
regex, err := regexp.Compile(r) |
|
|
if err != nil { |
|
|
xlog.Fatal("failed to compile regex", "error", err) |
|
|
} |
|
|
cutstrings[r] = regex |
|
|
reg = regex |
|
|
} |
|
|
mu.Unlock() |
|
|
predResult += reg.FindString(prediction) |
|
|
} |
|
|
if predResult != "" { |
|
|
prediction = predResult |
|
|
} |
|
|
|
|
|
for _, c := range config.TrimSpace { |
|
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
|
|
} |
|
|
|
|
|
for _, c := range config.TrimSuffix { |
|
|
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c)) |
|
|
} |
|
|
return prediction |
|
|
} |
|
|
|