package openai import ( "encoding/json" "errors" "fmt" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions // @Summary Generate completions for a given prompt and model. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { created := int(time.Now().Unix()) usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing } resp := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, Text: s, FinishReason: nil, }, }, Object: "text_completion", Usage: usage, } xlog.Debug("Sending goroutine", "text", s) responses <- resp return true } _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback) close(responses) return err } return func(c echo.Context) error { created := int(time.Now().Unix()) // Handle Correlation id := c.Request().Header.Get("X-Correlation-ID") if id == "" { id = uuid.New().String() } extraUsage := c.Request().Header.Get("Extra-Usage") != "" input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return echo.ErrBadRequest } if config.ResponseFormatMap != nil { d := schema.ChatCompletionResponseFormat{} dat, _ := json.Marshal(config.ResponseFormatMap) _ = json.Unmarshal(dat, &d) if d.Type == "json_object" { input.Grammar = functions.JSONBNF } } config.Grammar = input.Grammar xlog.Debug("Parameter Config", "config", config) if input.Stream { xlog.Debug("Stream request received") c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") if len(config.PromptStrings) > 1 { return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") } predInput := config.PromptStrings[0] templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ Input: predInput, SystemPrompt: config.SystemPrompt, ReasoningEffort: input.ReasoningEffort, Metadata: input.Metadata, }) if err == nil { predInput = templatedInput xlog.Debug("Template found, input modified", "input", predInput) } responses := make(chan schema.OpenAIResponse) ended := make(chan error) go func() { ended <- process(id, predInput, input, config, ml, responses, extraUsage) }() LOOP: for { select { case ev := <-responses: if len(ev.Choices) == 0 { xlog.Debug("No choices in the response, skipping") continue } respData, err := json.Marshal(ev) if err != nil { xlog.Debug("Failed to marshal response", "error", err) continue } xlog.Debug("Sending chunk", "chunk", string(respData)) _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) if err != nil { return err } c.Response().Flush() case err := <-ended: if err == nil { break LOOP } xlog.Error("Stream ended with error", "error", err) stopReason := FinishReasonStop errorResp := schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, Choices: []schema.Choice{ { Index: 0, FinishReason: &stopReason, Text: "Internal error: " + err.Error(), }, }, Object: "text_completion", } errorData, marshalErr := json.Marshal(errorResp) if marshalErr != nil { xlog.Error("Failed to marshal error response", "error", marshalErr) // Send a simple error message as fallback fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n") } else { fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData)) } c.Response().Flush() return nil } } stopReason := FinishReasonStop resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, FinishReason: &stopReason, }, }, Object: "text_completion", } respData, _ := json.Marshal(resp) fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } var result []schema.Choice totalTokenUsage := backend.TokenUsage{} for k, i := range config.PromptStrings { templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ SystemPrompt: config.SystemPrompt, Input: i, ReasoningEffort: input.ReasoningEffort, Metadata: input.Metadata, }) if err == nil { i = templatedInput xlog.Debug("Template found, input modified", "input", i) } r, tokenUsage, err := ComputeChoices( input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) { stopReason := FinishReasonStop *c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k}) }, nil) if err != nil { return err } totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing result = append(result, r...) } usage := schema.OpenAIUsage{ PromptTokens: totalTokenUsage.Prompt, CompletionTokens: totalTokenUsage.Completion, TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing } resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "text_completion", Usage: usage, } jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) // Return the prediction in the response body return c.JSON(200, resp) } }