File size: 9,804 Bytes
8d3471e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
package client

import (
	"bufio"
	"bytes"
	"context"
	dsprotocol "ds2api/internal/deepseek/protocol"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"strings"

	"ds2api/internal/auth"
	"ds2api/internal/config"
)

const defaultAutoContinueLimit = 8

type continueOpenFunc func(context.Context, string, int) (*http.Response, error)

type continueState struct {
	sessionID         string
	responseMessageID int
	lastStatus        string
	finished          bool
}

// wrapCompletionWithAutoContinue wraps the completion response body so that
// if the upstream indicates the response is incomplete (INCOMPLETE /
// AUTO_CONTINUE), ds2api will automatically call the DeepSeek continue
// endpoint and splice the continuation SSE stream onto the original.
// The caller sees a single, seamless SSE stream.
func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, resp *http.Response) *http.Response {
	if resp == nil || resp.Body == nil {
		return resp
	}
	sessionID, _ := payload["chat_session_id"].(string)
	sessionID = strings.TrimSpace(sessionID)
	if sessionID == "" {
		return resp
	}
	config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID)
	resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) {
		return c.callContinue(ctx, a, sessionID, responseMessageID, powResp)
	})
	return resp
}

// callContinue sends a continue request to DeepSeek to resume generation.
func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int, powResp string) (*http.Response, error) {
	if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 {
		return nil, errors.New("missing continue identifiers")
	}
	clients := c.requestClientsForAuth(ctx, a)
	headers := c.authHeaders(a.DeepSeekToken)
	headers["x-ds-pow-response"] = powResp
	payload := map[string]any{
		"chat_session_id":    sessionID,
		"message_id":         responseMessageID,
		"fallback_to_resume": true,
	}
	config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID)
	captureSession := c.capture.Start("deepseek_continue", dsprotocol.DeepSeekContinueURL, a.AccountID, payload)
	resp, err := c.streamPost(ctx, clients.stream, dsprotocol.DeepSeekContinueURL, headers, payload)
	if err != nil {
		return nil, err
	}
	if captureSession != nil {
		resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
	}
	if resp.StatusCode != http.StatusOK {
		_ = resp.Body.Close()
		return nil, errors.New("continue failed")
	}
	return resp, nil
}

// newAutoContinueBody returns a new ReadCloser that transparently pumps
// continuation rounds via an io.Pipe.
func newAutoContinueBody(ctx context.Context, initial io.ReadCloser, sessionID string, maxRounds int, openContinue continueOpenFunc) io.ReadCloser {
	if initial == nil || strings.TrimSpace(sessionID) == "" || openContinue == nil {
		return initial
	}
	if maxRounds <= 0 {
		maxRounds = defaultAutoContinueLimit
	}
	pr, pw := io.Pipe()
	go pumpAutoContinue(ctx, pw, initial, continueState{sessionID: sessionID}, maxRounds, openContinue)
	return pr
}

// pumpAutoContinue is the goroutine that drives the auto-continue loop.
// It reads the initial SSE body, checks whether a continue is required,
// and if so opens a new continue stream and splices it onto the pipe writer.
func pumpAutoContinue(ctx context.Context, pw *io.PipeWriter, initial io.ReadCloser, state continueState, maxRounds int, openContinue continueOpenFunc) {
	defer func() { _ = pw.Close() }()
	current := initial
	rounds := 0
	for {
		hadDone, err := streamBodyWithContinueState(ctx, pw, current, &state)
		_ = current.Close()
		if err != nil {
			_ = pw.CloseWithError(err)
			return
		}
		if state.shouldContinue() && rounds < maxRounds {
			rounds++
			config.Logger.Info("[auto_continue] continuing", "round", rounds, "session_id", state.sessionID, "message_id", state.responseMessageID, "status", state.lastStatus)
			nextResp, err := openContinue(ctx, state.sessionID, state.responseMessageID)
			if err != nil {
				config.Logger.Warn("[auto_continue] continue request failed", "round", rounds, "error", err)
				_ = pw.CloseWithError(err)
				return
			}
			current = nextResp.Body
			state.prepareForNextRound()
			continue
		}
		// Emit the final [DONE] sentinel if the upstream had one.
		if hadDone {
			if _, err := io.Copy(pw, bytes.NewBufferString("data: [DONE]\n")); err != nil {
				_ = pw.CloseWithError(err)
			}
		}
		return
	}
}

