dvc890's picture
Upload 42 files
581b6d4 verified
package unofficialapi
import (
"WarpGPT/pkg/common"
"WarpGPT/pkg/funcaptcha"
"WarpGPT/pkg/plugins"
"WarpGPT/pkg/tools"
"bytes"
"encoding/json"
"errors"
"fmt"
http "github.com/bogdanfinn/fhttp"
tls_client "github.com/bogdanfinn/tls-client"
"github.com/pkoukk/tiktoken-go"
"io"
shttp "net/http"
"strings"
"time"
"WarpGPT/pkg/logger"
"WarpGPT/pkg/plugins/service/wsstostream"
"github.com/gin-gonic/gin"
)
var context *plugins.Component
var UnofficialApiProcessInstance UnofficialApiProcess
var tke, _ = tiktoken.GetEncoding("cl100k_base")
type WsResponse struct {
ConversationId string `json:"conversation_id"`
ExpiresAt time.Time `json:"expires_at"`
ResponseId string `json:"response_id"`
WssUrl string `json:"wss_url"`
}
type Context struct {
GinContext *gin.Context
RequestUrl string
RequestClient tls_client.HttpClient
RequestBody io.ReadCloser
RequestParam string
RequestMethod string
RequestHeaders http.Header
}
type UnofficialApiProcess struct {
Context Context
WS *wsstostream.WssToStream
Response *http.Response
ID string
Model string
PromptTokens int
CompletionTokens int
OldString string
Mode string
ImagePointerList []ImagePointer
}
type ImagePointer struct {
Pointer string
Prompt string
}
type Result struct {
ApiRespStrStream ApiRespStrStream
ApiRespStrStreamEnd ApiRespStrStreamEnd
ApiImageGenerationRespStr ApiImageGenerationRespStr
Pass bool
}
func (p *UnofficialApiProcess) SetContext(conversation Context) {
p.Context = conversation
}
func (p *UnofficialApiProcess) GetContext() Context {
return p.Context
}
func (p *UnofficialApiProcess) ProcessMethod() {
context.Logger.Debug("UnofficialApiProcess")
var requestBody map[string]interface{}
err := p.decodeRequestBody(&requestBody)
if err != nil {
return
}
p.ID = IdGenerator()
_, exists := requestBody["model"]
if exists {
p.Model, _ = requestBody["model"].(string)
} else {
p.GetContext().GinContext.JSON(400, gin.H{"error": "Model not provided"})
return
}
if strings.Contains(p.GetContext().RequestParam, "chat/completions") {
p.Mode = "chat"
if err = p.chatApiProcess(requestBody); err != nil {
logger.Log.Error(err)
return
}
}
if strings.Contains(p.GetContext().RequestParam, "images/generations") {
p.Mode = "image"
if err = p.imageApiProcess(requestBody); err != nil {
logger.Log.Error(err)
return
}
}
}
func (p *UnofficialApiProcess) imageApiProcess(requestBody map[string]interface{}) error {
context.Logger.Debug("UnofficialApiProcess imageApiProcess")
response, err := p.MakeRequest(requestBody)
if err != nil {
return err
}
result := new(Result)
result.ApiImageGenerationRespStr = ApiImageGenerationRespStr{}
err = p.response(response, func(p *UnofficialApiProcess, a string) bool {
p.jsonImageProcess(a)
return false
})
if err = p.getImageUrlByPointer(&p.ImagePointerList, result); err != nil {
p.GetContext().GinContext.JSON(500, gin.H{"error": "get image url failed"})
context.Logger.Warning(err)
}
if result.ApiImageGenerationRespStr.Created != 0 {
p.GetContext().GinContext.Header("Content-Type", "application/json")
p.GetContext().GinContext.JSON(response.StatusCode, result.ApiImageGenerationRespStr)
}
if err != nil {
return err
}
return nil
}
func (p *UnofficialApiProcess) chatApiProcess(requestBody map[string]interface{}) error {
context.Logger.Debug("UnofficialApiProcess chatApiProcess")
response, err := p.MakeRequest(requestBody)
if err != nil {
return err
}
value, exists := requestBody["stream"]
if exists && value.(bool) {
err = p.response(response, func(p *UnofficialApiProcess, a string) bool {
data := p.streamChatProcess(a)
if _, err = p.GetContext().GinContext.Writer.Write([]byte(data)); err != nil {
context.Logger.Warning(err)
return true
}
p.GetContext().GinContext.Writer.Flush()
return false
})
if err != nil {
return err
}
} else {
err = p.response(response, func(p *UnofficialApiProcess, a string) bool {
data := p.jsonChatProcess(a)
if data != nil {
context.Logger.Debug("Counting the number of tokens")
p.CompletionTokens = len(tke.Encode(data.Choices[0].Message.Content, nil, nil))
data.Usage.PromptTokens = p.PromptTokens
data.Usage.CompletionTokens = p.CompletionTokens
data.Usage.TotalTokens = p.PromptTokens + p.CompletionTokens
p.GetContext().GinContext.Header("Content-Type", "application/json")
p.GetContext().GinContext.JSON(response.StatusCode, data)
return true
}
return false
})
if err != nil {
return err
}
}
return nil
}
func (p *UnofficialApiProcess) MakeRequest(requestBody map[string]interface{}) (*http.Response, error) {
reqModel, err := p.checkModel(p.Model)
if err != nil {
p.GetContext().GinContext.JSON(400, gin.H{"error": err.Error()})
return nil, err
}
req := GetChatReqStr(reqModel)
if err = p.generateBody(req, requestBody); err != nil {
return nil, err
}
jsonData, _ := json.Marshal(req)
var requestData map[string]interface{}
err = json.Unmarshal(jsonData, &requestData)
if err != nil {
p.GetContext().GinContext.JSON(400, gin.H{"error": err.Error()})
return nil, err
}
request, err := p.createRequest(requestData) //创建请求
if err != nil {
return nil, err
}
ws := wsstostream.NewWssToStream(p.GetContext().RequestHeaders.Get("Authorization"))
err = ws.InitConnect()
p.WS = ws
if err != nil {
logger.Log.Error(err)
p.GetContext().GinContext.JSON(500, gin.H{"error": err.Error()})
return nil, err
}
response, err := p.GetContext().RequestClient.Do(request) //发送请求
common.CopyResponseHeaders(response, p.GetContext().GinContext) //设置响应头
if err != nil {
var responseBody interface{}
err = json.NewDecoder(response.Body).Decode(&responseBody)
if err != nil {
p.GetContext().GinContext.JSON(500, gin.H{"error": err.Error()})
return nil, err
}
p.GetContext().GinContext.JSON(response.StatusCode, responseBody)
return nil, err
}
return response, nil
}
func (p *UnofficialApiProcess) createRequest(requestBody map[string]interface{}) (*http.Request, error) {
context.Logger.Debug("UnofficialApiProcess createRequest")
token, err := p.addArkoseTokenIfNeeded(&requestBody)
if err != nil {
return nil, err
}
bodyBytes, err := json.Marshal(requestBody)
if err != nil {
return nil, err
}
var request *http.Request
if p.Context.RequestBody == shttp.NoBody {
request, err = http.NewRequest(p.Context.RequestMethod, p.Context.RequestUrl, nil)
} else {
request, err = http.NewRequest(p.Context.RequestMethod, p.Context.RequestUrl, bytes.NewBuffer(bodyBytes))
}
if err != nil {
return nil, err
}
if token != "" {
p.addArkoseTokenInHeaderIfNeeded(request, token)
}
p.buildHeaders(request)
p.setCookies(request)
return request, nil
}
func (p *UnofficialApiProcess) setCookies(request *http.Request) {
context.Logger.Debug("UnofficialApiProcess setCookies")
for _, cookie := range p.GetContext().GinContext.Request.Cookies() {
request.AddCookie(&http.Cookie{
Name: cookie.Name,
Value: cookie.Value,
})
}
}
func (p *UnofficialApiProcess) buildHeaders(request *http.Request) {
context.Logger.Debug("UnofficialApiProcess buildHeaders")
headers := map[string]string{
"Host": context.Env.OpenaiHost,
"Origin": "https://" + context.Env.OpenaiHost + "/chat",
"Authorization": p.GetContext().GinContext.Request.Header.Get("Authorization"),
"Connection": "keep-alive",
"User-Agent": context.Env.UserAgent,
"Content-Type": p.GetContext().GinContext.Request.Header.Get("Content-Type"),
}
for key, value := range headers {
request.Header.Set(key, value)
}
if puid := p.GetContext().GinContext.Request.Header.Get("PUID"); puid != "" {
request.Header.Set("cookie", "_puid="+puid+";")
}
}
func (p *UnofficialApiProcess) addArkoseTokenInHeaderIfNeeded(request *http.Request, token string) {
context.Logger.Debug("UnofficialApiProcess addArkoseTokenInHeaderIfNeeded")
request.Header.Set("Openai-Sentinel-Arkose-Token", token)
}
func (p *UnofficialApiProcess) addArkoseTokenIfNeeded(requestBody *map[string]interface{}) (string, error) {
context.Logger.Debug("UnofficialApiProcess addArkoseTokenIfNeeded")
model, exists := (*requestBody)["model"]
if !exists {
return "", nil
}
if strings.HasPrefix(model.(string), "gpt-4") || context.Env.ArkoseMust {
token, err := funcaptcha.GetOpenAIArkoseToken(4, p.GetContext().RequestHeaders.Get("puid"))
if err != nil {
p.GetContext().GinContext.JSON(500, gin.H{"error": "Get ArkoseToken Failed"})
logger.Log.Error(err)
return "", err
}
(*requestBody)["arkose_token"] = token
return token, nil
}
return "", nil
}
func (p *UnofficialApiProcess) streamChatProcess(raw string) string {
result := p.getStreamResp(raw)
if strings.Contains(raw, "[DONE]") {
return "data: " + raw + "\n\n"
} else if result.Pass {
return ""
} else if result.ApiRespStrStreamEnd.Id != "" {
data, err := json.Marshal(result.ApiRespStrStreamEnd)
if err != nil {
context.Logger.Warning(err)
}
return "data: " + string(data) + "\n\n"
} else if result.ApiRespStrStream.Id != "" {
data, err := json.Marshal(result.ApiRespStrStream)
if err != nil {
context.Logger.Warning(err)
}
return "data: " + string(data) + "\n\n"
}
return ""
}
func (p *UnofficialApiProcess) response(response *http.Response, mid func(p *UnofficialApiProcess, a string) bool) error {
context.Logger.Debug("UnofficialApiProcess streamResponse")
var client *tools.SSEClient
if strings.Contains(p.Context.RequestParam, "/ws") {
var jsonData WsResponse
err := json.NewDecoder(response.Body).Decode(&jsonData)
if err != nil {
logger.Log.Error(err)
return err
}
p.WS.ResponseId = jsonData.ResponseId
p.WS.ConversationId = jsonData.ConversationId
p.GetContext().GinContext.Writer.Header().Set("Content-Type", "text/event-stream")
p.GetContext().GinContext.Writer.Header().Set("Cache-Control", "no-cache")
p.GetContext().GinContext.Writer.Header().Set("Connection", "keep-alive")
logger.Log.Debug("wss to stream")
client = tools.NewSSEClient(p.WS)
} else {
client = tools.NewSSEClient(response.Body)
}
events := client.Read()
for event := range events {
if event.Event == "message" {
if mid(p, event.Data) {
return nil
}
}
}
defer client.Close()
return nil
}
func (p *UnofficialApiProcess) jsonChatProcess(raw string) *ApiRespStr {
p.getStreamResp(raw)
if strings.Contains(raw, "[DONE]") {
resp := GetApiRespStr(p.ID)
choice := GetStrChoices()
choice.Message.Content = p.OldString
resp.Choices = append(resp.Choices, *choice)
resp.Model = p.Model
return resp
}
return nil
}
func (p *UnofficialApiProcess) jsonImageProcess(stream string) {
context.Logger.Debug("getImageResp")
var dalleRespStr DALLERespStr
json.Unmarshal([]byte(stream), &dalleRespStr)
if dalleRespStr.Message.Author.Name == "dalle.text2im" && dalleRespStr.Message.Content.ContentType == "multimodal_text" {
context.Logger.Debug("found image")
for _, v := range dalleRespStr.Message.Content.Parts {
item := new(ImagePointer)
item.Pointer = strings.ReplaceAll(v.AssetPointer, "file-service://", "")
item.Prompt = v.Metadata.Dalle.Prompt
p.ImagePointerList = append(p.ImagePointerList, *item)
}
}
}
func (p *UnofficialApiProcess) getImageUrlByPointer(imagePointerList *[]ImagePointer, result *Result) error {
context.Logger.Debug("getImageUrlByPointer")
for _, v := range *imagePointerList {
imageDownloadUrl, err := common.RequestOpenAI[ImageDownloadUrl]("/backend-api/files/"+v.Pointer+"/download", nil, "GET", p.GetContext().RequestHeaders.Get("Authorization"))
if err != nil {
return err
}
if imageDownloadUrl != nil && imageDownloadUrl.DownloadUrl != "" {
context.Logger.Debug("getDownloadUrl")
imageItem := new(ApiImageItem)
result.ApiImageGenerationRespStr.Created = time.Now().Unix()
imageItem.Url = imageDownloadUrl.DownloadUrl
imageItem.RevisedPrompt = v.Prompt
result.ApiImageGenerationRespStr.Data = append(result.ApiImageGenerationRespStr.Data, *imageItem)
}
}
return nil
}
func (p *UnofficialApiProcess) getStreamResp(stream string) *Result {
context.Logger.Debug("getStreamResp")
var chatRespStr ChatRespStr
var chatEndRespStr ChatEndRespStr
result := new(Result)
result.ApiRespStrStreamEnd = ApiRespStrStreamEnd{}
result.ApiRespStrStream = ApiRespStrStream{}
result.Pass = false
json.Unmarshal([]byte(stream), &chatRespStr)
if chatRespStr.Message.Id != "" {
if chatRespStr.Message.Metadata.ParentId == "" {
result.Pass = true
return result
}
context.Logger.Debug("chatRespStr")
resp := GetApiRespStrStream(p.ID)
choice := GetStreamChoice()
resp.Model = p.Model
choice.Delta.Content = strings.ReplaceAll(chatRespStr.Message.Content.Parts[0], p.OldString, "")
p.OldString = chatRespStr.Message.Content.Parts[0]
resp.Choices = resp.Choices[:0]
resp.Choices = append(resp.Choices, *choice)
result.ApiRespStrStream = *resp
}
json.Unmarshal([]byte(stream), &chatEndRespStr)
if chatEndRespStr.IsCompletion {
context.Logger.Debug("chatEndRespStr")
resp := GetApiRespStrStreamEnd(p.ID)
resp.Model = p.Model
result.ApiRespStrStreamEnd = *resp
}
if result.ApiRespStrStream.Id == "" && result.ApiRespStrStreamEnd.Id == "" {
result.Pass = true
}
return result
}
func (p *UnofficialApiProcess) checkModel(model string) (string, error) {
context.Logger.Debug("UnofficialApiProcess checkModel")
if strings.HasPrefix(model, "dall-e") || strings.HasPrefix(model, "gpt-4-vision") {
return "gpt-4", nil
} else if strings.HasPrefix(model, "gpt-3") {
return "text-davinci-002-render-sha", nil
} else if strings.HasPrefix(model, "gpt-4") {
return "gpt-4-gizmo", nil
} else {
return "", errors.New("unsupported model")
}
}
func (p *UnofficialApiProcess) generateBody(req *ChatReqStr, requestBody map[string]interface{}) error {
context.Logger.Debug("UnofficialApiProcess generateBody")
if p.Mode == "chat" {
logger.Log.Debug("Generate Chat Body")
messageList, exists := requestBody["messages"]
if !exists {
return errors.New("no message body")
}
messages, _ := messageList.([]interface{})
for _, message := range messages {
messageItem, _ := message.(map[string]interface{})
role, _ := messageItem["role"].(string)
if _, ok := messageItem["content"].(string); ok {
content, _ := messageItem["content"].(string)
p.PromptTokens += len(tke.Encode(content, nil, nil)) + 7
reqMessage := GetChatReqTemplate()
reqMessage.Content.Parts = reqMessage.Content.Parts[:0]
reqMessage.Author.Role = role
reqMessage.Content.Parts = append(reqMessage.Content.Parts, content)
req.Messages = append(req.Messages, *reqMessage)
}
if _, ok := messageItem["content"].([]map[string]interface{}); ok {
reqFileMessage := GetChatFileReqTemplate()
content, _ := messageItem["content"].([]map[string]interface{})
reqFileMessage.Content.Parts = reqFileMessage.Content.Parts[:0]
reqFileMessage.Author.Role = role
p.fileReqProcess(&content, &reqFileMessage.Content.Parts)
//reqMessage.Content.Parts = append(reqMessage.Content.Parts, content)
//req.Messages = append(req.Messages, *reqFileMessage)
}
}
}
if p.Mode == "image" {
logger.Log.Debug("Generate Image Body")
prompt, exists := requestBody["prompt"]
if !exists {
return errors.New("please provide prompt")
}
count, exists := requestBody["n"]
if !exists {
count = 1
}
size, exists := requestBody["size"]
if !exists {
size = "1024x1024"
}
reqMessage := GetChatReqTemplate()
reqMessage.Content.Parts = reqMessage.Content.Parts[:0]
reqMessage.Author.Role = "user"
reqMessage.Content.Parts = append(reqMessage.Content.Parts, fmt.Sprintf("Requirements for image generation:\n- ImageCount: %d\n- Size: %s\n- Prompt: [%s]\n- Requirements: Using the DALLE tool, each image is generated according to the number of ImageCount. It is not allowed to contain multiple elements in one image. You must call the tool multiple times to generate the number of ImageCount images, and the details of each image are different\n", int(count.(float64)), size.(string), prompt.(string)))
req.Messages = append(req.Messages, *reqMessage)
}
return nil
}
func (p *UnofficialApiProcess) fileReqProcess(content *[]map[string]interface{}, part *[]interface{}) {
}
func (p *UnofficialApiProcess) decodeRequestBody(requestBody *map[string]interface{}) error {
conversation := p.GetContext()
if conversation.RequestBody != shttp.NoBody {
if err := json.NewDecoder(conversation.RequestBody).Decode(requestBody); err != nil {
conversation.GinContext.JSON(400, gin.H{"error": "JSON invalid"})
return err
}
}
return nil
}
type UnOfficialApiRequestUrl struct {
}
func (u UnOfficialApiRequestUrl) Generate(path string, rawquery string) string {
if rawquery == "" {
return "https://" + context.Env.OpenaiHost + "/backend-api" + "/conversation"
}
return "https://" + context.Env.OpenaiHost + "/backend-api" + "/conversation" + "?" + rawquery
}
func (p *UnofficialApiProcess) Run(com *plugins.Component) {
context = com
context.Engine.Any("/r/*path", func(c *gin.Context) {
conversation := common.GetContextPack(c, UnOfficialApiRequestUrl{})
common.Do[Context](new(UnofficialApiProcess), Context(conversation))
})
}