| | package openai |
| |
|
| | import ( |
| | "context" |
| | "fmt" |
| |
|
| | "github.com/mudler/LocalAI/core/backend" |
| | "github.com/mudler/LocalAI/core/config" |
| | grpcClient "github.com/mudler/LocalAI/pkg/grpc" |
| | "github.com/mudler/LocalAI/pkg/grpc/proto" |
| | model "github.com/mudler/LocalAI/pkg/model" |
| | "github.com/mudler/xlog" |
| | "google.golang.org/grpc" |
| | ) |
| |
|
| | var ( |
| | _ Model = new(wrappedModel) |
| | _ Model = new(anyToAnyModel) |
| | ) |
| |
|
| | |
| | |
| | |
| | type wrappedModel struct { |
| | TTSConfig *config.ModelConfig |
| | TranscriptionConfig *config.ModelConfig |
| | LLMConfig *config.ModelConfig |
| | TTSClient grpcClient.Backend |
| | TranscriptionClient grpcClient.Backend |
| | LLMClient grpcClient.Backend |
| |
|
| | VADConfig *config.ModelConfig |
| | VADClient grpcClient.Backend |
| | } |
| |
|
| | |
| | |
| | |
| | type anyToAnyModel struct { |
| | LLMConfig *config.ModelConfig |
| | LLMClient grpcClient.Backend |
| |
|
| | VADConfig *config.ModelConfig |
| | VADClient grpcClient.Backend |
| | } |
| |
|
| | type transcriptOnlyModel struct { |
| | TranscriptionConfig *config.ModelConfig |
| | TranscriptionClient grpcClient.Backend |
| | VADConfig *config.ModelConfig |
| | VADClient grpcClient.Backend |
| | } |
| |
|
| | func (m *transcriptOnlyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { |
| | return m.VADClient.VAD(ctx, in) |
| | } |
| |
|
| | func (m *transcriptOnlyModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) { |
| | return m.TranscriptionClient.AudioTranscription(ctx, in, opts...) |
| | } |
| |
|
| | func (m *transcriptOnlyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { |
| | return nil, fmt.Errorf("predict operation not supported in transcript-only mode") |
| | } |
| |
|
| | func (m *transcriptOnlyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { |
| | return fmt.Errorf("predict stream operation not supported in transcript-only mode") |
| | } |
| |
|
| | func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { |
| | return m.VADClient.VAD(ctx, in) |
| | } |
| |
|
| | func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { |
| | return m.VADClient.VAD(ctx, in) |
| | } |
| |
|
| | func (m *wrappedModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) { |
| | return m.TranscriptionClient.AudioTranscription(ctx, in, opts...) |
| | } |
| |
|
| | func (m *anyToAnyModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) { |
| | |
| | return m.LLMClient.AudioTranscription(ctx, in, opts...) |
| | } |
| |
|
| | func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { |
| | |
| | |
| |
|
| | return m.LLMClient.Predict(ctx, in) |
| | } |
| |
|
| | func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { |
| | |
| |
|
| | return m.LLMClient.PredictStream(ctx, in, f) |
| | } |
| |
|
| | func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { |
| | return m.LLMClient.Predict(ctx, in) |
| | } |
| |
|
| | func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { |
| | return m.LLMClient.PredictStream(ctx, in, f) |
| | } |
| |
|
| | func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) { |
| | cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgVAD.Validate(); !valid { |
| | return nil, nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | opts := backend.ModelOptions(*cfgVAD, appConfig) |
| | VADClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, nil, fmt.Errorf("failed to load tts model: %w", err) |
| | } |
| |
|
| | cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgSST.Validate(); !valid { |
| | return nil, nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | opts = backend.ModelOptions(*cfgSST, appConfig) |
| | transcriptionClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, nil, fmt.Errorf("failed to load SST model: %w", err) |
| | } |
| |
|
| | return &transcriptOnlyModel{ |
| | VADConfig: cfgVAD, |
| | VADClient: VADClient, |
| | TranscriptionConfig: cfgSST, |
| | TranscriptionClient: transcriptionClient, |
| | }, cfgSST, nil |
| | } |
| |
|
| | |
| | func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) { |
| |
|
| | cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgVAD.Validate(); !valid { |
| | return nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | opts := backend.ModelOptions(*cfgVAD, appConfig) |
| | VADClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to load tts model: %w", err) |
| | } |
| |
|
| | |
| | cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgSST.Validate(); !valid { |
| | return nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | opts = backend.ModelOptions(*cfgSST, appConfig) |
| | transcriptionClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to load SST model: %w", err) |
| | } |
| |
|
| | |
| | if false { |
| |
|
| | cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgAnyToAny.Validate(); !valid { |
| | return nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | opts := backend.ModelOptions(*cfgAnyToAny, appConfig) |
| | anyToAnyClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to load tts model: %w", err) |
| | } |
| |
|
| | return &anyToAnyModel{ |
| | LLMConfig: cfgAnyToAny, |
| | LLMClient: anyToAnyClient, |
| | VADConfig: cfgVAD, |
| | VADClient: VADClient, |
| | }, nil |
| | } |
| |
|
| | xlog.Debug("Loading a wrapped model") |
| |
|
| | |
| | cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgLLM.Validate(); !valid { |
| | return nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath) |
| | if err != nil { |
| |
|
| | return nil, fmt.Errorf("failed to load backend config: %w", err) |
| | } |
| |
|
| | if valid, _ := cfgTTS.Validate(); !valid { |
| | return nil, fmt.Errorf("failed to validate config: %w", err) |
| | } |
| |
|
| | opts = backend.ModelOptions(*cfgTTS, appConfig) |
| | ttsClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to load tts model: %w", err) |
| | } |
| |
|
| | opts = backend.ModelOptions(*cfgLLM, appConfig) |
| | llmClient, err := ml.Load(opts...) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to load LLM model: %w", err) |
| | } |
| |
|
| | return &wrappedModel{ |
| | TTSConfig: cfgTTS, |
| | TranscriptionConfig: cfgSST, |
| | LLMConfig: cfgLLM, |
| | TTSClient: ttsClient, |
| | TranscriptionClient: transcriptionClient, |
| | LLMClient: llmClient, |
| |
|
| | VADConfig: cfgVAD, |
| | VADClient: VADClient, |
| | }, nil |
| | } |
| |
|