| package controller | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "net/http" | |
| "strconv" | |
| "strings" | |
| "time" | |
| "github.com/QuantumNous/new-api/common" | |
| "github.com/QuantumNous/new-api/constant" | |
| "github.com/QuantumNous/new-api/model" | |
| "github.com/QuantumNous/new-api/relay/channel/codex" | |
| "github.com/QuantumNous/new-api/service" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| func GetCodexChannelUsage(c *gin.Context) { | |
| channelId, err := strconv.Atoi(c.Param("id")) | |
| if err != nil { | |
| common.ApiError(c, fmt.Errorf("invalid channel id: %w", err)) | |
| return | |
| } | |
| ch, err := model.GetChannelById(channelId, true) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| if ch == nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"}) | |
| return | |
| } | |
| if ch.Type != constant.ChannelTypeCodex { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) | |
| return | |
| } | |
| if ch.ChannelInfo.IsMultiKey { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"}) | |
| return | |
| } | |
| oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key)) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) | |
| return | |
| } | |
| accessToken := strings.TrimSpace(oauthKey.AccessToken) | |
| accountID := strings.TrimSpace(oauthKey.AccountID) | |
| if accessToken == "" { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"}) | |
| return | |
| } | |
| if accountID == "" { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"}) | |
| return | |
| } | |
| client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) | |
| defer cancel() | |
| statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) | |
| return | |
| } | |
| if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" { | |
| refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second) | |
| defer refreshCancel() | |
| res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken) | |
| if refreshErr == nil { | |
| oauthKey.AccessToken = res.AccessToken | |
| oauthKey.RefreshToken = res.RefreshToken | |
| oauthKey.LastRefresh = time.Now().Format(time.RFC3339) | |
| oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339) | |
| if strings.TrimSpace(oauthKey.Type) == "" { | |
| oauthKey.Type = "codex" | |
| } | |
| encoded, encErr := common.Marshal(oauthKey) | |
| if encErr == nil { | |
| _ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error | |
| model.InitChannelCache() | |
| service.ResetProxyClientCache() | |
| } | |
| ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second) | |
| defer cancel2() | |
| statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) | |
| return | |
| } | |
| } | |
| } | |
| var payload any | |
| if json.Unmarshal(body, &payload) != nil { | |
| payload = string(body) | |
| } | |
| ok := statusCode >= 200 && statusCode < 300 | |
| resp := gin.H{ | |
| "success": ok, | |
| "message": "", | |
| "upstream_status": statusCode, | |
| "data": payload, | |
| } | |
| if !ok { | |
| resp["message"] = fmt.Sprintf("upstream status: %d", statusCode) | |
| } | |
| c.JSON(http.StatusOK, resp) | |
| } | |