// streamBodyWithContinueState scans an SSE body line-by-line, writing each
// line through to pw while observing state signals. Intermediate [DONE]
// sentinels are consumed (not forwarded) so that the downstream only sees
// one final [DONE] at the very end.
func streamBodyWithContinueState(ctx context.Context, pw *io.PipeWriter, body io.Reader, state *continueState) (bool, error) {
	reader := bufio.NewReaderSize(body, 64*1024)
	hadDone := false
	for {
		select {
		case <-ctx.Done():
			return hadDone, ctx.Err()
		default:
		}
		line, err := reader.ReadBytes('\n')
		if len(line) == 0 && err != nil {
			if err == io.EOF {
				return hadDone, nil
			}
			return hadDone, err
		}
		trimmed := strings.TrimSpace(string(line))
		if trimmed != "" {
			if strings.HasPrefix(trimmed, "data:") {
				data := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
				if data == "[DONE]" {
					hadDone = true
					if err != nil && err != io.EOF {
						return hadDone, err
					}
					if err == io.EOF {
						return hadDone, nil
					}
					continue
				}
				state.observe(data)
			}
			if !strings.HasSuffix(string(line), "\n") {
				line = append(line, '\n')
			}
			if _, copyErr := io.Copy(pw, bytes.NewReader(line)); copyErr != nil {
				return hadDone, copyErr
			}
		}
		if err != nil {
			if err == io.EOF {
				return hadDone, nil
			}
			return hadDone, err
		}
	}
}

// observe extracts continue-relevant signals from an SSE JSON chunk.
func (s *continueState) observe(data string) {
	if s == nil || strings.TrimSpace(data) == "" {
		return
	}
	var chunk map[string]any
	if err := json.Unmarshal([]byte(data), &chunk); err != nil {
		return
	}
	// Top-level response_message_id
	if id := intFrom(chunk["response_message_id"]); id > 0 {
		s.responseMessageID = id
	}
	s.observeDirectPatch(asString(chunk["p"]), chunk["v"])
	if p, _ := chunk["p"].(string); p == "response" {
		s.observeBatchPatches("response", chunk["v"])
	} else {
		s.observeBatchPatches("", chunk["v"])
	}
	if v, _ := chunk["v"].(map[string]any); v != nil {
		s.observeResponseObject(v["response"])
	}
	if message, _ := chunk["message"].(map[string]any); message != nil {
		s.observeResponseObject(message["response"])
	}
}

func (s *continueState) observeDirectPatch(path string, value any) {
	if s == nil {
		return
	}
	switch strings.Trim(strings.TrimSpace(path), "/") {
	case "response/status", "status", "response/quasi_status", "quasi_status":
		s.setStatus(asString(value))
	case "response/auto_continue", "auto_continue":
		if v, ok := value.(bool); ok && v {
			s.lastStatus = "AUTO_CONTINUE"
		}
	}
}

func (s *continueState) observeResponseObject(raw any) {
	if s == nil {
		return
	}
	response, _ := raw.(map[string]any)
	if response == nil {
		return
	}
	if id := intFrom(response["message_id"]); id > 0 {
		s.responseMessageID = id
	}
	s.setStatus(asString(response["status"]))
	if autoContinue, ok := response["auto_continue"].(bool); ok && autoContinue {
		s.lastStatus = "AUTO_CONTINUE"
	}
}

func (s *continueState) observeBatchPatches(parentPath string, raw any) {
	if s == nil {
		return
	}
	patches, ok := raw.([]any)
	if !ok {
		return
	}
	for _, patch := range patches {
		m, ok := patch.(map[string]any)
		if !ok {
			continue
		}
		path := strings.TrimSpace(asString(m["p"]))
		if path == "" {
			continue
		}
		fullPath := path
		if parent := strings.Trim(strings.TrimSpace(parentPath), "/"); parent != "" && !strings.Contains(path, "/") {
			fullPath = parent + "/" + path
		}
		switch strings.Trim(strings.TrimSpace(fullPath), "/") {
		case "response/status", "status", "response/quasi_status", "quasi_status":
			s.setStatus(asString(m["v"]))
		case "response/auto_continue", "auto_continue":
			if v, ok := m["v"].(bool); ok && v {
				s.lastStatus = "AUTO_CONTINUE"
			}
		}
	}
}

func (s *continueState) setStatus(status string) {
	if s == nil {
		return
	}
	normalized := strings.TrimSpace(status)
	if normalized == "" {
		return
	}
	s.lastStatus = normalized
	if strings.EqualFold(normalized, "FINISHED") || strings.EqualFold(normalized, "CONTENT_FILTER") {
		s.finished = true
	}
}

// shouldContinue returns true when the upstream explicitly indicates the
// response is incomplete and we have enough information to issue a continue
// request. Plain WIP is not sufficient because normal streams begin in WIP.
func (s *continueState) shouldContinue() bool {
	if s == nil {
		return false
	}
	if s.finished || s.responseMessageID <= 0 || strings.TrimSpace(s.sessionID) == "" {
		return false
	}
	switch strings.ToUpper(strings.TrimSpace(s.lastStatus)) {
	case "INCOMPLETE", "AUTO_CONTINUE":
		return true
	default:
		return false
	}
}

// prepareForNextRound resets ephemeral state before processing the next
// continuation stream.
func (s *continueState) prepareForNextRound() {
	if s == nil {
		return
	}
	s.finished = false
	s.lastStatus = ""
}

func asString(v any) string {
	if v == nil {
		return ""
	}
	switch x := v.(type) {
	case string:
		return x
	default:
		s := strings.TrimSpace(strings.ReplaceAll(strings.TrimSpace(fmt.Sprint(v)), "\u0000", ""))
		if s == "<nil>" {
			return ""
		}
		return s
	}
}