| package common_handler | |
| import ( | |
| "io" | |
| "net/http" | |
| "github.com/QuantumNous/new-api/common" | |
| "github.com/QuantumNous/new-api/constant" | |
| "github.com/QuantumNous/new-api/dto" | |
| "github.com/QuantumNous/new-api/relay/channel/xinference" | |
| relaycommon "github.com/QuantumNous/new-api/relay/common" | |
| "github.com/QuantumNous/new-api/service" | |
| "github.com/QuantumNous/new-api/types" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { | |
| responseBody, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) | |
| } | |
| service.CloseResponseBodyGracefully(resp) | |
| if common.DebugEnabled { | |
| println("reranker response body: ", string(responseBody)) | |
| } | |
| var jinaResp dto.RerankResponse | |
| if info.ChannelType == constant.ChannelTypeXinference { | |
| var xinRerankResponse xinference.XinRerankResponse | |
| err = common.Unmarshal(responseBody, &xinRerankResponse) | |
| if err != nil { | |
| return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) | |
| } | |
| jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) | |
| for i, result := range xinRerankResponse.Results { | |
| respResult := dto.RerankResponseResult{ | |
| Index: result.Index, | |
| RelevanceScore: result.RelevanceScore, | |
| } | |
| if info.ReturnDocuments { | |
| var document any | |
| if result.Document != nil { | |
| if doc, ok := result.Document.(string); ok { | |
| if doc == "" { | |
| document = info.Documents[result.Index] | |
| } else { | |
| document = doc | |
| } | |
| } else { | |
| document = result.Document | |
| } | |
| } | |
| respResult.Document = document | |
| } | |
| jinaRespResults[i] = respResult | |
| } | |
| jinaResp = dto.RerankResponse{ | |
| Results: jinaRespResults, | |
| Usage: dto.Usage{ | |
| PromptTokens: info.PromptTokens, | |
| TotalTokens: info.PromptTokens, | |
| }, | |
| } | |
| } else { | |
| err = common.Unmarshal(responseBody, &jinaResp) | |
| if err != nil { | |
| return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) | |
| } | |
| jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens | |
| } | |
| c.Writer.Header().Set("Content-Type", "application/json") | |
| c.JSON(http.StatusOK, jinaResp) | |
| return &jinaResp.Usage, nil | |
| } | |