|
|
package controller |
|
|
|
|
|
import ( |
|
|
"bufio" |
|
|
"bytes" |
|
|
"encoding/json" |
|
|
"github.com/gin-gonic/gin" |
|
|
"io" |
|
|
"net/http" |
|
|
"one-api/common" |
|
|
"strings" |
|
|
) |
|
|
|
|
|
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { |
|
|
responseText := "" |
|
|
scanner := bufio.NewScanner(resp.Body) |
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |
|
|
if atEOF && len(data) == 0 { |
|
|
return 0, nil, nil |
|
|
} |
|
|
if i := strings.Index(string(data), "\n"); i >= 0 { |
|
|
return i + 1, data[0:i], nil |
|
|
} |
|
|
if atEOF { |
|
|
return len(data), data, nil |
|
|
} |
|
|
return 0, nil, nil |
|
|
}) |
|
|
dataChan := make(chan string) |
|
|
stopChan := make(chan bool) |
|
|
go func() { |
|
|
for scanner.Scan() { |
|
|
data := scanner.Text() |
|
|
if len(data) < 6 { |
|
|
continue |
|
|
} |
|
|
if data[:6] != "data: " && data[:6] != "[DONE]" { |
|
|
continue |
|
|
} |
|
|
dataChan <- data |
|
|
data = data[6:] |
|
|
if !strings.HasPrefix(data, "[DONE]") { |
|
|
switch relayMode { |
|
|
case RelayModeChatCompletions: |
|
|
var streamResponse ChatCompletionsStreamResponse |
|
|
err := json.Unmarshal([]byte(data), &streamResponse) |
|
|
if err != nil { |
|
|
common.SysError("error unmarshalling stream response: " + err.Error()) |
|
|
continue |
|
|
} |
|
|
for _, choice := range streamResponse.Choices { |
|
|
responseText += choice.Delta.Content |
|
|
} |
|
|
case RelayModeCompletions: |
|
|
var streamResponse CompletionsStreamResponse |
|
|
err := json.Unmarshal([]byte(data), &streamResponse) |
|
|
if err != nil { |
|
|
common.SysError("error unmarshalling stream response: " + err.Error()) |
|
|
continue |
|
|
} |
|
|
for _, choice := range streamResponse.Choices { |
|
|
responseText += choice.Text |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
stopChan <- true |
|
|
}() |
|
|
setEventStreamHeaders(c) |
|
|
c.Stream(func(w io.Writer) bool { |
|
|
select { |
|
|
case data := <-dataChan: |
|
|
if strings.HasPrefix(data, "data: [DONE]") { |
|
|
data = data[:12] |
|
|
} |
|
|
|
|
|
data = strings.TrimSuffix(data, "\r") |
|
|
c.Render(-1, common.CustomEvent{Data: data}) |
|
|
return true |
|
|
case <-stopChan: |
|
|
return false |
|
|
} |
|
|
}) |
|
|
err := resp.Body.Close() |
|
|
if err != nil { |
|
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |
|
|
} |
|
|
return nil, responseText |
|
|
} |
|
|
|
|
|
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { |
|
|
var textResponse TextResponse |
|
|
if consumeQuota { |
|
|
responseBody, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
err = resp.Body.Close() |
|
|
if err != nil { |
|
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
err = json.Unmarshal(responseBody, &textResponse) |
|
|
if err != nil { |
|
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
if textResponse.Error.Type != "" { |
|
|
return &OpenAIErrorWithStatusCode{ |
|
|
OpenAIError: textResponse.Error, |
|
|
StatusCode: resp.StatusCode, |
|
|
}, nil |
|
|
} |
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for k, v := range resp.Header { |
|
|
c.Writer.Header().Set(k, v[0]) |
|
|
} |
|
|
c.Writer.WriteHeader(resp.StatusCode) |
|
|
_, err := io.Copy(c.Writer, resp.Body) |
|
|
if err != nil { |
|
|
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
err = resp.Body.Close() |
|
|
if err != nil { |
|
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |
|
|
} |
|
|
|
|
|
if textResponse.Usage.TotalTokens == 0 { |
|
|
completionTokens := 0 |
|
|
for _, choice := range textResponse.Choices { |
|
|
completionTokens += countTokenText(choice.Message.Content, model) |
|
|
} |
|
|
textResponse.Usage = Usage{ |
|
|
PromptTokens: promptTokens, |
|
|
CompletionTokens: completionTokens, |
|
|
TotalTokens: promptTokens + completionTokens, |
|
|
} |
|
|
} |
|
|
return nil, &textResponse.Usage |
|
|
} |
|
|
|