|
|
package middleware |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"net/http" |
|
|
"strconv" |
|
|
"strings" |
|
|
|
|
|
"github.com/google/uuid" |
|
|
"github.com/labstack/echo/v4" |
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
"github.com/mudler/LocalAI/core/services" |
|
|
"github.com/mudler/LocalAI/core/templates" |
|
|
"github.com/mudler/LocalAI/pkg/functions" |
|
|
"github.com/mudler/LocalAI/pkg/model" |
|
|
"github.com/mudler/LocalAI/pkg/utils" |
|
|
"github.com/mudler/xlog" |
|
|
) |
|
|
|
|
|
type correlationIDKeyType string |
|
|
|
|
|
|
|
|
const CorrelationIDKey correlationIDKeyType = "correlationID" |
|
|
|
|
|
type RequestExtractor struct { |
|
|
modelConfigLoader *config.ModelConfigLoader |
|
|
modelLoader *model.ModelLoader |
|
|
applicationConfig *config.ApplicationConfig |
|
|
} |
|
|
|
|
|
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { |
|
|
return &RequestExtractor{ |
|
|
modelConfigLoader: modelConfigLoader, |
|
|
modelLoader: modelLoader, |
|
|
applicationConfig: applicationConfig, |
|
|
} |
|
|
} |
|
|
|
|
|
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" |
|
|
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" |
|
|
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" |
|
|
|
|
|
|
|
|
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { |
|
|
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) |
|
|
if ok && model != "" { |
|
|
return |
|
|
} |
|
|
model = c.Param("model") |
|
|
|
|
|
if model == "" { |
|
|
model = c.QueryParam("model") |
|
|
} |
|
|
|
|
|
|
|
|
if model == "" { |
|
|
model = c.FormValue("model") |
|
|
} |
|
|
|
|
|
if model == "" { |
|
|
|
|
|
auth := c.Request().Header.Get("Authorization") |
|
|
bearer := strings.TrimPrefix(auth, "Bearer ") |
|
|
if bearer != "" && bearer != auth { |
|
|
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) |
|
|
if err == nil && exists { |
|
|
model = bearer |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) |
|
|
} |
|
|
|
|
|
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { |
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
|
return func(c echo.Context) error { |
|
|
re.setModelNameFromRequest(c) |
|
|
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) |
|
|
if !ok || localModelName == "" { |
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) |
|
|
xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) |
|
|
} |
|
|
return next(c) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { |
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
|
return func(c echo.Context) error { |
|
|
re.setModelNameFromRequest(c) |
|
|
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) |
|
|
if localModelName != "" { |
|
|
return next(c) |
|
|
} |
|
|
|
|
|
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) |
|
|
if err != nil { |
|
|
xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) |
|
|
return next(c) |
|
|
} |
|
|
|
|
|
if len(modelNames) == 0 { |
|
|
xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") |
|
|
|
|
|
|
|
|
return next(c) |
|
|
} |
|
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) |
|
|
xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) |
|
|
return next(c) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { |
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
|
return func(c echo.Context) error { |
|
|
input := initializer() |
|
|
if input == nil { |
|
|
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") |
|
|
} |
|
|
if err := c.Bind(input); err != nil { |
|
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) |
|
|
} |
|
|
|
|
|
|
|
|
if input.ModelName(nil) == "" { |
|
|
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) |
|
|
if ok && localModelName != "" { |
|
|
xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) |
|
|
input.ModelName(&localModelName) |
|
|
} |
|
|
} |
|
|
|
|
|
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) |
|
|
|
|
|
if err != nil { |
|
|
xlog.Warn("Model Configuration File not found", "model", input.ModelName(nil), "error", err) |
|
|
} else if cfg.Model == "" && input.ModelName(nil) != "" { |
|
|
xlog.Debug("config does not include model, using input", "input.ModelName", input.ModelName(nil)) |
|
|
cfg.Model = input.ModelName(nil) |
|
|
} |
|
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) |
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) |
|
|
|
|
|
return next(c) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { |
|
|
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) |
|
|
if !ok || input.Model == "" { |
|
|
return echo.ErrBadRequest |
|
|
} |
|
|
|
|
|
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) |
|
|
if !ok || cfg == nil { |
|
|
return echo.ErrBadRequest |
|
|
} |
|
|
|
|
|
|
|
|
correlationID := c.Request().Header.Get("X-Correlation-ID") |
|
|
if correlationID == "" { |
|
|
correlationID = uuid.New().String() |
|
|
} |
|
|
c.Response().Header().Set("X-Correlation-ID", correlationID) |
|
|
|
|
|
|
|
|
|
|
|
reqCtx := c.Request().Context() |
|
|
c1, cancel := context.WithCancel(re.applicationConfig.Context) |
|
|
|
|
|
|
|
|
go func() { |
|
|
select { |
|
|
case <-reqCtx.Done(): |
|
|
cancel() |
|
|
case <-c1.Done(): |
|
|
|
|
|
} |
|
|
}() |
|
|
|
|
|
|
|
|
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) |
|
|
|
|
|
input.Context = ctxWithCorrelationID |
|
|
input.Cancel = cancel |
|
|
|
|
|
err := mergeOpenAIRequestAndModelConfig(cfg, input) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
if cfg.Model == "" { |
|
|
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) |
|
|
cfg.Model = input.Model |
|
|
} |
|
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) |
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { |
|
|
if input.Echo { |
|
|
config.Echo = input.Echo |
|
|
} |
|
|
if input.TopK != nil { |
|
|
config.TopK = input.TopK |
|
|
} |
|
|
if input.TopP != nil { |
|
|
config.TopP = input.TopP |
|
|
} |
|
|
|
|
|
if input.Backend != "" { |
|
|
config.Backend = input.Backend |
|
|
} |
|
|
|
|
|
if input.ClipSkip != 0 { |
|
|
config.Diffusers.ClipSkip = input.ClipSkip |
|
|
} |
|
|
|
|
|
if input.NegativePromptScale != 0 { |
|
|
config.NegativePromptScale = input.NegativePromptScale |
|
|
} |
|
|
|
|
|
if input.NegativePrompt != "" { |
|
|
config.NegativePrompt = input.NegativePrompt |
|
|
} |
|
|
|
|
|
if input.RopeFreqBase != 0 { |
|
|
config.RopeFreqBase = input.RopeFreqBase |
|
|
} |
|
|
|
|
|
if input.RopeFreqScale != 0 { |
|
|
config.RopeFreqScale = input.RopeFreqScale |
|
|
} |
|
|
|
|
|
if input.Grammar != "" { |
|
|
config.Grammar = input.Grammar |
|
|
} |
|
|
|
|
|
if input.Temperature != nil { |
|
|
config.Temperature = input.Temperature |
|
|
} |
|
|
|
|
|
if input.Maxtokens != nil { |
|
|
config.Maxtokens = input.Maxtokens |
|
|
} |
|
|
|
|
|
if input.ResponseFormat != nil { |
|
|
switch responseFormat := input.ResponseFormat.(type) { |
|
|
case string: |
|
|
config.ResponseFormat = responseFormat |
|
|
case map[string]interface{}: |
|
|
config.ResponseFormatMap = responseFormat |
|
|
} |
|
|
} |
|
|
|
|
|
switch stop := input.Stop.(type) { |
|
|
case string: |
|
|
if stop != "" { |
|
|
config.StopWords = append(config.StopWords, stop) |
|
|
} |
|
|
case []interface{}: |
|
|
for _, pp := range stop { |
|
|
if s, ok := pp.(string); ok { |
|
|
config.StopWords = append(config.StopWords, s) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if len(input.Tools) > 0 { |
|
|
for _, tool := range input.Tools { |
|
|
input.Functions = append(input.Functions, tool.Function) |
|
|
} |
|
|
} |
|
|
|
|
|
if input.ToolsChoice != nil { |
|
|
var toolChoice functions.Tool |
|
|
|
|
|
switch content := input.ToolsChoice.(type) { |
|
|
case string: |
|
|
_ = json.Unmarshal([]byte(content), &toolChoice) |
|
|
case map[string]interface{}: |
|
|
dat, _ := json.Marshal(content) |
|
|
_ = json.Unmarshal(dat, &toolChoice) |
|
|
} |
|
|
input.FunctionCall = map[string]interface{}{ |
|
|
"name": toolChoice.Function.Name, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
imgIndex, vidIndex, audioIndex := 0, 0, 0 |
|
|
for i, m := range input.Messages { |
|
|
nrOfImgsInMessage := 0 |
|
|
nrOfVideosInMessage := 0 |
|
|
nrOfAudiosInMessage := 0 |
|
|
|
|
|
switch content := m.Content.(type) { |
|
|
case string: |
|
|
input.Messages[i].StringContent = content |
|
|
case []interface{}: |
|
|
dat, _ := json.Marshal(content) |
|
|
c := []schema.Content{} |
|
|
json.Unmarshal(dat, &c) |
|
|
|
|
|
textContent := "" |
|
|
|
|
|
|
|
|
CONTENT: |
|
|
for _, pp := range c { |
|
|
switch pp.Type { |
|
|
case "text": |
|
|
textContent += pp.Text |
|
|
|
|
|
case "video", "video_url": |
|
|
|
|
|
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) |
|
|
if err != nil { |
|
|
xlog.Error("Failed encoding video", "error", err) |
|
|
continue CONTENT |
|
|
} |
|
|
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) |
|
|
vidIndex++ |
|
|
nrOfVideosInMessage++ |
|
|
case "audio_url", "audio": |
|
|
|
|
|
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) |
|
|
if err != nil { |
|
|
xlog.Error("Failed encoding audio", "error", err) |
|
|
continue CONTENT |
|
|
} |
|
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) |
|
|
audioIndex++ |
|
|
nrOfAudiosInMessage++ |
|
|
case "input_audio": |
|
|
|
|
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) |
|
|
audioIndex++ |
|
|
nrOfAudiosInMessage++ |
|
|
case "image_url", "image": |
|
|
|
|
|
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) |
|
|
if err != nil { |
|
|
xlog.Error("Failed encoding image", "error", err) |
|
|
continue CONTENT |
|
|
} |
|
|
|
|
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) |
|
|
|
|
|
imgIndex++ |
|
|
nrOfImgsInMessage++ |
|
|
} |
|
|
} |
|
|
|
|
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ |
|
|
TotalImages: imgIndex, |
|
|
TotalVideos: vidIndex, |
|
|
TotalAudios: audioIndex, |
|
|
ImagesInMessage: nrOfImgsInMessage, |
|
|
VideosInMessage: nrOfVideosInMessage, |
|
|
AudiosInMessage: nrOfAudiosInMessage, |
|
|
}, textContent) |
|
|
} |
|
|
} |
|
|
|
|
|
if input.RepeatPenalty != 0 { |
|
|
config.RepeatPenalty = input.RepeatPenalty |
|
|
} |
|
|
|
|
|
if input.FrequencyPenalty != 0 { |
|
|
config.FrequencyPenalty = input.FrequencyPenalty |
|
|
} |
|
|
|
|
|
if input.PresencePenalty != 0 { |
|
|
config.PresencePenalty = input.PresencePenalty |
|
|
} |
|
|
|
|
|
if input.Keep != 0 { |
|
|
config.Keep = input.Keep |
|
|
} |
|
|
|
|
|
if input.Batch != 0 { |
|
|
config.Batch = input.Batch |
|
|
} |
|
|
|
|
|
if input.IgnoreEOS { |
|
|
config.IgnoreEOS = input.IgnoreEOS |
|
|
} |
|
|
|
|
|
if input.Seed != nil { |
|
|
config.Seed = input.Seed |
|
|
} |
|
|
|
|
|
if input.TypicalP != nil { |
|
|
config.TypicalP = input.TypicalP |
|
|
} |
|
|
|
|
|
xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) |
|
|
|
|
|
switch inputs := input.Input.(type) { |
|
|
case string: |
|
|
if inputs != "" { |
|
|
config.InputStrings = append(config.InputStrings, inputs) |
|
|
} |
|
|
case []any: |
|
|
for _, pp := range inputs { |
|
|
switch i := pp.(type) { |
|
|
case string: |
|
|
config.InputStrings = append(config.InputStrings, i) |
|
|
case []any: |
|
|
tokens := []int{} |
|
|
inputStrings := []string{} |
|
|
for _, ii := range i { |
|
|
switch ii := ii.(type) { |
|
|
case int: |
|
|
tokens = append(tokens, ii) |
|
|
case float64: |
|
|
tokens = append(tokens, int(ii)) |
|
|
case string: |
|
|
inputStrings = append(inputStrings, ii) |
|
|
default: |
|
|
xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) |
|
|
} |
|
|
} |
|
|
config.InputToken = append(config.InputToken, tokens) |
|
|
config.InputStrings = append(config.InputStrings, inputStrings...) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
switch fnc := input.FunctionCall.(type) { |
|
|
case string: |
|
|
if fnc != "" { |
|
|
config.SetFunctionCallString(fnc) |
|
|
} |
|
|
case map[string]interface{}: |
|
|
var name string |
|
|
n, exists := fnc["name"] |
|
|
if exists { |
|
|
nn, e := n.(string) |
|
|
if e { |
|
|
name = nn |
|
|
} |
|
|
} |
|
|
config.SetFunctionCallNameString(name) |
|
|
} |
|
|
|
|
|
switch p := input.Prompt.(type) { |
|
|
case string: |
|
|
config.PromptStrings = append(config.PromptStrings, p) |
|
|
case []interface{}: |
|
|
for _, pp := range p { |
|
|
if s, ok := pp.(string); ok { |
|
|
config.PromptStrings = append(config.PromptStrings, s) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if input.Quality != "" { |
|
|
q, err := strconv.Atoi(input.Quality) |
|
|
if err == nil { |
|
|
config.Step = q |
|
|
} |
|
|
} |
|
|
|
|
|
if valid, _ := config.Validate(); valid { |
|
|
return nil |
|
|
} |
|
|
return fmt.Errorf("unable to validate configuration after merging") |
|
|
} |
|
|
|