|
|
package controller |
|
|
|
|
|
import ( |
|
|
"bufio" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"github.com/gin-gonic/gin" |
|
|
"io" |
|
|
"net/http" |
|
|
"one-api/common" |
|
|
"strings" |
|
|
) |
|
|
|
|
|
type ClaudeMetadata struct { |
|
|
UserId string `json:"user_id"` |
|
|
} |
|
|
|
|
|
type ClaudeRequest struct { |
|
|
Model string `json:"model"` |
|
|
Prompt string `json:"prompt"` |
|
|
MaxTokensToSample int `json:"max_tokens_to_sample"` |
|
|
StopSequences []string `json:"stop_sequences,omitempty"` |
|
|
Temperature float64 `json:"temperature,omitempty"` |
|
|
TopP float64 `json:"top_p,omitempty"` |
|
|
TopK int `json:"top_k,omitempty"` |
|
|
|
|
|
Stream bool `json:"stream,omitempty"` |
|
|
} |
|
|
|
|
|
type ClaudeError struct { |
|
|
Type string `json:"type"` |
|
|
Message string `json:"message"` |
|
|
} |
|
|
|
|
|
type ClaudeResponse struct { |
|
|
Completion string `json:"completion"` |
|
|
StopReason string `json:"stop_reason"` |
|
|
Model string `json:"model"` |
|
|
Error ClaudeError `json:"error"` |
|
|
} |
|
|
|
|
|
func stopReasonClaude2OpenAI(reason string) string { |
|
|
switch reason { |
|
|
case "stop_sequence": |
|
|
return "stop" |
|
|
case "max_tokens": |
|
|
return "length" |
|
|
default: |
|
|
return reason |
|
|
} |
|
|
} |
|
|
|
|
|
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { |
|
|
claudeRequest := ClaudeRequest{ |
|
|
Model: textRequest.Model, |
|
|
Prompt: "", |
|
|
MaxTokensToSample: textRequest.MaxTokens, |
|
|
StopSequences: nil, |
|
|
Temperature: textRequest.Temperature, |
|
|
TopP: textRequest.TopP, |
|
|
Stream: textRequest.Stream, |
|
|
} |
|
|
if claudeRequest.MaxTokensToSample == 0 { |
|
|
claudeRequest.MaxTokensToSample = 1000000 |
|
|
} |
|
|
prompt := "" |
|
|
for _, message := range textRequest.Messages { |
|
|
if message.Role == "user" { |
|
|
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) |
|
|
} else if message.Role == "assistant" { |
|
|
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) |
|
|
} else if message.Role == "system" { |
|
|
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) |
|
|
} |
|
|
} |
|
|
prompt += "\n\nAssistant:" |
|
|
claudeRequest.Prompt = prompt |
|
|
return &claudeRequest |
|
|
} |
|
|
|
|
|
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { |
|
|
var choice ChatCompletionsStreamResponseChoice |
|
|
choice.Delta.Content = claudeResponse.Completion |
|
|
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) |
|
|
if finishReason != "null" { |
|
|
choice.FinishReason = &finishReason |
|
|
} |
|
|
var response ChatCompletionsStreamResponse |
|
|
response.Object = "chat.completion.chunk" |
|
|
response.Model = claudeResponse.Model |
|
|
response.Choices = []ChatCompletionsStreamResponseChoice{choice} |
|
|
return &response |
|
|
} |
|
|
|
|
|
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { |
|
|
choice := OpenAITextResponseChoice{ |
|
|
Index: 0, |
|
|
Message: Message{ |
|
|
Role: "assistant", |
|
|
Content: strings.TrimPrefix(claudeResponse.Completion, " "), |
|
|
Name: nil, |
|
|
}, |
|
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), |
|
|
} |
|
|
fullTextResponse := OpenAITextResponse{ |
|
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), |
|
|
Object: "chat.completion", |
|
|
Created: common.GetTimestamp(), |
|
|
Choices: []OpenAITextResponseChoice{choice}, |
|
|
} |
|
|
return &fullTextResponse |
|
|
} |
|
|
|
|
|
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { |
|
|
responseText := "" |
|
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) |
|
|
createdTime := common.GetTimestamp() |
|
|
scanner := bufio.NewScanner(resp.Body) |
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |
|
|
if atEOF && len(data) == 0 { |
|
|
return 0, nil, nil |
|
|
} |
|
|
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { |
|
|
return i + 4, data[0:i], nil |
|
|
} |
|
|
if atEOF { |
|
|
return len(data), data, nil |
|
|
} |
|
|
return 0, nil, nil |
|
|
}) |
|
|
dataChan := make(chan string) |
|
|
stopChan := make(chan bool) |
|
|
go func() { |
|
|
for scanner.Scan() { |
|
|
data := scanner.Text() |
|
|
if !strings.HasPrefix(data, "event: completion") { |
|
|
continue |
|
|
} |
|
|
data = strings.TrimPrefix(data, "event: completion\r\ndata: ") |
|
|
dataChan <- data |
|
|
} |
|
|
stopChan <- true |
|
|
}() |
|
|
setEventStreamHeaders(c) |
|
|
c.Stream(func(w io.Writer) bool { |
|
|
select { |
|
|
case data := <-dataChan: |
|
|
|
|
|
data = strings.TrimSuffix(data, "\r") |
|
|
var claudeResponse ClaudeResponse |
|
|
err := json.Unmarshal([]byte(data), &claudeResponse) |
|
|
if err != nil { |
|
|
common.SysError("error unmarshalling stream response: " + err.Error()) |
|
|
return true |
|
|
} |
|
|
responseText += claudeResponse.Completion |
|
|
response := streamResponseClaude2OpenAI(&claudeResponse) |
|
|
response.Id = responseId |
|
|
response.Created = createdTime |
|
|
jsonStr, err := json.Marshal(response) |
|
|
if err != nil { |
|
|
common.SysError("error marshalling stream response: " + err.Error()) |
|
|
return true |
|
|
} |
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) |
|
|
return true |
|
|
case <-stopChan: |
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |
|
|
return false |
|
|
} |
|
|
}) |
|
|
err := resp.Body.Close() |
|
|
if err != nil { |
|
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |
|
|
} |
|
|
return nil, responseText |
|
|
} |
|
|
|
|
|
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { |
|
|
responseBody, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
err = resp.Body.Close() |
|
|
if err != nil { |
|
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
var claudeResponse ClaudeResponse |
|
|
err = json.Unmarshal(responseBody, &claudeResponse) |
|
|
if err != nil { |
|
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
if claudeResponse.Error.Type != "" { |
|
|
return &OpenAIErrorWithStatusCode{ |
|
|
OpenAIError: OpenAIError{ |
|
|
Message: claudeResponse.Error.Message, |
|
|
Type: claudeResponse.Error.Type, |
|
|
Param: "", |
|
|
Code: claudeResponse.Error.Type, |
|
|
}, |
|
|
StatusCode: resp.StatusCode, |
|
|
}, nil |
|
|
} |
|
|
fullTextResponse := responseClaude2OpenAI(&claudeResponse) |
|
|
completionTokens := countTokenText(claudeResponse.Completion, model) |
|
|
usage := Usage{ |
|
|
PromptTokens: promptTokens, |
|
|
CompletionTokens: completionTokens, |
|
|
TotalTokens: promptTokens + completionTokens, |
|
|
} |
|
|
fullTextResponse.Usage = usage |
|
|
jsonResponse, err := json.Marshal(fullTextResponse) |
|
|
if err != nil { |
|
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
c.Writer.Header().Set("Content-Type", "application/json") |
|
|
c.Writer.WriteHeader(resp.StatusCode) |
|
|
_, err = c.Writer.Write(jsonResponse) |
|
|
return nil, &usage |
|
|
} |
|
|
|