| package openai |
|
|
| import ( |
| "context" |
| "encoding/base64" |
| "encoding/json" |
| "fmt" |
| "os" |
| "strings" |
| "sync" |
| "time" |
|
|
| "net/http" |
|
|
| "github.com/go-audio/audio" |
| "github.com/gorilla/websocket" |
| "github.com/labstack/echo/v4" |
| "github.com/mudler/LocalAI/core/application" |
| "github.com/mudler/LocalAI/core/config" |
| "github.com/mudler/LocalAI/core/http/endpoints/openai/types" |
| "github.com/mudler/LocalAI/core/templates" |
| laudio "github.com/mudler/LocalAI/pkg/audio" |
| "github.com/mudler/LocalAI/pkg/functions" |
| "github.com/mudler/LocalAI/pkg/grpc/proto" |
| model "github.com/mudler/LocalAI/pkg/model" |
| "github.com/mudler/LocalAI/pkg/sound" |
|
|
| "google.golang.org/grpc" |
|
|
| "github.com/mudler/xlog" |
| ) |
|
|
| const ( |
| localSampleRate = 16000 |
| remoteSampleRate = 24000 |
| vadModel = "silero-vad-ggml" |
| ) |
|
|
| |
| |
|
|
| |
| type Session struct { |
| ID string |
| TranscriptionOnly bool |
| Model string |
| Voice string |
| TurnDetection *types.ServerTurnDetection `json:"turn_detection"` |
| InputAudioTranscription *types.InputAudioTranscription |
| Functions functions.Functions |
| Conversations map[string]*Conversation |
| InputAudioBuffer []byte |
| AudioBufferLock sync.Mutex |
| Instructions string |
| DefaultConversationID string |
| ModelInterface Model |
| } |
|
|
| func (s *Session) FromClient(session *types.ClientSession) { |
| } |
|
|
| func (s *Session) ToServer() types.ServerSession { |
| return types.ServerSession{ |
| ID: s.ID, |
| Object: func() string { |
| if s.TranscriptionOnly { |
| return "realtime.transcription_session" |
| } else { |
| return "realtime.session" |
| } |
| }(), |
| Model: s.Model, |
| Modalities: []types.Modality{types.ModalityText, types.ModalityAudio}, |
| Instructions: s.Instructions, |
| Voice: s.Voice, |
| InputAudioFormat: types.AudioFormatPcm16, |
| OutputAudioFormat: types.AudioFormatPcm16, |
| TurnDetection: s.TurnDetection, |
| InputAudioTranscription: s.InputAudioTranscription, |
| |
| Tools: []types.Tool{}, |
| |
| |
| |
| |
| } |
| } |
|
|
| |
| |
| type FunctionCall struct { |
| Name string `json:"name"` |
| Arguments map[string]interface{} `json:"arguments"` |
| } |
|
|
| |
| type Conversation struct { |
| ID string |
| Items []*types.MessageItem |
| Lock sync.Mutex |
| } |
|
|
| func (c *Conversation) ToServer() types.Conversation { |
| return types.Conversation{ |
| ID: c.ID, |
| Object: "realtime.conversation", |
| } |
| } |
|
|
| |
| type Item struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Type string `json:"type"` |
| Status string `json:"status"` |
| Role string `json:"role"` |
| Content []ConversationContent `json:"content,omitempty"` |
| FunctionCall *FunctionCall `json:"function_call,omitempty"` |
| } |
|
|
| |
| type ConversationContent struct { |
| Type string `json:"type"` |
| Audio string `json:"audio,omitempty"` |
| Text string `json:"text,omitempty"` |
| |
| } |
|
|
| |
| type IncomingMessage struct { |
| Type types.ClientEventType `json:"type"` |
| Session json.RawMessage `json:"session,omitempty"` |
| Item json.RawMessage `json:"item,omitempty"` |
| Audio string `json:"audio,omitempty"` |
| Response json.RawMessage `json:"response,omitempty"` |
| Error *ErrorMessage `json:"error,omitempty"` |
| |
| } |
|
|
| |
| type ErrorMessage struct { |
| Type string `json:"type"` |
| Code string `json:"code"` |
| Message string `json:"message"` |
| Param string `json:"param,omitempty"` |
| EventID string `json:"event_id,omitempty"` |
| } |
|
|
| |
| type OutgoingMessage struct { |
| Type string `json:"type"` |
| Session *Session `json:"session,omitempty"` |
| Conversation *Conversation `json:"conversation,omitempty"` |
| Item *Item `json:"item,omitempty"` |
| Content string `json:"content,omitempty"` |
| Audio string `json:"audio,omitempty"` |
| Error *ErrorMessage `json:"error,omitempty"` |
| } |
|
|
| |
| var sessions = make(map[string]*Session) |
| var sessionLock sync.Mutex |
|
|
| |
| type Model interface { |
| VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) |
| Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) |
| Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) |
| PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error |
| } |
|
|
| var upgrader = websocket.Upgrader{ |
| CheckOrigin: func(r *http.Request) bool { |
| return true |
| }, |
| } |
|
|
| |
| func RealtimeSessions(application *application.Application) echo.HandlerFunc { |
| return func(c echo.Context) error { |
| return c.NoContent(501) |
| } |
| } |
|
|
| func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc { |
| return func(c echo.Context) error { |
| return c.NoContent(501) |
| } |
| } |
|
|
| func Realtime(application *application.Application) echo.HandlerFunc { |
| return func(c echo.Context) error { |
| ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) |
| if err != nil { |
| return err |
| } |
| defer ws.Close() |
|
|
| |
| model := c.QueryParam("model") |
| if model == "" { |
| model = "gpt-4o" |
| } |
| intent := c.QueryParam("intent") |
|
|
| registerRealtime(application, model, intent)(ws) |
| return nil |
| } |
| } |
|
|
| func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) { |
| return func(c *websocket.Conn) { |
|
|
| evaluator := application.TemplatesEvaluator() |
| xlog.Debug("WebSocket connection established", "address", c.RemoteAddr().String()) |
| if intent != "transcription" { |
| sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter") |
| } |
|
|
| xlog.Debug("Realtime params", "model", model, "intent", intent) |
|
|
| sessionID := generateSessionID() |
| session := &Session{ |
| ID: sessionID, |
| TranscriptionOnly: true, |
| Model: model, |
| Voice: "alloy", |
| TurnDetection: &types.ServerTurnDetection{ |
| Type: types.ServerTurnDetectionTypeServerVad, |
| TurnDetectionParams: types.TurnDetectionParams{ |
| |
| Threshold: 0.5, |
| |
| PrefixPaddingMs: 30, |
| SilenceDurationMs: 500, |
| CreateResponse: func() *bool { t := true; return &t }(), |
| }, |
| }, |
| InputAudioTranscription: &types.InputAudioTranscription{ |
| Model: "whisper-1", |
| }, |
| Conversations: make(map[string]*Conversation), |
| } |
|
|
| |
| conversationID := generateConversationID() |
| conversation := &Conversation{ |
| ID: conversationID, |
| Items: []*types.MessageItem{}, |
| } |
| session.Conversations[conversationID] = conversation |
| session.DefaultConversationID = conversationID |
|
|
| |
| |
| pipeline := config.Pipeline{ |
| VAD: vadModel, |
| Transcription: session.InputAudioTranscription.Model, |
| } |
|
|
| m, cfg, err := newTranscriptionOnlyModel( |
| &pipeline, |
| application.ModelConfigLoader(), |
| application.ModelLoader(), |
| application.ApplicationConfig(), |
| ) |
| if err != nil { |
| xlog.Error("failed to load model", "error", err) |
| sendError(c, "model_load_error", "Failed to load model", "", "") |
| return |
| } |
| session.ModelInterface = m |
|
|
| |
| sessionLock.Lock() |
| sessions[sessionID] = session |
| sessionLock.Unlock() |
|
|
| sendEvent(c, types.TranscriptionSessionCreatedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeTranscriptionSessionCreated, |
| }, |
| Session: session.ToServer(), |
| }) |
|
|
| var ( |
| |
| msg []byte |
| wg sync.WaitGroup |
| done = make(chan struct{}) |
| ) |
|
|
| vadServerStarted := true |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| conversation := session.Conversations[session.DefaultConversationID] |
| handleVAD(cfg, evaluator, session, conversation, c, done) |
| }() |
|
|
| for { |
| if _, msg, err = c.ReadMessage(); err != nil { |
| xlog.Error("read error", "error", err) |
| break |
| } |
|
|
| |
| var incomingMsg IncomingMessage |
| if err := json.Unmarshal(msg, &incomingMsg); err != nil { |
| xlog.Error("invalid json", "error", err) |
| sendError(c, "invalid_json", "Invalid JSON format", "", "") |
| continue |
| } |
|
|
| var sessionUpdate types.ClientSession |
| switch incomingMsg.Type { |
| case types.ClientEventTypeTranscriptionSessionUpdate: |
| xlog.Debug("recv", "message", string(msg)) |
|
|
| if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { |
| xlog.Error("failed to unmarshal 'transcription_session.update'", "error", err) |
| sendError(c, "invalid_session_update", "Invalid session update format", "", "") |
| continue |
| } |
| if err := updateTransSession( |
| session, |
| &sessionUpdate, |
| application.ModelConfigLoader(), |
| application.ModelLoader(), |
| application.ApplicationConfig(), |
| ); err != nil { |
| xlog.Error("failed to update session", "error", err) |
| sendError(c, "session_update_error", "Failed to update session", "", "") |
| continue |
| } |
|
|
| sendEvent(c, types.SessionUpdatedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeTranscriptionSessionUpdated, |
| }, |
| Session: session.ToServer(), |
| }) |
|
|
| case types.ClientEventTypeSessionUpdate: |
| xlog.Debug("recv", "message", string(msg)) |
|
|
| |
| if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { |
| xlog.Error("failed to unmarshal 'session.update'", "error", err) |
| sendError(c, "invalid_session_update", "Invalid session update format", "", "") |
| continue |
| } |
| if err := updateSession( |
| session, |
| &sessionUpdate, |
| application.ModelConfigLoader(), |
| application.ModelLoader(), |
| application.ApplicationConfig(), |
| ); err != nil { |
| xlog.Error("failed to update session", "error", err) |
| sendError(c, "session_update_error", "Failed to update session", "", "") |
| continue |
| } |
|
|
| sendEvent(c, types.SessionUpdatedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeSessionUpdated, |
| }, |
| Session: session.ToServer(), |
| }) |
|
|
| if session.TurnDetection.Type == types.ServerTurnDetectionTypeServerVad && !vadServerStarted { |
| xlog.Debug("Starting VAD goroutine...") |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| conversation := session.Conversations[session.DefaultConversationID] |
| handleVAD(cfg, evaluator, session, conversation, c, done) |
| }() |
| vadServerStarted = true |
| } else if session.TurnDetection.Type != types.ServerTurnDetectionTypeServerVad && vadServerStarted { |
| xlog.Debug("Stopping VAD goroutine...") |
|
|
| wg.Add(-1) |
| go func() { |
| done <- struct{}{} |
| }() |
| vadServerStarted = false |
| } |
| case types.ClientEventTypeInputAudioBufferAppend: |
| |
| if incomingMsg.Audio == "" { |
| xlog.Error("Audio data is missing in 'input_audio_buffer.append'") |
| sendError(c, "missing_audio_data", "Audio data is missing", "", "") |
| continue |
| } |
|
|
| |
| decodedAudio, err := base64.StdEncoding.DecodeString(incomingMsg.Audio) |
| if err != nil { |
| xlog.Error("failed to decode audio data", "error", err) |
| sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") |
| continue |
| } |
|
|
| |
| session.AudioBufferLock.Lock() |
| session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) |
| session.AudioBufferLock.Unlock() |
|
|
| case types.ClientEventTypeInputAudioBufferCommit: |
| xlog.Debug("recv", "message", string(msg)) |
|
|
| |
| |
|
|
| if session.TranscriptionOnly { |
| continue |
| } |
|
|
| |
| item := &types.MessageItem{ |
| ID: generateItemID(), |
| Type: "message", |
| Status: "completed", |
| Role: "user", |
| Content: []types.MessageContentPart{ |
| { |
| Type: "input_audio", |
| Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer), |
| }, |
| }, |
| } |
|
|
| |
| conversation.Lock.Lock() |
| conversation.Items = append(conversation.Items, item) |
| conversation.Lock.Unlock() |
|
|
| |
| session.AudioBufferLock.Lock() |
| session.InputAudioBuffer = nil |
| session.AudioBufferLock.Unlock() |
|
|
| |
| sendEvent(c, types.ConversationItemCreatedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: "conversation.item.created", |
| }, |
| Item: types.ResponseMessageItem{ |
| Object: "realtime.item", |
| MessageItem: *item, |
| }, |
| }) |
|
|
| case types.ClientEventTypeConversationItemCreate: |
| xlog.Debug("recv", "message", string(msg)) |
|
|
| |
| var item types.ConversationItemCreateEvent |
| if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { |
| xlog.Error("failed to unmarshal 'conversation.item.create'", "error", err) |
| sendError(c, "invalid_item", "Invalid item format", "", "") |
| continue |
| } |
|
|
| sendNotImplemented(c, "conversation.item.create") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| case types.ClientEventTypeConversationItemDelete: |
| sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO") |
|
|
| case types.ClientEventTypeResponseCreate: |
| |
| var responseCreate types.ResponseCreateEvent |
| if len(incomingMsg.Response) > 0 { |
| if err := json.Unmarshal(incomingMsg.Response, &responseCreate); err != nil { |
| xlog.Error("failed to unmarshal 'response.create' response object", "error", err) |
| sendError(c, "invalid_response_create", "Invalid response create format", "", "") |
| continue |
| } |
| } |
|
|
| |
| if len(responseCreate.Response.Tools) > 0 { |
| |
| } |
|
|
| sendNotImplemented(c, "response.create") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| case types.ClientEventTypeResponseCancel: |
| xlog.Debug("recv", "message", string(msg)) |
|
|
| |
| |
| sendNotImplemented(c, "response.cancel") |
|
|
| default: |
| xlog.Error("unknown message type", "type", incomingMsg.Type) |
| sendError(c, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") |
| } |
| } |
|
|
| |
| close(done) |
| wg.Wait() |
|
|
| |
| sessionLock.Lock() |
| delete(sessions, sessionID) |
| sessionLock.Unlock() |
| } |
| } |
|
|
| |
| func sendEvent(c *websocket.Conn, event types.ServerEvent) { |
| eventBytes, err := json.Marshal(event) |
| if err != nil { |
| xlog.Error("failed to marshal event", "error", err) |
| return |
| } |
| if err = c.WriteMessage(websocket.TextMessage, eventBytes); err != nil { |
| xlog.Error("write error", "error", err) |
| } |
| } |
|
|
| |
| func sendError(c *websocket.Conn, code, message, param, eventID string) { |
| errorEvent := types.ErrorEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| Type: types.ServerEventTypeError, |
| EventID: eventID, |
| }, |
| Error: types.Error{ |
| Type: "invalid_request_error", |
| Code: code, |
| Message: message, |
| EventID: eventID, |
| }, |
| } |
|
|
| sendEvent(c, errorEvent) |
| } |
|
|
| func sendNotImplemented(c *websocket.Conn, message string) { |
| sendError(c, "not_implemented", message, "", "event_TODO") |
| } |
|
|
| func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { |
| sessionLock.Lock() |
| defer sessionLock.Unlock() |
|
|
| trUpd := update.InputAudioTranscription |
| trCur := session.InputAudioTranscription |
|
|
| if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model { |
| pipeline := config.Pipeline{ |
| VAD: vadModel, |
| Transcription: trUpd.Model, |
| } |
|
|
| m, _, err := newTranscriptionOnlyModel(&pipeline, cl, ml, appConfig) |
| if err != nil { |
| return err |
| } |
|
|
| session.ModelInterface = m |
| } |
|
|
| if trUpd != nil { |
| trCur.Language = trUpd.Language |
| trCur.Prompt = trUpd.Prompt |
| } |
|
|
| if update.TurnDetection != nil && update.TurnDetection.Type != "" { |
| session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type) |
| session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams |
| } |
|
|
| return nil |
| } |
|
|
| |
| func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { |
| sessionLock.Lock() |
| defer sessionLock.Unlock() |
|
|
| if update.Model != "" { |
| pipeline := config.Pipeline{ |
| LLM: update.Model, |
| |
| } |
| m, err := newModel(&pipeline, cl, ml, appConfig) |
| if err != nil { |
| return err |
| } |
| session.ModelInterface = m |
| session.Model = update.Model |
| } |
|
|
| if update.Voice != "" { |
| session.Voice = update.Voice |
| } |
| if update.TurnDetection != nil && update.TurnDetection.Type != "" { |
| session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type) |
| session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams |
| } |
| |
| if update.Instructions != "" { |
| session.Instructions = update.Instructions |
| } |
| if update.Tools != nil { |
| return fmt.Errorf("Haven't implemented tools") |
| } |
|
|
| session.InputAudioTranscription = update.InputAudioTranscription |
|
|
| return nil |
| } |
|
|
| |
| |
| func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { |
| vadContext, cancel := context.WithCancel(context.Background()) |
| go func() { |
| <-done |
| cancel() |
| }() |
|
|
| silenceThreshold := float64(session.TurnDetection.SilenceDurationMs) / 1000 |
| speechStarted := false |
| startTime := time.Now() |
|
|
| ticker := time.NewTicker(300 * time.Millisecond) |
| defer ticker.Stop() |
|
|
| for { |
| select { |
| case <-done: |
| return |
| case <-ticker.C: |
| session.AudioBufferLock.Lock() |
| allAudio := make([]byte, len(session.InputAudioBuffer)) |
| copy(allAudio, session.InputAudioBuffer) |
| session.AudioBufferLock.Unlock() |
|
|
| aints := sound.BytesToInt16sLE(allAudio) |
| if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate { |
| continue |
| } |
|
|
| |
| aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate) |
|
|
| segments, err := runVAD(vadContext, session, aints) |
| if err != nil { |
| if err.Error() == "unexpected speech end" { |
| xlog.Debug("VAD cancelled") |
| continue |
| } |
| xlog.Error("failed to process audio", "error", err) |
| sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") |
| continue |
| } |
|
|
| audioLength := float64(len(aints)) / localSampleRate |
|
|
| |
| |
| if len(segments) == 0 && audioLength > silenceThreshold { |
| session.AudioBufferLock.Lock() |
| session.InputAudioBuffer = nil |
| session.AudioBufferLock.Unlock() |
| xlog.Debug("Detected silence for a while, clearing audio buffer") |
|
|
| sendEvent(c, types.InputAudioBufferClearedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeInputAudioBufferCleared, |
| }, |
| }) |
|
|
| continue |
| } else if len(segments) == 0 { |
| continue |
| } |
|
|
| if !speechStarted { |
| sendEvent(c, types.InputAudioBufferSpeechStartedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeInputAudioBufferSpeechStarted, |
| }, |
| AudioStartMs: time.Now().Sub(startTime).Milliseconds(), |
| }) |
| speechStarted = true |
| } |
|
|
| |
| segEndTime := segments[len(segments)-1].GetEnd() |
| if segEndTime == 0 { |
| continue |
| } |
|
|
| if float32(audioLength)-segEndTime > float32(silenceThreshold) { |
| xlog.Debug("Detected end of speech segment") |
| session.AudioBufferLock.Lock() |
| session.InputAudioBuffer = nil |
| session.AudioBufferLock.Unlock() |
|
|
| sendEvent(c, types.InputAudioBufferSpeechStoppedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeInputAudioBufferSpeechStopped, |
| }, |
| AudioEndMs: time.Now().Sub(startTime).Milliseconds(), |
| }) |
| speechStarted = false |
|
|
| sendEvent(c, types.InputAudioBufferCommittedEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| EventID: "event_TODO", |
| Type: types.ServerEventTypeInputAudioBufferCommitted, |
| }, |
| ItemID: generateItemID(), |
| PreviousItemID: "TODO", |
| }) |
|
|
| abytes := sound.Int16toBytesLE(aints) |
| |
| go commitUtterance(vadContext, abytes, cfg, evaluator, session, conv, c) |
| } |
| } |
| } |
| } |
|
|
| func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) { |
| if len(utt) == 0 { |
| return |
| } |
|
|
| |
|
|
| f, err := os.CreateTemp("", "realtime-audio-chunk-*.wav") |
| if err != nil { |
| xlog.Error("failed to create temp file", "error", err) |
| return |
| } |
| defer f.Close() |
| defer os.Remove(f.Name()) |
| xlog.Debug("Writing to file", "file", f.Name()) |
|
|
| hdr := laudio.NewWAVHeader(uint32(len(utt))) |
| if err := hdr.Write(f); err != nil { |
| xlog.Error("Failed to write WAV header", "error", err) |
| return |
| } |
|
|
| if _, err := f.Write(utt); err != nil { |
| xlog.Error("Failed to write audio data", "error", err) |
| return |
| } |
|
|
| f.Sync() |
|
|
| if session.InputAudioTranscription != nil { |
| tr, err := session.ModelInterface.Transcribe(ctx, &proto.TranscriptRequest{ |
| Dst: f.Name(), |
| Language: session.InputAudioTranscription.Language, |
| Translate: false, |
| Threads: uint32(*cfg.Threads), |
| Prompt: session.InputAudioTranscription.Prompt, |
| }) |
| if err != nil { |
| sendError(c, "transcription_failed", err.Error(), "", "event_TODO") |
| } |
|
|
| sendEvent(c, types.ResponseAudioTranscriptDoneEvent{ |
| ServerEventBase: types.ServerEventBase{ |
| Type: types.ServerEventTypeResponseAudioTranscriptDone, |
| EventID: "event_TODO", |
| }, |
|
|
| ItemID: generateItemID(), |
| ResponseID: "resp_TODO", |
| OutputIndex: 0, |
| ContentIndex: 0, |
| Transcript: tr.GetText(), |
| }) |
| |
| } |
|
|
| if !session.TranscriptionOnly { |
| sendNotImplemented(c, "Commiting items to the conversation not implemented") |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| } |
|
|
| func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADSegment, error) { |
| soundIntBuffer := &audio.IntBuffer{ |
| Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1}, |
| SourceBitDepth: 16, |
| Data: sound.ConvertInt16ToInt(adata), |
| } |
|
|
| float32Data := soundIntBuffer.AsFloat32Buffer().Data |
|
|
| resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{ |
| Audio: float32Data, |
| }) |
| if err != nil { |
| return nil, err |
| } |
|
|
| |
| return resp.Segments, nil |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) { |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| if strings.Contains(prompt, "weather") { |
| functionCall := &FunctionCall{ |
| Name: "get_weather", |
| Arguments: map[string]interface{}{ |
| "location": "New York", |
| "scale": "celsius", |
| }, |
| } |
| return "", functionCall, nil |
| } |
|
|
| |
| return "This is a generated response based on the conversation.", nil, nil |
| } |
|
|
| |
| func processAudioResponse(session *Session, audioData []byte) (string, []byte, *FunctionCall, error) { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{ |
| Prompt: "What's the weather in New York?", |
| }) |
|
|
| if err != nil { |
| return "", nil, nil, err |
| } |
|
|
| generatedAudio := reply.Audio |
|
|
| transcribedText := "What's the weather in New York?" |
| var functionCall *FunctionCall |
|
|
| |
| if strings.Contains(transcribedText, "weather") { |
| functionCall = &FunctionCall{ |
| Name: "get_weather", |
| Arguments: map[string]interface{}{ |
| "location": "New York", |
| "scale": "celsius", |
| }, |
| } |
| return "", nil, functionCall, nil |
| } |
|
|
| |
| generatedText := "This is a response to your speech input." |
|
|
| return generatedText, generatedAudio, nil, nil |
| } |
|
|
| |
| func splitResponseIntoChunks(response string) []string { |
| |
| chunkSize := 50 |
| var chunks []string |
| for len(response) > 0 { |
| if len(response) > chunkSize { |
| chunks = append(chunks, response[:chunkSize]) |
| response = response[chunkSize:] |
| } else { |
| chunks = append(chunks, response) |
| break |
| } |
| } |
| return chunks |
| } |
|
|
| |
| func generateSessionID() string { |
| |
| |
| return "sess_" + generateUniqueID() |
| } |
|
|
| func generateConversationID() string { |
| |
| |
| return "conv_" + generateUniqueID() |
| } |
|
|
| func generateItemID() string { |
| |
| |
| return "item_" + generateUniqueID() |
| } |
|
|
| func generateUniqueID() string { |
| |
| |
| |
| return "unique_id" |
| } |
|
|
| |
| type ResponseCreate struct { |
| Modalities []string `json:"modalities,omitempty"` |
| Instructions string `json:"instructions,omitempty"` |
| Functions functions.Functions `json:"functions,omitempty"` |
| |
| } |
|
|