| package helper |
|
|
| import ( |
| "encoding/json" |
| "errors" |
| "fmt" |
|
|
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/relay/common" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error { |
| |
| modelMapping := c.GetString("model_mapping") |
| if modelMapping != "" && modelMapping != "{}" { |
| modelMap := make(map[string]string) |
| err := json.Unmarshal([]byte(modelMapping), &modelMap) |
| if err != nil { |
| return fmt.Errorf("unmarshal_model_mapping_failed") |
| } |
|
|
| |
| currentModel := info.OriginModelName |
| visitedModels := map[string]bool{ |
| currentModel: true, |
| } |
| for { |
| if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" { |
| |
| if visitedModels[mappedModel] { |
| if mappedModel == currentModel { |
| if currentModel == info.OriginModelName { |
| info.IsModelMapped = false |
| return nil |
| } else { |
| info.IsModelMapped = true |
| break |
| } |
| } |
| return errors.New("model_mapping_contains_cycle") |
| } |
| visitedModels[mappedModel] = true |
| currentModel = mappedModel |
| info.IsModelMapped = true |
| } else { |
| break |
| } |
| } |
| if info.IsModelMapped { |
| info.UpstreamModelName = currentModel |
| } |
| } |
| if request != nil { |
| request.SetModelName(info.UpstreamModelName) |
| } |
| return nil |
| } |
|
|