| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| package controller |
|
|
| import ( |
| "encoding/json" |
| "fmt" |
| "net/http" |
| "strconv" |
| "strings" |
| "veloera/common" |
| "veloera/middleware" |
| "veloera/model" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type OpenAIModel struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| OwnedBy string `json:"owned_by"` |
| Permission []struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| AllowCreateEngine bool `json:"allow_create_engine"` |
| AllowSampling bool `json:"allow_sampling"` |
| AllowLogprobs bool `json:"allow_logprobs"` |
| AllowSearchIndices bool `json:"allow_search_indices"` |
| AllowView bool `json:"allow_view"` |
| AllowFineTuning bool `json:"allow_fine_tuning"` |
| Organization string `json:"organization"` |
| Group string `json:"group"` |
| IsBlocking bool `json:"is_blocking"` |
| } `json:"permission"` |
| Root string `json:"root"` |
| Parent string `json:"parent"` |
| } |
|
|
| type OpenAIModelsResponse struct { |
| Data []OpenAIModel `json:"data"` |
| Success bool `json:"success"` |
| } |
| |
| func GetChannelTags(c *gin.Context) { |
| offset, _ := strconv.Atoi(c.Query("offset")) |
| limit, _ := strconv.Atoi(c.Query("limit")) |
| if limit <= 0 || limit > 200 { |
| limit = 100 |
| } |
| tags, err := model.GetPaginatedTags(offset, limit) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": tags, |
| }) |
| } |
|
|
|
|
| func GetAllChannels(c *gin.Context) { |
| p, _ := strconv.Atoi(c.Query("p")) |
| pageSize, _ := strconv.Atoi(c.Query("page_size")) |
| if p < 0 { |
| p = 0 |
| } |
| if pageSize < 0 { |
| pageSize = common.ItemsPerPage |
| } |
| channelData := make([]*model.Channel, 0) |
| idSort, _ := strconv.ParseBool(c.Query("id_sort")) |
| enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) |
| if enableTagMode { |
| tags, err := model.GetPaginatedTags(p*pageSize, pageSize) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| for _, tag := range tags { |
| if tag != nil && *tag != "" { |
| tagChannel, err := model.GetChannelsByTag(*tag, idSort) |
| if err == nil { |
| channelData = append(channelData, tagChannel...) |
| } |
| } |
| } |
| } else { |
| channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| channelData = channels |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": channelData, |
| }) |
| return |
| } |
|
|
| func FetchUpstreamModels(c *gin.Context) { |
| id, err := strconv.Atoi(c.Param("id")) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| channel, err := model.GetChannelById(id, true) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
|
|
| baseURL := common.ChannelBaseURLs[channel.Type] |
| if channel.GetBaseURL() != "" { |
| baseURL = channel.GetBaseURL() |
| } |
| url := fmt.Sprintf("%s/v1/models", baseURL) |
|
|
| if strings.HasSuffix(baseURL, "/chat/completions") { |
| url = strings.TrimSuffix(baseURL, "/chat/completions") + "/models" |
| } |
|
|
| if channel.Type == common.ChannelTypeGemini { |
| url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) |
| } |
| if channel.Type == common.ChannelTypeGitHub { |
| url = strings.Replace(baseURL, "/inference", "/catalog/models", 1) |
| } |
| key := strings.Split(channel.Key, ",")[0] |
| body, err := GetResponseBody("GET", url, channel, GetAuthHeader(key)) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| var ids []string |
| |
| if channel.Type == common.ChannelTypeGitHub { |
| var arr []struct { ID string `json:"id"` } |
| if err = json.Unmarshal(body, &arr); err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": fmt.Sprintf("解析 GitHub 响应失败: %s", err.Error()), |
| }) |
| return |
| } |
| for _, m := range arr { |
| ids = append(ids, m.ID) |
| } |
| } else { |
| var result OpenAIModelsResponse |
| if err = json.Unmarshal(body, &result); err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": fmt.Sprintf("解析响应失败: %s", err.Error()), |
| }) |
| return |
| } |
| for _, model := range result.Data { |
| id := model.ID |
| if channel.Type == common.ChannelTypeGemini { |
| id = strings.TrimPrefix(id, "models/") |
| } |
| ids = append(ids, id) |
| } |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": ids, |
| }) |
| } |
|
|
| func FixChannelsAbilities(c *gin.Context) { |
| count, err := model.FixAbility() |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": count, |
| }) |
| } |
|
|
| func SearchChannels(c *gin.Context) { |
| keyword := c.Query("keyword") |
| group := c.Query("group") |
| modelKeyword := c.Query("model") |
| idSort, _ := strconv.ParseBool(c.Query("id_sort")) |
| enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) |
|
|
| |
| parseCSVInts := func(s string) []int { |
| if s == "" { |
| return nil |
| } |
| parts := strings.Split(s, ",") |
| res := make([]int, 0, len(parts)) |
| for _, p := range parts { |
| p = strings.TrimSpace(p) |
| if p == "" { |
| continue |
| } |
| if v, err := strconv.Atoi(p); err == nil { |
| res = append(res, v) |
| } |
| } |
| return res |
| } |
| parseCSVStr := func(s string) []string { |
| if s == "" { |
| return nil |
| } |
| parts := strings.Split(s, ",") |
| res := make([]string, 0, len(parts)) |
| for _, p := range parts { |
| p = strings.TrimSpace(p) |
| if p != "" { |
| res = append(res, p) |
| } |
| } |
| return res |
| } |
| types := parseCSVInts(c.Query("types")) |
| statuses := parseCSVInts(c.Query("statuses")) |
| tagsFilter := parseCSVStr(c.Query("tags")) |
|
|
| channelData := make([]*model.Channel, 0) |
| if enableTagMode { |
| |
| tags, err := model.SearchTags(keyword, group, modelKeyword, idSort, types, statuses, tagsFilter) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| for _, tag := range tags { |
| if tag != nil && *tag != "" { |
| tagChannel, err := model.GetChannelsByTag(*tag, idSort) |
| if err == nil { |
| |
| for _, ch := range tagChannel { |
| if len(types) > 0 { |
| ok := false |
| for _, t := range types { |
| if ch.Type == t { |
| ok = true |
| break |
| } |
| } |
| if !ok { |
| continue |
| } |
| } |
| if len(statuses) > 0 { |
| ok := false |
| for _, s := range statuses { |
| if ch.Status == s { |
| ok = true |
| break |
| } |
| } |
| if !ok { |
| continue |
| } |
| } |
| if len(tagsFilter) > 0 { |
| ok := false |
| for _, tg := range tagsFilter { |
| if ch.GetTag() == tg { |
| ok = true |
| break |
| } |
| } |
| if !ok { |
| continue |
| } |
| } |
| channelData = append(channelData, ch) |
| } |
| } |
| } |
| } |
| } else { |
| channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort, types, statuses, tagsFilter) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| channelData = channels |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": channelData, |
| }) |
| return |
| } |
|
|
| func GetChannel(c *gin.Context) { |
| id, err := strconv.Atoi(c.Param("id")) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| channel, err := model.GetChannelById(id, true) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": channel, |
| }) |
| return |
| } |
|
|
| func AddChannel(c *gin.Context) { |
| channel := model.Channel{} |
| err := c.ShouldBindJSON(&channel) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| channel.CreatedTime = common.GetTimestamp() |
|
|
| |
| if channel.Type == common.ChannelTypeVertexAi { |
| if channel.Other == "" { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "部署地区不能为空", |
| }) |
| return |
| } else { |
| if common.IsJsonStr(channel.Other) { |
| regionMap := common.StrToMap(channel.Other) |
| if regionMap["default"] == nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "部署地区必须包含default字段", |
| }) |
| return |
| } |
| } |
| } |
| } |
|
|
| |
| models := strings.Split(channel.Models, ",") |
| for _, model := range models { |
| if len(model) > 255 { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": fmt.Sprintf("模型名称过长: %s", model), |
| }) |
| return |
| } |
| } |
|
|
| err = channel.Insert() |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| |
| middleware.RefreshPrefixChannelsCache(channel.Group) |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| return |
| } |
|
|
| func DeleteChannel(c *gin.Context) { |
| id, _ := strconv.Atoi(c.Param("id")) |
| channel := model.Channel{Id: id} |
| err := channel.Delete() |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| return |
| } |
|
|
| func DeleteDisabledChannel(c *gin.Context) { |
| rows, err := model.DeleteDisabledChannel() |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": rows, |
| }) |
| return |
| } |
|
|
| type ChannelTag struct { |
| Tag string `json:"tag"` |
| NewTag *string `json:"new_tag"` |
| Priority *int64 `json:"priority"` |
| Weight *uint `json:"weight"` |
| ModelMapping *string `json:"model_mapping"` |
| Models *string `json:"models"` |
| Groups *string `json:"groups"` |
| } |
|
|
| func DisableTagChannels(c *gin.Context) { |
| channelTag := ChannelTag{} |
| err := c.ShouldBindJSON(&channelTag) |
| if err != nil || channelTag.Tag == "" { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "参数错误", |
| }) |
| return |
| } |
| err = model.DisableChannelByTag(channelTag.Tag) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| if channels, err := model.GetChannelsByTag(channelTag.Tag, false); err == nil { |
| groupSet := make(map[string]struct{}) |
| for _, ch := range channels { |
| for _, g := range strings.Split(ch.Group, ",") { |
| g = strings.TrimSpace(g) |
| if g != "" { |
| groupSet[g] = struct{}{} |
| } |
| } |
| } |
| for g := range groupSet { |
| middleware.RefreshPrefixChannelsCache(g) |
| } |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| return |
| } |
|
|
| func EnableTagChannels(c *gin.Context) { |
| channelTag := ChannelTag{} |
| err := c.ShouldBindJSON(&channelTag) |
| if err != nil || channelTag.Tag == "" { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "参数错误", |
| }) |
| return |
| } |
| err = model.EnableChannelByTag(channelTag.Tag) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| if channels, err := model.GetChannelsByTag(channelTag.Tag, false); err == nil { |
| groupSet := make(map[string]struct{}) |
| for _, ch := range channels { |
| for _, g := range strings.Split(ch.Group, ",") { |
| g = strings.TrimSpace(g) |
| if g != "" { |
| groupSet[g] = struct{}{} |
| } |
| } |
| } |
| for g := range groupSet { |
| middleware.RefreshPrefixChannelsCache(g) |
| } |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| return |
| } |
|
|
| func EditTagChannels(c *gin.Context) { |
| channelTag := ChannelTag{} |
| err := c.ShouldBindJSON(&channelTag) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "参数错误", |
| }) |
| return |
| } |
| if channelTag.Tag == "" { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "tag不能为空", |
| }) |
| return |
| } |
| err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| return |
| } |
|
|
| type ChannelBatch struct { |
| Ids []int `json:"ids"` |
| Tag *string `json:"tag"` |
| } |
|
|
| func DeleteChannelBatch(c *gin.Context) { |
| channelBatch := ChannelBatch{} |
| err := c.ShouldBindJSON(&channelBatch) |
| if err != nil || len(channelBatch.Ids) == 0 { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "参数错误", |
| }) |
| return |
| } |
| err = model.BatchDeleteChannels(channelBatch.Ids) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": len(channelBatch.Ids), |
| }) |
| return |
| } |
|
|
| func UpdateChannel(c *gin.Context) { |
| channel := model.Channel{} |
| err := c.ShouldBindJSON(&channel) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| if channel.Type == common.ChannelTypeVertexAi { |
| if channel.Other == "" { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "部署地区不能为空", |
| }) |
| return |
| } else { |
| if common.IsJsonStr(channel.Other) { |
| |
| regionMap := common.StrToMap(channel.Other) |
| if regionMap["default"] == nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "部署地区必须包含default字段", |
| }) |
| return |
| } |
| } |
| } |
| } |
| err = channel.Update() |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| |
| middleware.RefreshPrefixChannelsCache(channel.Group) |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": channel, |
| }) |
| return |
| } |
|
|
| func FetchModels(c *gin.Context) { |
| var req struct { |
| BaseURL string `json:"base_url"` |
| Type int `json:"type"` |
| Key string `json:"key"` |
| } |
|
|
| if err := c.ShouldBindJSON(&req); err != nil { |
| c.JSON(http.StatusBadRequest, gin.H{ |
| "success": false, |
| "message": "Invalid request", |
| }) |
| return |
| } |
|
|
| baseURL := req.BaseURL |
| if baseURL == "" { |
| baseURL = common.ChannelBaseURLs[req.Type] |
| } |
|
|
| client := &http.Client{} |
| url := fmt.Sprintf("%s/v1/models", baseURL) |
|
|
| if strings.HasSuffix(baseURL, "/chat/completions") { |
| url = strings.TrimSuffix(baseURL, "/chat/completions") + "/models" |
| } |
|
|
| request, err := http.NewRequest("GET", url, nil) |
| if err != nil { |
| c.JSON(http.StatusInternalServerError, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| |
| key := strings.TrimSpace(req.Key) |
| |
| key = strings.Split(key, "\n")[0] |
| request.Header.Set("Authorization", "Bearer "+key) |
|
|
| response, err := client.Do(request) |
| if err != nil { |
| c.JSON(http.StatusInternalServerError, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| |
| if response.StatusCode != http.StatusOK { |
| c.JSON(http.StatusInternalServerError, gin.H{ |
| "success": false, |
| "message": "Failed to fetch models", |
| }) |
| return |
| } |
| defer response.Body.Close() |
|
|
| var result struct { |
| Data []struct { |
| ID string `json:"id"` |
| } `json:"data"` |
| } |
|
|
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
| c.JSON(http.StatusInternalServerError, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| var models []string |
| for _, model := range result.Data { |
| models = append(models, model.ID) |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "data": models, |
| }) |
| } |
|
|
| func BatchSetChannelTag(c *gin.Context) { |
| channelBatch := ChannelBatch{} |
| err := c.ShouldBindJSON(&channelBatch) |
| if err != nil || len(channelBatch.Ids) == 0 { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": "参数错误", |
| }) |
| return |
| } |
| err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": len(channelBatch.Ids), |
| }) |
| return |
| } |
|
|