Spaces:
Configuration error
Configuration error
| package backend | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "os" | |
| "regexp" | |
| "strings" | |
| "sync" | |
| "unicode/utf8" | |
| "github.com/rs/zerolog/log" | |
| "github.com/mudler/LocalAI/core/config" | |
| "github.com/mudler/LocalAI/core/schema" | |
| "github.com/mudler/LocalAI/core/gallery" | |
| "github.com/mudler/LocalAI/pkg/grpc" | |
| "github.com/mudler/LocalAI/pkg/grpc/proto" | |
| model "github.com/mudler/LocalAI/pkg/model" | |
| "github.com/mudler/LocalAI/pkg/utils" | |
| ) | |
| type LLMResponse struct { | |
| Response string // should this be []byte? | |
| Usage TokenUsage | |
| } | |
| type TokenUsage struct { | |
| Prompt int | |
| Completion int | |
| } | |
| func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { | |
| modelFile := c.Model | |
| var inferenceModel grpc.Backend | |
| var err error | |
| opts := ModelOptions(c, o, []model.Option{}) | |
| if c.Backend != "" { | |
| opts = append(opts, model.WithBackendString(c.Backend)) | |
| } | |
| // Check if the modelFile exists, if it doesn't try to load it from the gallery | |
| if o.AutoloadGalleries { // experimental | |
| if _, err := os.Stat(modelFile); os.IsNotExist(err) { | |
| utils.ResetDownloadTimers() | |
| // if we failed to load the model, we try to download it | |
| err := gallery.InstallModelFromGallery(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans) | |
| if err != nil { | |
| return nil, err | |
| } | |
| } | |
| } | |
| if c.Backend == "" { | |
| inferenceModel, err = loader.GreedyLoader(opts...) | |
| } else { | |
| inferenceModel, err = loader.BackendLoader(opts...) | |
| } | |
| if err != nil { | |
| return nil, err | |
| } | |
| var protoMessages []*proto.Message | |
| // if we are using the tokenizer template, we need to convert the messages to proto messages | |
| // unless the prompt has already been tokenized (non-chat endpoints + functions) | |
| if c.TemplateConfig.UseTokenizerTemplate && s == "" { | |
| protoMessages = make([]*proto.Message, len(messages), len(messages)) | |
| for i, message := range messages { | |
| protoMessages[i] = &proto.Message{ | |
| Role: message.Role, | |
| } | |
| switch ct := message.Content.(type) { | |
| case string: | |
| protoMessages[i].Content = ct | |
| case []interface{}: | |
| // If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here | |
| data, _ := json.Marshal(ct) | |
| resultData := []struct { | |
| Text string `json:"text"` | |
| }{} | |
| json.Unmarshal(data, &resultData) | |
| for _, r := range resultData { | |
| protoMessages[i].Content += r.Text | |
| } | |
| default: | |
| return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) | |
| } | |
| } | |
| } | |
| // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported | |
| fn := func() (LLMResponse, error) { | |
| opts := gRPCPredictOpts(c, loader.ModelPath) | |
| opts.Prompt = s | |
| opts.Messages = protoMessages | |
| opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate | |
| opts.Images = images | |
| opts.Videos = videos | |
| opts.Audios = audios | |
| tokenUsage := TokenUsage{} | |
| // check the per-model feature flag for usage, since tokenCallback may have a cost. | |
| // Defaults to off as for now it is still experimental | |
| if c.FeatureFlag.Enabled("usage") { | |
| userTokenCallback := tokenCallback | |
| if userTokenCallback == nil { | |
| userTokenCallback = func(token string, usage TokenUsage) bool { | |
| return true | |
| } | |
| } | |
| promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) | |
| if pErr == nil && promptInfo.Length > 0 { | |
| tokenUsage.Prompt = int(promptInfo.Length) | |
| } | |
| tokenCallback = func(token string, usage TokenUsage) bool { | |
| tokenUsage.Completion++ | |
| return userTokenCallback(token, tokenUsage) | |
| } | |
| } | |
| if tokenCallback != nil { | |
| ss := "" | |
| var partialRune []byte | |
| err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { | |
| partialRune = append(partialRune, chars...) | |
| for len(partialRune) > 0 { | |
| r, size := utf8.DecodeRune(partialRune) | |
| if r == utf8.RuneError { | |
| // incomplete rune, wait for more bytes | |
| break | |
| } | |
| tokenCallback(string(r), tokenUsage) | |
| ss += string(r) | |
| partialRune = partialRune[size:] | |
| } | |
| }) | |
| return LLMResponse{ | |
| Response: ss, | |
| Usage: tokenUsage, | |
| }, err | |
| } else { | |
| // TODO: Is the chicken bit the only way to get here? is that acceptable? | |
| reply, err := inferenceModel.Predict(ctx, opts) | |
| if err != nil { | |
| return LLMResponse{}, err | |
| } | |
| if tokenUsage.Prompt == 0 { | |
| tokenUsage.Prompt = int(reply.PromptTokens) | |
| } | |
| if tokenUsage.Completion == 0 { | |
| tokenUsage.Completion = int(reply.Tokens) | |
| } | |
| return LLMResponse{ | |
| Response: string(reply.Message), | |
| Usage: tokenUsage, | |
| }, err | |
| } | |
| } | |
| return fn, nil | |
| } | |
| var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) | |
| var mu sync.Mutex = sync.Mutex{} | |
| func Finetune(config config.BackendConfig, input, prediction string) string { | |
| if config.Echo { | |
| prediction = input + prediction | |
| } | |
| for _, c := range config.Cutstrings { | |
| mu.Lock() | |
| reg, ok := cutstrings[c] | |
| if !ok { | |
| r, err := regexp.Compile(c) | |
| if err != nil { | |
| log.Fatal().Err(err).Msg("failed to compile regex") | |
| } | |
| cutstrings[c] = r | |
| reg = cutstrings[c] | |
| } | |
| mu.Unlock() | |
| prediction = reg.ReplaceAllString(prediction, "") | |
| } | |
| // extract results from the response which can be for instance inside XML tags | |
| var predResult string | |
| for _, r := range config.ExtractRegex { | |
| mu.Lock() | |
| reg, ok := cutstrings[r] | |
| if !ok { | |
| regex, err := regexp.Compile(r) | |
| if err != nil { | |
| log.Fatal().Err(err).Msg("failed to compile regex") | |
| } | |
| cutstrings[r] = regex | |
| reg = regex | |
| } | |
| mu.Unlock() | |
| predResult += reg.FindString(prediction) | |
| } | |
| if predResult != "" { | |
| prediction = predResult | |
| } | |
| for _, c := range config.TrimSpace { | |
| prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) | |
| } | |
| for _, c := range config.TrimSuffix { | |
| prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c)) | |
| } | |
| return prediction | |
| } | |