| | package controller |
| |
|
| | import ( |
| | "encoding/json" |
| | "fmt" |
| | "net/http" |
| | "one-api/common" |
| | "one-api/model" |
| | "strconv" |
| | "strings" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | type OpenAIModel struct { |
| | ID string `json:"id"` |
| | Object string `json:"object"` |
| | Created int64 `json:"created"` |
| | OwnedBy string `json:"owned_by"` |
| | Permission []struct { |
| | ID string `json:"id"` |
| | Object string `json:"object"` |
| | Created int64 `json:"created"` |
| | AllowCreateEngine bool `json:"allow_create_engine"` |
| | AllowSampling bool `json:"allow_sampling"` |
| | AllowLogprobs bool `json:"allow_logprobs"` |
| | AllowSearchIndices bool `json:"allow_search_indices"` |
| | AllowView bool `json:"allow_view"` |
| | AllowFineTuning bool `json:"allow_fine_tuning"` |
| | Organization string `json:"organization"` |
| | Group string `json:"group"` |
| | IsBlocking bool `json:"is_blocking"` |
| | } `json:"permission"` |
| | Root string `json:"root"` |
| | Parent string `json:"parent"` |
| | } |
| |
|
| | type OpenAIModelsResponse struct { |
| | Data []OpenAIModel `json:"data"` |
| | Success bool `json:"success"` |
| | } |
| |
|
| | func GetAllChannels(c *gin.Context) { |
| | p, _ := strconv.Atoi(c.Query("p")) |
| | pageSize, _ := strconv.Atoi(c.Query("page_size")) |
| | if p < 0 { |
| | p = 0 |
| | } |
| | if pageSize < 0 { |
| | pageSize = common.ItemsPerPage |
| | } |
| | channelData := make([]*model.Channel, 0) |
| | idSort, _ := strconv.ParseBool(c.Query("id_sort")) |
| | enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) |
| | if enableTagMode { |
| | tags, err := model.GetPaginatedTags(p*pageSize, pageSize) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | for _, tag := range tags { |
| | if tag != nil && *tag != "" { |
| | tagChannel, err := model.GetChannelsByTag(*tag, idSort) |
| | if err == nil { |
| | channelData = append(channelData, tagChannel...) |
| | } |
| | } |
| | } |
| | } else { |
| | channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | channelData = channels |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": channelData, |
| | }) |
| | return |
| | } |
| |
|
| | func FetchUpstreamModels(c *gin.Context) { |
| | id, err := strconv.Atoi(c.Param("id")) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| |
|
| | channel, err := model.GetChannelById(id, true) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | baseURL := common.ChannelBaseURLs[channel.Type] |
| | if channel.GetBaseURL() != "" { |
| | baseURL = channel.GetBaseURL() |
| | } |
| | url := fmt.Sprintf("%s/v1/models", baseURL) |
| | switch channel.Type { |
| | case common.ChannelTypeGemini: |
| | url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) |
| | case common.ChannelTypeAli: |
| | url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) |
| | } |
| | body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| |
|
| | var result OpenAIModelsResponse |
| | if err = json.Unmarshal(body, &result); err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": fmt.Sprintf("解析响应失败: %s", err.Error()), |
| | }) |
| | return |
| | } |
| |
|
| | var ids []string |
| | for _, model := range result.Data { |
| | id := model.ID |
| | if channel.Type == common.ChannelTypeGemini { |
| | id = strings.TrimPrefix(id, "models/") |
| | } |
| | ids = append(ids, id) |
| | } |
| |
|
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": ids, |
| | }) |
| | } |
| |
|
| | func FixChannelsAbilities(c *gin.Context) { |
| | count, err := model.FixAbility() |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": count, |
| | }) |
| | } |
| |
|
| | func SearchChannels(c *gin.Context) { |
| | keyword := c.Query("keyword") |
| | group := c.Query("group") |
| | modelKeyword := c.Query("model") |
| | idSort, _ := strconv.ParseBool(c.Query("id_sort")) |
| | enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) |
| | channelData := make([]*model.Channel, 0) |
| | if enableTagMode { |
| | tags, err := model.SearchTags(keyword, group, modelKeyword, idSort) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | for _, tag := range tags { |
| | if tag != nil && *tag != "" { |
| | tagChannel, err := model.GetChannelsByTag(*tag, idSort) |
| | if err == nil { |
| | channelData = append(channelData, tagChannel...) |
| | } |
| | } |
| | } |
| | } else { |
| | channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | channelData = channels |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": channelData, |
| | }) |
| | return |
| | } |
| |
|
| | func GetChannel(c *gin.Context) { |
| | id, err := strconv.Atoi(c.Param("id")) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | channel, err := model.GetChannelById(id, false) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": channel, |
| | }) |
| | return |
| | } |
| |
|
| | func AddChannel(c *gin.Context) { |
| | channel := model.Channel{} |
| | err := c.ShouldBindJSON(&channel) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | channel.CreatedTime = common.GetTimestamp() |
| | keys := strings.Split(channel.Key, "\n") |
| | if channel.Type == common.ChannelTypeVertexAi { |
| | if channel.Other == "" { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "部署地区不能为空", |
| | }) |
| | return |
| | } else { |
| | if common.IsJsonStr(channel.Other) { |
| | |
| | regionMap := common.StrToMap(channel.Other) |
| | if regionMap["default"] == nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "部署地区必须包含default字段", |
| | }) |
| | return |
| | } |
| | } |
| | } |
| | keys = []string{channel.Key} |
| | } |
| | channels := make([]model.Channel, 0, len(keys)) |
| | for _, key := range keys { |
| | if key == "" { |
| | continue |
| | } |
| | localChannel := channel |
| | localChannel.Key = key |
| | |
| | models := strings.Split(localChannel.Models, ",") |
| | for _, model := range models { |
| | if len(model) > 255 { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": fmt.Sprintf("模型名称过长: %s", model), |
| | }) |
| | return |
| | } |
| | } |
| | channels = append(channels, localChannel) |
| | } |
| | err = model.BatchInsertChannels(channels) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | }) |
| | return |
| | } |
| |
|
| | func DeleteChannel(c *gin.Context) { |
| | id, _ := strconv.Atoi(c.Param("id")) |
| | channel := model.Channel{Id: id} |
| | err := channel.Delete() |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | }) |
| | return |
| | } |
| |
|
| | func DeleteDisabledChannel(c *gin.Context) { |
| | rows, err := model.DeleteDisabledChannel() |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": rows, |
| | }) |
| | return |
| | } |
| |
|
| | type ChannelTag struct { |
| | Tag string `json:"tag"` |
| | NewTag *string `json:"new_tag"` |
| | Priority *int64 `json:"priority"` |
| | Weight *uint `json:"weight"` |
| | ModelMapping *string `json:"model_mapping"` |
| | Models *string `json:"models"` |
| | Groups *string `json:"groups"` |
| | } |
| |
|
| | func DisableTagChannels(c *gin.Context) { |
| | channelTag := ChannelTag{} |
| | err := c.ShouldBindJSON(&channelTag) |
| | if err != nil || channelTag.Tag == "" { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "参数错误", |
| | }) |
| | return |
| | } |
| | err = model.DisableChannelByTag(channelTag.Tag) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | }) |
| | return |
| | } |
| |
|
| | func EnableTagChannels(c *gin.Context) { |
| | channelTag := ChannelTag{} |
| | err := c.ShouldBindJSON(&channelTag) |
| | if err != nil || channelTag.Tag == "" { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "参数错误", |
| | }) |
| | return |
| | } |
| | err = model.EnableChannelByTag(channelTag.Tag) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | }) |
| | return |
| | } |
| |
|
| | func EditTagChannels(c *gin.Context) { |
| | channelTag := ChannelTag{} |
| | err := c.ShouldBindJSON(&channelTag) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "参数错误", |
| | }) |
| | return |
| | } |
| | if channelTag.Tag == "" { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "tag不能为空", |
| | }) |
| | return |
| | } |
| | err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | }) |
| | return |
| | } |
| |
|
| | type ChannelBatch struct { |
| | Ids []int `json:"ids"` |
| | Tag *string `json:"tag"` |
| | } |
| |
|
| | func DeleteChannelBatch(c *gin.Context) { |
| | channelBatch := ChannelBatch{} |
| | err := c.ShouldBindJSON(&channelBatch) |
| | if err != nil || len(channelBatch.Ids) == 0 { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "参数错误", |
| | }) |
| | return |
| | } |
| | err = model.BatchDeleteChannels(channelBatch.Ids) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": len(channelBatch.Ids), |
| | }) |
| | return |
| | } |
| |
|
| | func UpdateChannel(c *gin.Context) { |
| | channel := model.Channel{} |
| | err := c.ShouldBindJSON(&channel) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | if channel.Type == common.ChannelTypeVertexAi { |
| | if channel.Other == "" { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "部署地区不能为空", |
| | }) |
| | return |
| | } else { |
| | if common.IsJsonStr(channel.Other) { |
| | |
| | regionMap := common.StrToMap(channel.Other) |
| | if regionMap["default"] == nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "部署地区必须包含default字段", |
| | }) |
| | return |
| | } |
| | } |
| | } |
| | } |
| | err = channel.Update() |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": channel, |
| | }) |
| | return |
| | } |
| |
|
| | func FetchModels(c *gin.Context) { |
| | var req struct { |
| | BaseURL string `json:"base_url"` |
| | Type int `json:"type"` |
| | Key string `json:"key"` |
| | } |
| |
|
| | if err := c.ShouldBindJSON(&req); err != nil { |
| | c.JSON(http.StatusBadRequest, gin.H{ |
| | "success": false, |
| | "message": "Invalid request", |
| | }) |
| | return |
| | } |
| |
|
| | baseURL := req.BaseURL |
| | if baseURL == "" { |
| | baseURL = common.ChannelBaseURLs[req.Type] |
| | } |
| |
|
| | client := &http.Client{} |
| | url := fmt.Sprintf("%s/v1/models", baseURL) |
| |
|
| | request, err := http.NewRequest("GET", url, nil) |
| | if err != nil { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| |
|
| | |
| | key := strings.TrimSpace(req.Key) |
| | |
| | key = strings.Split(key, "\n")[0] |
| | request.Header.Set("Authorization", "Bearer "+key) |
| |
|
| | response, err := client.Do(request) |
| | if err != nil { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | |
| | if response.StatusCode != http.StatusOK { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "success": false, |
| | "message": "Failed to fetch models", |
| | }) |
| | return |
| | } |
| | defer response.Body.Close() |
| |
|
| | var result struct { |
| | Data []struct { |
| | ID string `json:"id"` |
| | } `json:"data"` |
| | } |
| |
|
| | if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| |
|
| | var models []string |
| | for _, model := range result.Data { |
| | models = append(models, model.ID) |
| | } |
| |
|
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "data": models, |
| | }) |
| | } |
| |
|
| | func BatchSetChannelTag(c *gin.Context) { |
| | channelBatch := ChannelBatch{} |
| | err := c.ShouldBindJSON(&channelBatch) |
| | if err != nil || len(channelBatch.Ids) == 0 { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": "参数错误", |
| | }) |
| | return |
| | } |
| | err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": len(channelBatch.Ids), |
| | }) |
| | return |
| | } |
| |
|