Spaces:
Running
Running
Amlan-109
feat: Initial commit of LocalAI Amlan Edition with premium branding and personalization
750bbe6 | package openai | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "github.com/mudler/LocalAI/core/backend" | |
| "github.com/mudler/LocalAI/core/config" | |
| "github.com/mudler/LocalAI/core/http/endpoints/openai/types" | |
| "github.com/mudler/LocalAI/core/schema" | |
| "github.com/mudler/LocalAI/core/templates" | |
| "github.com/mudler/LocalAI/pkg/functions" | |
| "github.com/mudler/LocalAI/pkg/grpc/proto" | |
| model "github.com/mudler/LocalAI/pkg/model" | |
| "github.com/mudler/xlog" | |
| ) | |
| var ( | |
| _ Model = new(wrappedModel) | |
| _ Model = new(transcriptOnlyModel) | |
| ) | |
| // wrappedModel represent a model which does not support Any-to-Any operations | |
| // This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods | |
| // which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS) | |
| type wrappedModel struct { | |
| TTSConfig *config.ModelConfig | |
| TranscriptionConfig *config.ModelConfig | |
| LLMConfig *config.ModelConfig | |
| VADConfig *config.ModelConfig | |
| appConfig *config.ApplicationConfig | |
| modelLoader *model.ModelLoader | |
| confLoader *config.ModelConfigLoader | |
| evaluator *templates.Evaluator | |
| } | |
| // anyToAnyModel represent a model which supports Any-to-Any operations | |
| // We have to wrap this out as well because we want to load two models one for VAD and one for the actual model. | |
| // In the future there could be models that accept continous audio input only so this design will be useful for that | |
| type anyToAnyModel struct { | |
| LLMConfig *config.ModelConfig | |
| VADConfig *config.ModelConfig | |
| appConfig *config.ApplicationConfig | |
| modelLoader *model.ModelLoader | |
| confLoader *config.ModelConfigLoader | |
| } | |
| type transcriptOnlyModel struct { | |
| TranscriptionConfig *config.ModelConfig | |
| VADConfig *config.ModelConfig | |
| appConfig *config.ApplicationConfig | |
| modelLoader *model.ModelLoader | |
| confLoader *config.ModelConfigLoader | |
| } | |
| func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) { | |
| return backend.VAD(request, ctx, m.modelLoader, m.appConfig, *m.VADConfig) | |
| } | |
| func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { | |
| return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) | |
| } | |
| func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { | |
| return nil, fmt.Errorf("predict operation not supported in transcript-only mode") | |
| } | |
| func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) { | |
| return "", nil, fmt.Errorf("TTS not supported in transcript-only mode") | |
| } | |
| func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig { | |
| return nil | |
| } | |
| func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) { | |
| return backend.VAD(request, ctx, m.modelLoader, m.appConfig, *m.VADConfig) | |
| } | |
| func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { | |
| return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) | |
| } | |
| func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { | |
| input := schema.OpenAIRequest{ | |
| Messages: messages, | |
| } | |
| var predInput string | |
| var funcs []functions.Function | |
| if !m.LLMConfig.TemplateConfig.UseTokenizerTemplate { | |
| if len(tools) > 0 { | |
| for _, t := range tools { | |
| if t.Function != nil { | |
| var params map[string]any | |
| switch p := t.Function.Parameters.(type) { | |
| case map[string]any: | |
| params = p | |
| case string: | |
| if err := json.Unmarshal([]byte(p), ¶ms); err != nil { | |
| xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name) | |
| } | |
| } | |
| funcs = append(funcs, functions.Function{ | |
| Name: t.Function.Name, | |
| Description: t.Function.Description, | |
| Parameters: params, | |
| }) | |
| } | |
| } | |
| // Add noAction function before templating so it's included in the prompt | |
| // Allow the user to set custom actions via config file | |
| noActionName := "answer" | |
| noActionDescription := "use this action to answer without performing any action" | |
| if m.LLMConfig.FunctionsConfig.NoActionFunctionName != "" { | |
| noActionName = m.LLMConfig.FunctionsConfig.NoActionFunctionName | |
| } | |
| if m.LLMConfig.FunctionsConfig.NoActionDescriptionName != "" { | |
| noActionDescription = m.LLMConfig.FunctionsConfig.NoActionDescriptionName | |
| } | |
| noActionGrammar := functions.Function{ | |
| Name: noActionName, | |
| Description: noActionDescription, | |
| Parameters: map[string]interface{}{ | |
| "properties": map[string]interface{}{ | |
| "message": map[string]interface{}{ | |
| "type": "string", | |
| "description": "The message to reply the user with", | |
| }, | |
| }, | |
| }, | |
| } | |
| if !m.LLMConfig.FunctionsConfig.DisableNoAction { | |
| funcs = append(funcs, noActionGrammar) | |
| } | |
| } | |
| predInput = m.evaluator.TemplateMessages(input, input.Messages, m.LLMConfig, funcs, len(funcs) > 0) | |
| xlog.Debug("Prompt (after templating)", "prompt", predInput) | |
| if m.LLMConfig.Grammar != "" { | |
| xlog.Debug("Grammar", "grammar", m.LLMConfig.Grammar) | |
| } | |
| } | |
| // Handle tool_choice parameter similar to the chat endpoint | |
| if toolChoice != nil { | |
| if toolChoice.Mode != "" { | |
| // String values: "auto", "required", "none" | |
| switch toolChoice.Mode { | |
| case types.ToolChoiceModeRequired: | |
| m.LLMConfig.SetFunctionCallString("required") | |
| case types.ToolChoiceModeNone: | |
| // Don't use tools | |
| m.LLMConfig.SetFunctionCallString("none") | |
| case types.ToolChoiceModeAuto: | |
| // Default behavior - let model decide | |
| } | |
| } else if toolChoice.Function != nil { | |
| // Specific function specified | |
| m.LLMConfig.SetFunctionCallString(toolChoice.Function.Name) | |
| } | |
| } | |
| // Generate grammar for function calling if tools are provided and grammar generation is enabled | |
| shouldUseFn := len(tools) > 0 && m.LLMConfig.ShouldUseFunctions() | |
| if !m.LLMConfig.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn { | |
| // Force picking one of the functions by the request | |
| if m.LLMConfig.FunctionToCall() != "" { | |
| funcs = functions.Functions(funcs).Select(m.LLMConfig.FunctionToCall()) | |
| } | |
| // Generate grammar from function definitions | |
| jsStruct := functions.Functions(funcs).ToJSONStructure(m.LLMConfig.FunctionsConfig.FunctionNameKey, m.LLMConfig.FunctionsConfig.FunctionNameKey) | |
| g, err := jsStruct.Grammar(m.LLMConfig.FunctionsConfig.GrammarOptions()...) | |
| if err == nil { | |
| m.LLMConfig.Grammar = g | |
| xlog.Debug("Generated grammar for function calling", "grammar", g) | |
| } else { | |
| xlog.Error("Failed generating grammar", "error", err) | |
| } | |
| } | |
| var toolsJSON string | |
| if len(tools) > 0 { | |
| b, _ := json.Marshal(tools) | |
| toolsJSON = string(b) | |
| } | |
| var toolChoiceJSON string | |
| if toolChoice != nil { | |
| b, _ := json.Marshal(toolChoice) | |
| toolChoiceJSON = string(b) | |
| } | |
| return backend.ModelInference(ctx, predInput, messages, images, videos, audios, m.modelLoader, m.LLMConfig, m.confLoader, m.appConfig, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias) | |
| } | |
| func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) { | |
| return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig) | |
| } | |
| func (m *wrappedModel) PredictConfig() *config.ModelConfig { | |
| return m.LLMConfig | |
| } | |
| func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) { | |
| cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) | |
| if err != nil { | |
| return nil, nil, fmt.Errorf("failed to load backend config: %w", err) | |
| } | |
| if valid, _ := cfgVAD.Validate(); !valid { | |
| return nil, nil, fmt.Errorf("failed to validate config: %w", err) | |
| } | |
| cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) | |
| if err != nil { | |
| return nil, nil, fmt.Errorf("failed to load backend config: %w", err) | |
| } | |
| if valid, _ := cfgSST.Validate(); !valid { | |
| return nil, nil, fmt.Errorf("failed to validate config: %w", err) | |
| } | |
| return &transcriptOnlyModel{ | |
| TranscriptionConfig: cfgSST, | |
| VADConfig: cfgVAD, | |
| confLoader: cl, | |
| modelLoader: ml, | |
| appConfig: appConfig, | |
| }, cfgSST, nil | |
| } | |
| // returns and loads either a wrapped model or a model that support audio-to-audio | |
| func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator) (Model, error) { | |
| xlog.Debug("Creating new model pipeline model", "pipeline", pipeline) | |
| cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to load backend config: %w", err) | |
| } | |
| if valid, _ := cfgVAD.Validate(); !valid { | |
| return nil, fmt.Errorf("failed to validate config: %w", err) | |
| } | |
| // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process | |
| cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to load backend config: %w", err) | |
| } | |
| if valid, _ := cfgSST.Validate(); !valid { | |
| return nil, fmt.Errorf("failed to validate config: %w", err) | |
| } | |
| // TODO: Decide when we have a real any-to-any model | |
| // if false { | |
| // | |
| // cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) | |
| // if err != nil { | |
| // | |
| // return nil, fmt.Errorf("failed to load backend config: %w", err) | |
| // } | |
| // | |
| // if valid, _ := cfgAnyToAny.Validate(); !valid { | |
| // return nil, fmt.Errorf("failed to validate config: %w", err) | |
| // } | |
| // | |
| // return &anyToAnyModel{ | |
| // LLMConfig: cfgAnyToAny, | |
| // VADConfig: cfgVAD, | |
| // }, nil | |
| // } | |
| xlog.Debug("Loading a wrapped model") | |
| // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations | |
| cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to load backend config: %w", err) | |
| } | |
| if valid, _ := cfgLLM.Validate(); !valid { | |
| return nil, fmt.Errorf("failed to validate config: %w", err) | |
| } | |
| cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to load backend config: %w", err) | |
| } | |
| if valid, _ := cfgTTS.Validate(); !valid { | |
| return nil, fmt.Errorf("failed to validate config: %w", err) | |
| } | |
| return &wrappedModel{ | |
| TTSConfig: cfgTTS, | |
| TranscriptionConfig: cfgSST, | |
| LLMConfig: cfgLLM, | |
| VADConfig: cfgVAD, | |
| confLoader: cl, | |
| modelLoader: ml, | |
| appConfig: appConfig, | |
| evaluator: evaluator, | |
| }, nil | |
| } | |