| | package relay |
| |
|
| | import ( |
| | "encoding/json" |
| | "fmt" |
| | "github.com/gin-gonic/gin" |
| | "github.com/gorilla/websocket" |
| | "net/http" |
| | "one-api/common" |
| | "one-api/dto" |
| | relaycommon "one-api/relay/common" |
| | "one-api/service" |
| | "one-api/setting" |
| | "one-api/setting/operation_setting" |
| | ) |
| |
|
| | func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { |
| | relayInfo := relaycommon.GenRelayInfoWs(c, ws) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | modelMapping := c.GetString("model_mapping") |
| | |
| | if modelMapping != "" && modelMapping != "{}" { |
| | modelMap := make(map[string]string) |
| | err := json.Unmarshal([]byte(modelMapping), &modelMap) |
| | if err != nil { |
| | return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) |
| | } |
| | if modelMap[relayInfo.OriginModelName] != "" { |
| | relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName] |
| | |
| | |
| | } |
| | } |
| | |
| | modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false) |
| | groupRatio := setting.GetGroupRatio(relayInfo.Group) |
| |
|
| | var preConsumedQuota int |
| | var ratio float64 |
| | var modelRatio float64 |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if !getModelPriceSuccess { |
| | preConsumedTokens := common.PreConsumedQuota |
| | |
| | |
| | |
| | modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName) |
| | ratio = modelRatio * groupRatio |
| | preConsumedQuota = int(float64(preConsumedTokens) * ratio) |
| | } else { |
| | preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) |
| | relayInfo.UsePrice = true |
| | } |
| |
|
| | |
| | preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) |
| | if openaiErr != nil { |
| | return openaiErr |
| | } |
| |
|
| | defer func() { |
| | if openaiErr != nil { |
| | returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) |
| | } |
| | }() |
| |
|
| | adaptor := GetAdaptor(relayInfo.ApiType) |
| | if adaptor == nil { |
| | return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) |
| | } |
| | adaptor.Init(relayInfo) |
| | |
| | |
| | |
| |
|
| | statusCodeMappingStr := c.GetString("status_code_mapping") |
| | resp, err := adaptor.DoRequest(c, relayInfo, nil) |
| | if err != nil { |
| | return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) |
| | } |
| |
|
| | if resp != nil { |
| | relayInfo.TargetWs = resp.(*websocket.Conn) |
| | defer relayInfo.TargetWs.Close() |
| | } |
| |
|
| | usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo) |
| | if openaiErr != nil { |
| | |
| | service.ResetStatusCode(openaiErr, statusCodeMappingStr) |
| | return openaiErr |
| | } |
| | service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, |
| | userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") |
| | return nil |
| | } |
| |
|