Spaces:
Sleeping
Sleeping
| package controller | |
| import ( | |
| "encoding/json" | |
| "fmt" | |
| "net/http" | |
| "one-api/common" | |
| "one-api/model" | |
| "strconv" | |
| "strings" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| type OpenAIModel struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| OwnedBy string `json:"owned_by"` | |
| Permission []struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| AllowCreateEngine bool `json:"allow_create_engine"` | |
| AllowSampling bool `json:"allow_sampling"` | |
| AllowLogprobs bool `json:"allow_logprobs"` | |
| AllowSearchIndices bool `json:"allow_search_indices"` | |
| AllowView bool `json:"allow_view"` | |
| AllowFineTuning bool `json:"allow_fine_tuning"` | |
| Organization string `json:"organization"` | |
| Group string `json:"group"` | |
| IsBlocking bool `json:"is_blocking"` | |
| } `json:"permission"` | |
| Root string `json:"root"` | |
| Parent string `json:"parent"` | |
| } | |
| type OpenAIModelsResponse struct { | |
| Data []OpenAIModel `json:"data"` | |
| Success bool `json:"success"` | |
| } | |
| func GetAllChannels(c *gin.Context) { | |
| p, _ := strconv.Atoi(c.Query("p")) | |
| pageSize, _ := strconv.Atoi(c.Query("page_size")) | |
| if p < 0 { | |
| p = 0 | |
| } | |
| if pageSize < 0 { | |
| pageSize = common.ItemsPerPage | |
| } | |
| channelData := make([]*model.Channel, 0) | |
| idSort, _ := strconv.ParseBool(c.Query("id_sort")) | |
| enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) | |
| if enableTagMode { | |
| tags, err := model.GetPaginatedTags(p*pageSize, pageSize) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| for _, tag := range tags { | |
| if tag != nil && *tag != "" { | |
| tagChannel, err := model.GetChannelsByTag(*tag, idSort) | |
| if err == nil { | |
| channelData = append(channelData, tagChannel...) | |
| } | |
| } | |
| } | |
| } else { | |
| channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| channelData = channels | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": channelData, | |
| }) | |
| return | |
| } | |
| func FetchUpstreamModels(c *gin.Context) { | |
| id, err := strconv.Atoi(c.Param("id")) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| channel, err := model.GetChannelById(id, true) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| //if channel.Type != common.ChannelTypeOpenAI { | |
| // c.JSON(http.StatusOK, gin.H{ | |
| // "success": false, | |
| // "message": "仅支持 OpenAI 类型渠道", | |
| // }) | |
| // return | |
| //} | |
| baseURL := common.ChannelBaseURLs[channel.Type] | |
| if channel.GetBaseURL() != "" { | |
| baseURL = channel.GetBaseURL() | |
| } | |
| url := fmt.Sprintf("%s/v1/models", baseURL) | |
| body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| var result OpenAIModelsResponse | |
| if err = json.Unmarshal(body, &result); err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": fmt.Sprintf("解析响应失败: %s", err.Error()), | |
| }) | |
| return | |
| } | |
| var ids []string | |
| for _, model := range result.Data { | |
| ids = append(ids, model.ID) | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": ids, | |
| }) | |
| } | |
| func FixChannelsAbilities(c *gin.Context) { | |
| count, err := model.FixAbility() | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": count, | |
| }) | |
| } | |
| func SearchChannels(c *gin.Context) { | |
| keyword := c.Query("keyword") | |
| group := c.Query("group") | |
| modelKeyword := c.Query("model") | |
| idSort, _ := strconv.ParseBool(c.Query("id_sort")) | |
| enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) | |
| channelData := make([]*model.Channel, 0) | |
| if enableTagMode { | |
| tags, err := model.SearchTags(keyword, group, modelKeyword, idSort) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| for _, tag := range tags { | |
| if tag != nil && *tag != "" { | |
| tagChannel, err := model.GetChannelsByTag(*tag, idSort) | |
| if err == nil { | |
| channelData = append(channelData, tagChannel...) | |
| } | |
| } | |
| } | |
| } else { | |
| channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| channelData = channels | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": channelData, | |
| }) | |
| return | |
| } | |
| func GetChannel(c *gin.Context) { | |
| id, err := strconv.Atoi(c.Param("id")) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| channel, err := model.GetChannelById(id, false) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": channel, | |
| }) | |
| return | |
| } | |
| func AddChannel(c *gin.Context) { | |
| channel := model.Channel{} | |
| err := c.ShouldBindJSON(&channel) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| channel.CreatedTime = common.GetTimestamp() | |
| keys := strings.Split(channel.Key, "\n") | |
| if channel.Type == common.ChannelTypeVertexAi { | |
| if channel.Other == "" { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "部署地区不能为空", | |
| }) | |
| return | |
| } else { | |
| if common.IsJsonStr(channel.Other) { | |
| // must have default | |
| regionMap := common.StrToMap(channel.Other) | |
| if regionMap["default"] == nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "部署地区必须包含default字段", | |
| }) | |
| return | |
| } | |
| } | |
| } | |
| keys = []string{channel.Key} | |
| } | |
| channels := make([]model.Channel, 0, len(keys)) | |
| for _, key := range keys { | |
| if key == "" { | |
| continue | |
| } | |
| localChannel := channel | |
| localChannel.Key = key | |
| // Validate the length of the model name | |
| models := strings.Split(localChannel.Models, ",") | |
| for _, model := range models { | |
| if len(model) > 255 { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": fmt.Sprintf("模型名称过长: %s", model), | |
| }) | |
| return | |
| } | |
| } | |
| channels = append(channels, localChannel) | |
| } | |
| err = model.BatchInsertChannels(channels) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| }) | |
| return | |
| } | |
| func DeleteChannel(c *gin.Context) { | |
| id, _ := strconv.Atoi(c.Param("id")) | |
| channel := model.Channel{Id: id} | |
| err := channel.Delete() | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| }) | |
| return | |
| } | |
| func DeleteDisabledChannel(c *gin.Context) { | |
| rows, err := model.DeleteDisabledChannel() | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": rows, | |
| }) | |
| return | |
| } | |
| type ChannelTag struct { | |
| Tag string `json:"tag"` | |
| NewTag *string `json:"new_tag"` | |
| Priority *int64 `json:"priority"` | |
| Weight *uint `json:"weight"` | |
| ModelMapping *string `json:"model_mapping"` | |
| Models *string `json:"models"` | |
| Groups *string `json:"groups"` | |
| } | |
| func DisableTagChannels(c *gin.Context) { | |
| channelTag := ChannelTag{} | |
| err := c.ShouldBindJSON(&channelTag) | |
| if err != nil || channelTag.Tag == "" { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "参数错误", | |
| }) | |
| return | |
| } | |
| err = model.DisableChannelByTag(channelTag.Tag) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| }) | |
| return | |
| } | |
| func EnableTagChannels(c *gin.Context) { | |
| channelTag := ChannelTag{} | |
| err := c.ShouldBindJSON(&channelTag) | |
| if err != nil || channelTag.Tag == "" { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "参数错误", | |
| }) | |
| return | |
| } | |
| err = model.EnableChannelByTag(channelTag.Tag) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| }) | |
| return | |
| } | |
| func EditTagChannels(c *gin.Context) { | |
| channelTag := ChannelTag{} | |
| err := c.ShouldBindJSON(&channelTag) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "参数错误", | |
| }) | |
| return | |
| } | |
| if channelTag.Tag == "" { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "tag不能为空", | |
| }) | |
| return | |
| } | |
| err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| }) | |
| return | |
| } | |
| type ChannelBatch struct { | |
| Ids []int `json:"ids"` | |
| Tag *string `json:"tag"` | |
| } | |
| func DeleteChannelBatch(c *gin.Context) { | |
| channelBatch := ChannelBatch{} | |
| err := c.ShouldBindJSON(&channelBatch) | |
| if err != nil || len(channelBatch.Ids) == 0 { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "参数错误", | |
| }) | |
| return | |
| } | |
| err = model.BatchDeleteChannels(channelBatch.Ids) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": len(channelBatch.Ids), | |
| }) | |
| return | |
| } | |
| func UpdateChannel(c *gin.Context) { | |
| channel := model.Channel{} | |
| err := c.ShouldBindJSON(&channel) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| if channel.Type == common.ChannelTypeVertexAi { | |
| if channel.Other == "" { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "部署地区不能为空", | |
| }) | |
| return | |
| } else { | |
| if common.IsJsonStr(channel.Other) { | |
| // must have default | |
| regionMap := common.StrToMap(channel.Other) | |
| if regionMap["default"] == nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "部署地区必须包含default字段", | |
| }) | |
| return | |
| } | |
| } | |
| } | |
| } | |
| err = channel.Update() | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": channel, | |
| }) | |
| return | |
| } | |
| func FetchModels(c *gin.Context) { | |
| var req struct { | |
| BaseURL string `json:"base_url"` | |
| Type int `json:"type"` | |
| Key string `json:"key"` | |
| } | |
| if err := c.ShouldBindJSON(&req); err != nil { | |
| c.JSON(http.StatusBadRequest, gin.H{ | |
| "success": false, | |
| "message": "Invalid request", | |
| }) | |
| return | |
| } | |
| baseURL := req.BaseURL | |
| if baseURL == "" { | |
| baseURL = common.ChannelBaseURLs[req.Type] | |
| } | |
| client := &http.Client{} | |
| url := fmt.Sprintf("%s/v1/models", baseURL) | |
| request, err := http.NewRequest("GET", url, nil) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| // remove line breaks and extra spaces. | |
| key := strings.TrimSpace(req.Key) | |
| // If the key contains a line break, only take the first part. | |
| key = strings.Split(key, "\n")[0] | |
| request.Header.Set("Authorization", "Bearer "+key) | |
| response, err := client.Do(request) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| //check status code | |
| if response.StatusCode != http.StatusOK { | |
| c.JSON(http.StatusInternalServerError, gin.H{ | |
| "success": false, | |
| "message": "Failed to fetch models", | |
| }) | |
| return | |
| } | |
| defer response.Body.Close() | |
| var result struct { | |
| Data []struct { | |
| ID string `json:"id"` | |
| } `json:"data"` | |
| } | |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| var models []string | |
| for _, model := range result.Data { | |
| models = append(models, model.ID) | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "data": models, | |
| }) | |
| } | |
| func BatchSetChannelTag(c *gin.Context) { | |
| channelBatch := ChannelBatch{} | |
| err := c.ShouldBindJSON(&channelBatch) | |
| if err != nil || len(channelBatch.Ids) == 0 { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": "参数错误", | |
| }) | |
| return | |
| } | |
| err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": false, | |
| "message": err.Error(), | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": len(channelBatch.Ids), | |
| }) | |
| return | |
| } | |