Spaces:
Runtime error
Runtime error
| package chatgpt | |
| import ( | |
| "bytes" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "strings" | |
| "github.com/PuerkitoBio/goquery" | |
| "github.com/gin-gonic/gin" | |
| "github.com/linweiyuan/go-chatgpt-api/api" | |
| "github.com/linweiyuan/go-chatgpt-api/util/logger" | |
| http "github.com/bogdanfinn/fhttp" | |
| ) | |
| //goland:noinspection GoUnhandledErrorResult | |
| func GetConversations(c *gin.Context) { | |
| offset, ok := c.GetQuery("offset") | |
| if !ok { | |
| offset = "0" | |
| } | |
| limit, ok := c.GetQuery("limit") | |
| if !ok { | |
| limit = "20" | |
| } | |
| handleGet(c, apiPrefix+"/conversations?offset="+offset+"&limit="+limit, getConversationsErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func CreateConversation(c *gin.Context) { | |
| var request CreateConversationRequest | |
| if err := c.BindJSON(&request); err != nil { | |
| c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage)) | |
| return | |
| } | |
| if request.ConversationID == nil || *request.ConversationID == "" { | |
| request.ConversationID = nil | |
| } | |
| if request.Messages[0].Author.Role == "" { | |
| request.Messages[0].Author.Role = defaultRole | |
| } | |
| if request.Model == gpt4Model { | |
| formParams := fmt.Sprintf( | |
| "public_key=%s", | |
| gpt4PublicKey, | |
| ) | |
| req, _ := http.NewRequest(http.MethodPost, gpt4TokenUrl, strings.NewReader(formParams)) | |
| req.Header.Set("Content-Type", api.ContentType) | |
| resp, err := api.Client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| responseMap := make(map[string]string) | |
| json.NewDecoder(resp.Body).Decode(&responseMap) | |
| request.ArkoseToken = responseMap["token"] | |
| } | |
| jsonBytes, _ := json.Marshal(request) | |
| logger.Info(fmt.Sprintf("ConversationRequest: %s", jsonBytes)) | |
| req, _ := http.NewRequest(http.MethodPost, apiPrefix+"/conversation", bytes.NewBuffer(jsonBytes)) | |
| req.Header.Set("User-Agent", api.UserAgent) | |
| req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader))) | |
| req.Header.Set("Accept", "text/event-stream") | |
| resp, err := api.Client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| if resp.StatusCode != http.StatusOK { | |
| responseMap := make(map[string]interface{}) | |
| json.NewDecoder(resp.Body).Decode(&responseMap) | |
| c.AbortWithStatusJSON(resp.StatusCode, responseMap) | |
| resp.Body.Close() | |
| return | |
| } | |
| c.Set("oldpart", "") | |
| Status, ParentMessageID, part := api.HandleConversationResponse(c, resp) | |
| if Status { | |
| resp.Body.Close() | |
| ContinueConversation(c, *request.ConversationID, ParentMessageID, request.Model, part) | |
| } else { | |
| resp.Body.Close() | |
| } | |
| } | |
| func ContinueConversation(c *gin.Context, conversationID string, parentMessageID string, model string, oldpart string) { | |
| var request ContinueConversationRequest | |
| request.ConversationID = &conversationID | |
| request.ParentMessageID = parentMessageID | |
| request.Model = model | |
| request.Action = "continue" | |
| if request.Model == gpt4Model { | |
| formParams := fmt.Sprintf( | |
| "public_key=%s", | |
| gpt4PublicKey, | |
| ) | |
| req, _ := http.NewRequest(http.MethodPost, gpt4TokenUrl, strings.NewReader(formParams)) | |
| req.Header.Set("Content-Type", api.ContentType) | |
| resp, err := api.Client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| responseMap := make(map[string]string) | |
| json.NewDecoder(resp.Body).Decode(&responseMap) | |
| request.ArkoseToken = responseMap["token"] | |
| } | |
| jsonBytes, _ := json.Marshal(request) | |
| logger.Info(fmt.Sprintf("ContinueConversationRequest: %s", jsonBytes)) | |
| req, _ := http.NewRequest(http.MethodPost, apiPrefix+"/conversation", bytes.NewBuffer(jsonBytes)) | |
| req.Header.Set("User-Agent", api.UserAgent) | |
| req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader))) | |
| req.Header.Set("Accept", "text/event-stream") | |
| resp, err := api.Client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| if resp.StatusCode != http.StatusOK { | |
| responseMap := make(map[string]interface{}) | |
| json.NewDecoder(resp.Body).Decode(&responseMap) | |
| c.AbortWithStatusJSON(resp.StatusCode, responseMap) | |
| resp.Body.Close() | |
| return | |
| } | |
| c.Set("oldpart", oldpart) | |
| Status, ParentMessageID, part := api.HandleConversationResponse(c, resp) | |
| if Status { | |
| resp.Body.Close() | |
| ContinueConversation(c, *request.ConversationID, ParentMessageID, request.Model, part) | |
| } else { | |
| resp.Body.Close() | |
| } | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func GenerateTitle(c *gin.Context) { | |
| var request GenerateTitleRequest | |
| if err := c.BindJSON(&request); err != nil { | |
| c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage)) | |
| return | |
| } | |
| jsonBytes, _ := json.Marshal(request) | |
| handlePost(c, apiPrefix+"/conversation/gen_title/"+c.Param("id"), string(jsonBytes), generateTitleErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func GetConversation(c *gin.Context) { | |
| handleGet(c, apiPrefix+"/conversation/"+c.Param("id"), getContentErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func UpdateConversation(c *gin.Context) { | |
| var request PatchConversationRequest | |
| if err := c.BindJSON(&request); err != nil { | |
| c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage)) | |
| return | |
| } | |
| // bool default to false, then will hide (delete) the conversation | |
| if request.Title != nil { | |
| request.IsVisible = true | |
| } | |
| jsonBytes, _ := json.Marshal(request) | |
| handlePatch(c, apiPrefix+"/conversation/"+c.Param("id"), string(jsonBytes), updateConversationErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func FeedbackMessage(c *gin.Context) { | |
| var request FeedbackMessageRequest | |
| if err := c.BindJSON(&request); err != nil { | |
| c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage)) | |
| return | |
| } | |
| jsonBytes, _ := json.Marshal(request) | |
| handlePost(c, apiPrefix+"/conversation/message_feedback", string(jsonBytes), feedbackMessageErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func ClearConversations(c *gin.Context) { | |
| jsonBytes, _ := json.Marshal(PatchConversationRequest{ | |
| IsVisible: false, | |
| }) | |
| handlePatch(c, apiPrefix+"/conversations", string(jsonBytes), clearConversationsErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func GetModels(c *gin.Context) { | |
| handleGet(c, apiPrefix+"/models", getModelsErrorMessage) | |
| } | |
| func GetAccountCheck(c *gin.Context) { | |
| handleGet(c, apiPrefix+"/accounts/check", getAccountCheckErrorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func Login(c *gin.Context) { | |
| var loginInfo api.LoginInfo | |
| if err := c.ShouldBindJSON(&loginInfo); err != nil { | |
| c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(api.ParseUserInfoErrorMessage)) | |
| return | |
| } | |
| userLogin := UserLogin{ | |
| client: api.NewHttpClient(), | |
| } | |
| // get csrf token | |
| req, _ := http.NewRequest(http.MethodGet, csrfUrl, nil) | |
| req.Header.Set("User-Agent", api.UserAgent) | |
| resp, err := userLogin.client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| if resp.StatusCode == http.StatusForbidden { | |
| doc, _ := goquery.NewDocumentFromReader(resp.Body) | |
| alert := doc.Find(".message").Text() | |
| if alert != "" { | |
| c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(strings.TrimSpace(alert))) | |
| return | |
| } | |
| } | |
| c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(getCsrfTokenErrorMessage)) | |
| return | |
| } | |
| // get authorized url | |
| responseMap := make(map[string]string) | |
| json.NewDecoder(resp.Body).Decode(&responseMap) | |
| authorizedUrl, statusCode, err := userLogin.GetAuthorizedUrl(responseMap["csrfToken"]) | |
| if err != nil { | |
| c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| // get state | |
| state, statusCode, err := userLogin.GetState(authorizedUrl) | |
| if err != nil { | |
| c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| // check username | |
| statusCode, err = userLogin.CheckUsername(state, loginInfo.Username) | |
| if err != nil { | |
| c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| // check password | |
| _, statusCode, err = userLogin.CheckPassword(state, loginInfo.Username, loginInfo.Password) | |
| if err != nil { | |
| c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| // get access token | |
| accessToken, statusCode, err := userLogin.GetAccessToken("") | |
| if err != nil { | |
| c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| c.Writer.WriteString(accessToken) | |
| } | |
| func Fallback(c *gin.Context) { | |
| method := c.Request.Method | |
| url := apiPrefix + c.Request.URL.Path | |
| queryParams := c.Request.URL.Query().Encode() | |
| if queryParams != "" { | |
| url += "?" + queryParams | |
| } | |
| var requestBody string | |
| if c.Request.Method == http.MethodPost || c.Request.Method == http.MethodPatch { | |
| body, _ := io.ReadAll(c.Request.Body) | |
| requestBody = string(body) | |
| } | |
| c.Status(http.StatusOK) | |
| switch method { | |
| case http.MethodGet: | |
| handleGet(c, url, fallbackErrorMessage) | |
| case http.MethodPost: | |
| handlePost(c, url, requestBody, fallbackErrorMessage) | |
| case http.MethodPatch: | |
| handlePatch(c, url, requestBody, fallbackErrorMessage) | |
| default: | |
| c.JSON(http.StatusMethodNotAllowed, gin.H{"message": fallbackMethodNotAllowedMessage}) | |
| } | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func handleGet(c *gin.Context, url string, errorMessage string) { | |
| req, _ := http.NewRequest(http.MethodGet, url, nil) | |
| req.Header.Set("User-Agent", api.UserAgent) | |
| req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader))) | |
| resp, err := api.Client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(errorMessage)) | |
| return | |
| } | |
| io.Copy(c.Writer, resp.Body) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func handlePost(c *gin.Context, url string, requestBody string, errorMessage string) { | |
| req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) | |
| handlePostOrPatch(c, req, errorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func handlePatch(c *gin.Context, url string, requestBody string, errorMessage string) { | |
| req, _ := http.NewRequest(http.MethodPatch, url, strings.NewReader(requestBody)) | |
| handlePostOrPatch(c, req, errorMessage) | |
| } | |
| //goland:noinspection GoUnhandledErrorResult | |
| func handlePostOrPatch(c *gin.Context, req *http.Request, errorMessage string) { | |
| req.Header.Set("User-Agent", api.UserAgent) | |
| req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader))) | |
| resp, err := api.Client.Do(req) | |
| if err != nil { | |
| c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(errorMessage)) | |
| return | |
| } | |
| io.Copy(c.Writer, resp.Body) | |
| } | |