| package client |
|
|
| import ( |
| "bufio" |
| "bytes" |
| "context" |
| dsprotocol "ds2api/internal/deepseek/protocol" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
|
|
| "ds2api/internal/auth" |
| "ds2api/internal/config" |
| ) |
|
|
| const defaultAutoContinueLimit = 8 |
|
|
| type continueOpenFunc func(context.Context, string, int) (*http.Response, error) |
|
|
| type continueState struct { |
| sessionID string |
| responseMessageID int |
| lastStatus string |
| finished bool |
| } |
|
|
| |
| |
| |
| |
| |
| func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, resp *http.Response) *http.Response { |
| if resp == nil || resp.Body == nil { |
| return resp |
| } |
| sessionID, _ := payload["chat_session_id"].(string) |
| sessionID = strings.TrimSpace(sessionID) |
| if sessionID == "" { |
| return resp |
| } |
| config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID) |
| resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) { |
| return c.callContinue(ctx, a, sessionID, responseMessageID, powResp) |
| }) |
| return resp |
| } |
|
|
| |
| func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int, powResp string) (*http.Response, error) { |
| if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 { |
| return nil, errors.New("missing continue identifiers") |
| } |
| clients := c.requestClientsForAuth(ctx, a) |
| headers := c.authHeaders(a.DeepSeekToken) |
| headers["x-ds-pow-response"] = powResp |
| payload := map[string]any{ |
| "chat_session_id": sessionID, |
| "message_id": responseMessageID, |
| "fallback_to_resume": true, |
| } |
| config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID) |
| captureSession := c.capture.Start("deepseek_continue", dsprotocol.DeepSeekContinueURL, a.AccountID, payload) |
| resp, err := c.streamPost(ctx, clients.stream, dsprotocol.DeepSeekContinueURL, headers, payload) |
| if err != nil { |
| return nil, err |
| } |
| if captureSession != nil { |
| resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) |
| } |
| if resp.StatusCode != http.StatusOK { |
| _ = resp.Body.Close() |
| return nil, errors.New("continue failed") |
| } |
| return resp, nil |
| } |
|
|
| |
| |
| func newAutoContinueBody(ctx context.Context, initial io.ReadCloser, sessionID string, maxRounds int, openContinue continueOpenFunc) io.ReadCloser { |
| if initial == nil || strings.TrimSpace(sessionID) == "" || openContinue == nil { |
| return initial |
| } |
| if maxRounds <= 0 { |
| maxRounds = defaultAutoContinueLimit |
| } |
| pr, pw := io.Pipe() |
| go pumpAutoContinue(ctx, pw, initial, continueState{sessionID: sessionID}, maxRounds, openContinue) |
| return pr |
| } |
|
|
| |
| |
| |
| func pumpAutoContinue(ctx context.Context, pw *io.PipeWriter, initial io.ReadCloser, state continueState, maxRounds int, openContinue continueOpenFunc) { |
| defer func() { _ = pw.Close() }() |
| current := initial |
| rounds := 0 |
| for { |
| hadDone, err := streamBodyWithContinueState(ctx, pw, current, &state) |
| _ = current.Close() |
| if err != nil { |
| _ = pw.CloseWithError(err) |
| return |
| } |
| if state.shouldContinue() && rounds < maxRounds { |
| rounds++ |
| config.Logger.Info("[auto_continue] continuing", "round", rounds, "session_id", state.sessionID, "message_id", state.responseMessageID, "status", state.lastStatus) |
| nextResp, err := openContinue(ctx, state.sessionID, state.responseMessageID) |
| if err != nil { |
| config.Logger.Warn("[auto_continue] continue request failed", "round", rounds, "error", err) |
| _ = pw.CloseWithError(err) |
| return |
| } |
| current = nextResp.Body |
| state.prepareForNextRound() |
| continue |
| } |
| |
| if hadDone { |
| if _, err := io.Copy(pw, bytes.NewBufferString("data: [DONE]\n")); err != nil { |
| _ = pw.CloseWithError(err) |
| } |
| } |
| return |
| } |
| } |
|
|
| |
| |
| |
| |
| func streamBodyWithContinueState(ctx context.Context, pw *io.PipeWriter, body io.Reader, state *continueState) (bool, error) { |
| reader := bufio.NewReaderSize(body, 64*1024) |
| hadDone := false |
| for { |
| select { |
| case <-ctx.Done(): |
| return hadDone, ctx.Err() |
| default: |
| } |
| line, err := reader.ReadBytes('\n') |
| if len(line) == 0 && err != nil { |
| if err == io.EOF { |
| return hadDone, nil |
| } |
| return hadDone, err |
| } |
| trimmed := strings.TrimSpace(string(line)) |
| if trimmed != "" { |
| if strings.HasPrefix(trimmed, "data:") { |
| data := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) |
| if data == "[DONE]" { |
| hadDone = true |
| if err != nil && err != io.EOF { |
| return hadDone, err |
| } |
| if err == io.EOF { |
| return hadDone, nil |
| } |
| continue |
| } |
| state.observe(data) |
| } |
| if !strings.HasSuffix(string(line), "\n") { |
| line = append(line, '\n') |
| } |
| if _, copyErr := io.Copy(pw, bytes.NewReader(line)); copyErr != nil { |
| return hadDone, copyErr |
| } |
| } |
| if err != nil { |
| if err == io.EOF { |
| return hadDone, nil |
| } |
| return hadDone, err |
| } |
| } |
| } |
|
|
| |
| func (s *continueState) observe(data string) { |
| if s == nil || strings.TrimSpace(data) == "" { |
| return |
| } |
| var chunk map[string]any |
| if err := json.Unmarshal([]byte(data), &chunk); err != nil { |
| return |
| } |
| |
| if id := intFrom(chunk["response_message_id"]); id > 0 { |
| s.responseMessageID = id |
| } |
| s.observeDirectPatch(asString(chunk["p"]), chunk["v"]) |
| if p, _ := chunk["p"].(string); p == "response" { |
| s.observeBatchPatches("response", chunk["v"]) |
| } else { |
| s.observeBatchPatches("", chunk["v"]) |
| } |
| if v, _ := chunk["v"].(map[string]any); v != nil { |
| s.observeResponseObject(v["response"]) |
| } |
| if message, _ := chunk["message"].(map[string]any); message != nil { |
| s.observeResponseObject(message["response"]) |
| } |
| } |
|
|
| func (s *continueState) observeDirectPatch(path string, value any) { |
| if s == nil { |
| return |
| } |
| switch strings.Trim(strings.TrimSpace(path), "/") { |
| case "response/status", "status", "response/quasi_status", "quasi_status": |
| s.setStatus(asString(value)) |
| case "response/auto_continue", "auto_continue": |
| if v, ok := value.(bool); ok && v { |
| s.lastStatus = "AUTO_CONTINUE" |
| } |
| } |
| } |
|
|
| func (s *continueState) observeResponseObject(raw any) { |
| if s == nil { |
| return |
| } |
| response, _ := raw.(map[string]any) |
| if response == nil { |
| return |
| } |
| if id := intFrom(response["message_id"]); id > 0 { |
| s.responseMessageID = id |
| } |
| s.setStatus(asString(response["status"])) |
| if autoContinue, ok := response["auto_continue"].(bool); ok && autoContinue { |
| s.lastStatus = "AUTO_CONTINUE" |
| } |
| } |
|
|
| func (s *continueState) observeBatchPatches(parentPath string, raw any) { |
| if s == nil { |
| return |
| } |
| patches, ok := raw.([]any) |
| if !ok { |
| return |
| } |
| for _, patch := range patches { |
| m, ok := patch.(map[string]any) |
| if !ok { |
| continue |
| } |
| path := strings.TrimSpace(asString(m["p"])) |
| if path == "" { |
| continue |
| } |
| fullPath := path |
| if parent := strings.Trim(strings.TrimSpace(parentPath), "/"); parent != "" && !strings.Contains(path, "/") { |
| fullPath = parent + "/" + path |
| } |
| switch strings.Trim(strings.TrimSpace(fullPath), "/") { |
| case "response/status", "status", "response/quasi_status", "quasi_status": |
| s.setStatus(asString(m["v"])) |
| case "response/auto_continue", "auto_continue": |
| if v, ok := m["v"].(bool); ok && v { |
| s.lastStatus = "AUTO_CONTINUE" |
| } |
| } |
| } |
| } |
|
|
| func (s *continueState) setStatus(status string) { |
| if s == nil { |
| return |
| } |
| normalized := strings.TrimSpace(status) |
| if normalized == "" { |
| return |
| } |
| s.lastStatus = normalized |
| if strings.EqualFold(normalized, "FINISHED") || strings.EqualFold(normalized, "CONTENT_FILTER") { |
| s.finished = true |
| } |
| } |
|
|
| |
| |
| |
| func (s *continueState) shouldContinue() bool { |
| if s == nil { |
| return false |
| } |
| if s.finished || s.responseMessageID <= 0 || strings.TrimSpace(s.sessionID) == "" { |
| return false |
| } |
| switch strings.ToUpper(strings.TrimSpace(s.lastStatus)) { |
| case "INCOMPLETE", "AUTO_CONTINUE": |
| return true |
| default: |
| return false |
| } |
| } |
|
|
| |
| |
| func (s *continueState) prepareForNextRound() { |
| if s == nil { |
| return |
| } |
| s.finished = false |
| s.lastStatus = "" |
| } |
|
|
| func asString(v any) string { |
| if v == nil { |
| return "" |
| } |
| switch x := v.(type) { |
| case string: |
| return x |
| default: |
| s := strings.TrimSpace(strings.ReplaceAll(strings.TrimSpace(fmt.Sprint(v)), "\u0000", "")) |
| if s == "<nil>" { |
| return "" |
| } |
| return s |
| } |
| } |
|
|