NtGdi commited on
Commit
12aa997
·
1 Parent(s): 553c00b

feat: support search result references

Browse files
Files changed (2) hide show
  1. internal/chat.go +88 -16
  2. internal/models.go +147 -21
internal/chat.go CHANGED
@@ -286,6 +286,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
286
  hasContent := false
287
  searchRefFilter := NewSearchRefFilter()
288
  thinkingFilter := &ThinkingFilter{}
 
289
 
290
  for scanner.Scan() {
291
  line := scanner.Text()
@@ -311,7 +312,28 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
311
 
312
  // 处理思考阶段的增量内容
313
  if upstream.Data.Phase == "thinking" && upstream.Data.DeltaContent != "" {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  reasoningContent := thinkingFilter.ProcessThinking(upstream.Data.DeltaContent)
 
315
  if reasoningContent != "" {
316
  hasContent = true
317
  chunk := ChatCompletionChunk{
@@ -332,17 +354,21 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
332
  continue
333
  }
334
 
335
- // 跳过搜索结果内容和搜索工具调用
336
- if upstream.Data.EditContent != "" && (IsSearchResultContent(upstream.Data.EditContent) || IsSearchToolCall(upstream.Data.EditContent, upstream.Data.Phase)) {
 
 
 
 
 
 
 
 
337
  continue
338
  }
339
 
340
- // 解析 answer 阶段内容
341
- content := ""
342
- reasoningContent := ""
343
-
344
- // 先输出 thinking 缓冲区剩余内容
345
- if thinkingRemaining := thinkingFilter.Flush(); thinkingRemaining != "" {
346
  hasContent = true
347
  chunk := ChatCompletionChunk{
348
  ID: completionID,
@@ -351,13 +377,39 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
351
  Model: modelName,
352
  Choices: []Choice{{
353
  Index: 0,
354
- Delta: Delta{ReasoningContent: thinkingRemaining},
355
  FinishReason: nil,
356
  }},
357
  }
358
  data, _ := json.Marshal(chunk)
359
  fmt.Fprintf(w, "data: %s\n\n", data)
360
  flusher.Flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  }
362
 
363
  if upstream.Data.Phase == "answer" && upstream.Data.DeltaContent != "" {
@@ -376,6 +428,9 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
376
  }
377
 
378
  // 输出完整思考内容(如果有)
 
 
 
379
  if reasoningContent != "" {
380
  hasContent = true
381
  chunk := ChatCompletionChunk{
@@ -475,6 +530,9 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
475
  var chunks []string
476
  var reasoningChunks []string
477
  thinkingFilter := &ThinkingFilter{}
 
 
 
478
 
479
  for scanner.Scan() {
480
  line := scanner.Text()
@@ -496,8 +554,12 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
496
  break
497
  }
498
 
499
- // 处理思考阶段的增量内容
500
  if upstream.Data.Phase == "thinking" && upstream.Data.DeltaContent != "" {
 
 
 
 
 
501
  reasoningContent := thinkingFilter.ProcessThinking(upstream.Data.DeltaContent)
502
  if reasoningContent != "" {
503
  reasoningChunks = append(reasoningChunks, reasoningContent)
@@ -505,17 +567,27 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
505
  continue
506
  }
507
 
508
- // 跳过搜索结果内容和搜索工具调用
509
- if upstream.Data.EditContent != "" && (IsSearchResultContent(upstream.Data.EditContent) || IsSearchToolCall(upstream.Data.EditContent, upstream.Data.Phase)) {
 
 
 
 
 
 
510
  continue
511
  }
512
 
513
- // 解析 answer 阶段内容
 
 
 
 
 
514
  content := ""
515
  if upstream.Data.Phase == "answer" && upstream.Data.DeltaContent != "" {
516
  content = upstream.Data.DeltaContent
517
  } else if upstream.Data.Phase == "answer" && upstream.Data.EditContent != "" {
518
- // 思考模型首次 answer:提取完整思考内容 + 正常回复开头
519
  if strings.Contains(upstream.Data.EditContent, "</details>") {
520
  reasoningContent := thinkingFilter.ExtractCompleteThinking(upstream.Data.EditContent)
521
  if reasoningContent != "" {
@@ -534,10 +606,10 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
534
  }
535
  }
536
 
537
- // 合并所有内容后统一过滤搜索引用标记
538
  fullContent := strings.Join(chunks, "")
539
- fullContent = searchRefPattern.ReplaceAllString(fullContent, "")
540
  fullReasoning := strings.Join(reasoningChunks, "")
 
541
 
542
  if fullContent == "" {
543
  LogError("Non-stream response 200 but no content received")
 
286
  hasContent := false
287
  searchRefFilter := NewSearchRefFilter()
288
  thinkingFilter := &ThinkingFilter{}
289
+ pendingSourcesMarkdown := ""
290
 
291
  for scanner.Scan() {
292
  line := scanner.Text()
 
312
 
313
  // 处理思考阶段的增量内容
314
  if upstream.Data.Phase == "thinking" && upstream.Data.DeltaContent != "" {
315
+ // 如果有待输出的搜索结果,先输出到 reasoning
316
+ if pendingSourcesMarkdown != "" {
317
+ hasContent = true
318
+ chunk := ChatCompletionChunk{
319
+ ID: completionID,
320
+ Object: "chat.completion.chunk",
321
+ Created: time.Now().Unix(),
322
+ Model: modelName,
323
+ Choices: []Choice{{
324
+ Index: 0,
325
+ Delta: Delta{ReasoningContent: pendingSourcesMarkdown},
326
+ FinishReason: nil,
327
+ }},
328
+ }
329
+ data, _ := json.Marshal(chunk)
330
+ fmt.Fprintf(w, "data: %s\n\n", data)
331
+ flusher.Flush()
332
+ pendingSourcesMarkdown = ""
333
+ }
334
+
335
  reasoningContent := thinkingFilter.ProcessThinking(upstream.Data.DeltaContent)
336
+ reasoningContent = searchRefFilter.Process(reasoningContent)
337
  if reasoningContent != "" {
338
  hasContent = true
339
  chunk := ChatCompletionChunk{
 
354
  continue
355
  }
356
 
357
+ // 解析搜索结果,暂存等待下一个流决定放在哪里
358
+ if upstream.Data.EditContent != "" && IsSearchResultContent(upstream.Data.EditContent) {
359
+ if results := ParseSearchResults(upstream.Data.EditContent); len(results) > 0 {
360
+ searchRefFilter.AddSearchResults(results)
361
+ pendingSourcesMarkdown = searchRefFilter.GetSearchResultsMarkdown()
362
+ }
363
+ continue
364
+ }
365
+ // 跳过搜索工具调用
366
+ if upstream.Data.EditContent != "" && IsSearchToolCall(upstream.Data.EditContent, upstream.Data.Phase) {
367
  continue
368
  }
369
 
370
+ // 进入 answer 阶段,如果有待输出的搜索结果,先输出到 content
371
+ if pendingSourcesMarkdown != "" {
 
 
 
 
372
  hasContent = true
373
  chunk := ChatCompletionChunk{
374
  ID: completionID,
 
377
  Model: modelName,
378
  Choices: []Choice{{
379
  Index: 0,
380
+ Delta: Delta{Content: pendingSourcesMarkdown},
381
  FinishReason: nil,
382
  }},
383
  }
384
  data, _ := json.Marshal(chunk)
385
  fmt.Fprintf(w, "data: %s\n\n", data)
386
  flusher.Flush()
387
+ pendingSourcesMarkdown = ""
388
+ }
389
+
390
+ content := ""
391
+ reasoningContent := ""
392
+
393
+ // 先输出 thinking 缓冲区剩余内容
394
+ if thinkingRemaining := thinkingFilter.Flush(); thinkingRemaining != "" {
395
+ thinkingRemaining = searchRefFilter.Process(thinkingRemaining) + searchRefFilter.Flush()
396
+ if thinkingRemaining != "" {
397
+ hasContent = true
398
+ chunk := ChatCompletionChunk{
399
+ ID: completionID,
400
+ Object: "chat.completion.chunk",
401
+ Created: time.Now().Unix(),
402
+ Model: modelName,
403
+ Choices: []Choice{{
404
+ Index: 0,
405
+ Delta: Delta{ReasoningContent: thinkingRemaining},
406
+ FinishReason: nil,
407
+ }},
408
+ }
409
+ data, _ := json.Marshal(chunk)
410
+ fmt.Fprintf(w, "data: %s\n\n", data)
411
+ flusher.Flush()
412
+ }
413
  }
414
 
415
  if upstream.Data.Phase == "answer" && upstream.Data.DeltaContent != "" {
 
428
  }
429
 
430
  // 输出完整思考内容(如果有)
431
+ if reasoningContent != "" {
432
+ reasoningContent = searchRefFilter.Process(reasoningContent) + searchRefFilter.Flush()
433
+ }
434
  if reasoningContent != "" {
435
  hasContent = true
436
  chunk := ChatCompletionChunk{
 
530
  var chunks []string
531
  var reasoningChunks []string
532
  thinkingFilter := &ThinkingFilter{}
533
+ searchRefFilter := NewSearchRefFilter()
534
+ hasThinking := false
535
+ pendingSourcesMarkdown := ""
536
 
537
  for scanner.Scan() {
538
  line := scanner.Text()
 
554
  break
555
  }
556
 
 
557
  if upstream.Data.Phase == "thinking" && upstream.Data.DeltaContent != "" {
558
+ if pendingSourcesMarkdown != "" {
559
+ reasoningChunks = append(reasoningChunks, pendingSourcesMarkdown)
560
+ pendingSourcesMarkdown = ""
561
+ }
562
+ hasThinking = true
563
  reasoningContent := thinkingFilter.ProcessThinking(upstream.Data.DeltaContent)
564
  if reasoningContent != "" {
565
  reasoningChunks = append(reasoningChunks, reasoningContent)
 
567
  continue
568
  }
569
 
570
+ if upstream.Data.EditContent != "" && IsSearchResultContent(upstream.Data.EditContent) {
571
+ if results := ParseSearchResults(upstream.Data.EditContent); len(results) > 0 {
572
+ searchRefFilter.AddSearchResults(results)
573
+ pendingSourcesMarkdown = searchRefFilter.GetSearchResultsMarkdown()
574
+ }
575
+ continue
576
+ }
577
+ if upstream.Data.EditContent != "" && IsSearchToolCall(upstream.Data.EditContent, upstream.Data.Phase) {
578
  continue
579
  }
580
 
581
+ // 进入 answer 阶段,把待输出的搜索结果放到 content
582
+ if pendingSourcesMarkdown != "" && !hasThinking {
583
+ chunks = append(chunks, pendingSourcesMarkdown)
584
+ pendingSourcesMarkdown = ""
585
+ }
586
+
587
  content := ""
588
  if upstream.Data.Phase == "answer" && upstream.Data.DeltaContent != "" {
589
  content = upstream.Data.DeltaContent
590
  } else if upstream.Data.Phase == "answer" && upstream.Data.EditContent != "" {
 
591
  if strings.Contains(upstream.Data.EditContent, "</details>") {
592
  reasoningContent := thinkingFilter.ExtractCompleteThinking(upstream.Data.EditContent)
593
  if reasoningContent != "" {
 
606
  }
607
  }
608
 
 
609
  fullContent := strings.Join(chunks, "")
610
+ fullContent = searchRefFilter.Process(fullContent) + searchRefFilter.Flush()
611
  fullReasoning := strings.Join(reasoningChunks, "")
612
+ fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
613
 
614
  if fullContent == "" {
615
  LogError("Non-stream response 200 but no content received")
internal/models.go CHANGED
@@ -1,16 +1,19 @@
1
  package internal
2
 
3
  import (
 
 
4
  "regexp"
5
  "strings"
6
  )
7
 
8
  // 基础模型映射(不包含标签后缀)
9
  var BaseModelMapping = map[string]string{
10
- "GLM-4.5": "0727-360B-API",
11
- "GLM-4.6": "GLM-4-6-API-V1",
12
- "GLM-4.5-V": "glm-4.5v",
13
- "GLM-4.5-Air": "0727-106B-API",
 
14
  }
15
 
16
  // v1/models 返回的模型列表(不包含所有标签组合)
@@ -21,6 +24,7 @@ var ModelList = []string{
21
  "GLM-4.6-thinking",
22
  "GLM-4.5-V",
23
  "GLM-4.5-Air",
 
24
  }
25
 
26
  // 解析模型名称,提取基础模型名和标签
@@ -167,38 +171,62 @@ type ModelInfo struct {
167
  OwnedBy string `json:"owned_by"`
168
  }
169
 
170
- // 搜索引用标记正则:【turn数字search数字
171
- var searchRefPattern = regexp.MustCompile(`【turn\d+search\d+】`)
172
-
173
- // 搜索引用标记可能的前缀模式
174
  var searchRefPrefixPattern = regexp.MustCompile(`【(t(u(r(n(\d+(s(e(a(r(c(h(\d+)?)?)?)?)?)?)?)?)?)?)?)?$`)
175
 
176
- // SearchRefFilter 用于跨流过滤搜索引用标记
 
 
 
 
 
 
177
  type SearchRefFilter struct {
178
- buffer string
 
179
  }
180
 
181
- // NewSearchRefFilter 创建新的过滤器
182
  func NewSearchRefFilter() *SearchRefFilter {
183
- return &SearchRefFilter{}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  }
185
 
186
- // Process 处理内容返回以安全输出部分
187
- // 如果末尾有可能是引用标记的前缀,会暂存起来
188
  func (f *SearchRefFilter) Process(content string) string {
189
- // 合并之前暂存的内容
190
  content = f.buffer + content
191
  f.buffer = ""
192
 
193
- // 先移除完整的引用标记
194
- content = searchRefPattern.ReplaceAllString(content, "")
 
 
 
 
 
 
 
 
 
 
195
 
196
  if content == "" {
197
  return ""
198
  }
199
 
200
- // 检查末尾是否有可能是引用标记的前缀
201
- // 从末尾开始,最多检查【turn999search999】长度(约20字符)
202
  maxPrefixLen := 20
203
  if len(content) < maxPrefixLen {
204
  maxPrefixLen = len(content)
@@ -207,7 +235,6 @@ func (f *SearchRefFilter) Process(content string) string {
207
  for i := 1; i <= maxPrefixLen; i++ {
208
  suffix := content[len(content)-i:]
209
  if searchRefPrefixPattern.MatchString(suffix) {
210
- // 找到可能的前缀,暂存起来
211
  f.buffer = suffix
212
  return content[:len(content)-i]
213
  }
@@ -216,18 +243,117 @@ func (f *SearchRefFilter) Process(content string) string {
216
  return content
217
  }
218
 
219
- // Flush 返回所有暂存的内容(流结束时调用)
220
  func (f *SearchRefFilter) Flush() string {
221
  result := f.buffer
222
  f.buffer = ""
 
 
 
 
 
 
 
 
 
 
223
  return result
224
  }
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  // 检查是否为搜索结果内容(需要跳过)
227
  func IsSearchResultContent(editContent string) bool {
228
  return strings.Contains(editContent, `"search_result"`)
229
  }
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  // 检查是否为搜索工具调用内容(需要跳过)
232
  func IsSearchToolCall(editContent string, phase string) bool {
233
  if phase != "tool_call" {
 
1
  package internal
2
 
3
  import (
4
+ "encoding/json"
5
+ "fmt"
6
  "regexp"
7
  "strings"
8
  )
9
 
10
  // 基础模型映射(不包含标签后缀)
11
  var BaseModelMapping = map[string]string{
12
+ "GLM-4.5": "0727-360B-API",
13
+ "GLM-4.6": "GLM-4-6-API-V1",
14
+ "GLM-4.5-V": "glm-4.5v",
15
+ "GLM-4.5-Air": "0727-106B-API",
16
+ "0808-360B-DR": "0808-360B-DR",
17
  }
18
 
19
  // v1/models 返回的模型列表(不包含所有标签组合)
 
24
  "GLM-4.6-thinking",
25
  "GLM-4.5-V",
26
  "GLM-4.5-Air",
27
+ "0808-360B-DR",
28
  }
29
 
30
  // 解析模型名称,提取基础模型名和标签
 
171
  OwnedBy string `json:"owned_by"`
172
  }
173
 
174
+ var searchRefPattern = regexp.MustCompile(`【turn\d+search(\d+)`)
 
 
 
175
  var searchRefPrefixPattern = regexp.MustCompile(`【(t(u(r(n(\d+(s(e(a(r(c(h(\d+)?)?)?)?)?)?)?)?)?)?)?)?$`)
176
 
177
+ type SearchResult struct {
178
+ Title string `json:"title"`
179
+ URL string `json:"url"`
180
+ Index int `json:"index"`
181
+ RefID string `json:"ref_id"`
182
+ }
183
+
184
  type SearchRefFilter struct {
185
+ buffer string
186
+ searchResults map[string]SearchResult
187
  }
188
 
 
189
  func NewSearchRefFilter() *SearchRefFilter {
190
+ return &SearchRefFilter{
191
+ searchResults: make(map[string]SearchResult),
192
+ }
193
+ }
194
+
195
+ func (f *SearchRefFilter) AddSearchResults(results []SearchResult) {
196
+ for _, r := range results {
197
+ f.searchResults[r.RefID] = r
198
+ }
199
+ }
200
+
201
+ func escapeMarkdownTitle(title string) string {
202
+ title = strings.ReplaceAll(title, `\`, `\\`)
203
+ title = strings.ReplaceAll(title, `[`, `\[`)
204
+ title = strings.ReplaceAll(title, `]`, `\]`)
205
+ return title
206
  }
207
 
208
+ // Process 将搜索引用转换为 markdown 链接末尾不完整引用暂存
 
209
  func (f *SearchRefFilter) Process(content string) string {
 
210
  content = f.buffer + content
211
  f.buffer = ""
212
 
213
+ if content == "" {
214
+ return ""
215
+ }
216
+
217
+ content = searchRefPattern.ReplaceAllStringFunc(content, func(match string) string {
218
+ runes := []rune(match)
219
+ refID := string(runes[1 : len(runes)-1])
220
+ if result, ok := f.searchResults[refID]; ok {
221
+ return fmt.Sprintf(`[\[%d\]](%s)`, result.Index, result.URL)
222
+ }
223
+ return ""
224
+ })
225
 
226
  if content == "" {
227
  return ""
228
  }
229
 
 
 
230
  maxPrefixLen := 20
231
  if len(content) < maxPrefixLen {
232
  maxPrefixLen = len(content)
 
235
  for i := 1; i <= maxPrefixLen; i++ {
236
  suffix := content[len(content)-i:]
237
  if searchRefPrefixPattern.MatchString(suffix) {
 
238
  f.buffer = suffix
239
  return content[:len(content)-i]
240
  }
 
243
  return content
244
  }
245
 
 
246
  func (f *SearchRefFilter) Flush() string {
247
  result := f.buffer
248
  f.buffer = ""
249
+ if result != "" {
250
+ result = searchRefPattern.ReplaceAllStringFunc(result, func(match string) string {
251
+ runes := []rune(match)
252
+ refID := string(runes[1 : len(runes)-1])
253
+ if r, ok := f.searchResults[refID]; ok {
254
+ return fmt.Sprintf(`[\[%d\]](%s)`, r.Index, r.URL)
255
+ }
256
+ return ""
257
+ })
258
+ }
259
  return result
260
  }
261
 
262
+ func (f *SearchRefFilter) GetSearchResultsMarkdown() string {
263
+ if len(f.searchResults) == 0 {
264
+ return ""
265
+ }
266
+
267
+ var results []SearchResult
268
+ for _, r := range f.searchResults {
269
+ results = append(results, r)
270
+ }
271
+ for i := 0; i < len(results)-1; i++ {
272
+ for j := i + 1; j < len(results); j++ {
273
+ if results[i].Index > results[j].Index {
274
+ results[i], results[j] = results[j], results[i]
275
+ }
276
+ }
277
+ }
278
+
279
+ var sb strings.Builder
280
+ for _, r := range results {
281
+ escapedTitle := escapeMarkdownTitle(r.Title)
282
+ sb.WriteString(fmt.Sprintf("[\\[%d\\] %s](%s)\n", r.Index, escapedTitle, r.URL))
283
+ }
284
+ sb.WriteString("\n")
285
+ return sb.String()
286
+ }
287
+
288
  // 检查是否为搜索结果内容(需要跳过)
289
  func IsSearchResultContent(editContent string) bool {
290
  return strings.Contains(editContent, `"search_result"`)
291
  }
292
 
293
+ // ParseSearchResults 从 edit_content 中解析搜索结果
294
+ func ParseSearchResults(editContent string) []SearchResult {
295
+ // 查找 "search_result": 的位置
296
+ searchResultKey := `"search_result":`
297
+ idx := strings.Index(editContent, searchResultKey)
298
+ if idx == -1 {
299
+ return nil
300
+ }
301
+
302
+ // 找到 [ 开始的位置
303
+ startIdx := idx + len(searchResultKey)
304
+ for startIdx < len(editContent) && editContent[startIdx] != '[' {
305
+ startIdx++
306
+ }
307
+ if startIdx >= len(editContent) {
308
+ return nil
309
+ }
310
+
311
+ // 找到匹配的 ] 结束位置
312
+ bracketCount := 0
313
+ endIdx := startIdx
314
+ for endIdx < len(editContent) {
315
+ if editContent[endIdx] == '[' {
316
+ bracketCount++
317
+ } else if editContent[endIdx] == ']' {
318
+ bracketCount--
319
+ if bracketCount == 0 {
320
+ endIdx++
321
+ break
322
+ }
323
+ }
324
+ endIdx++
325
+ }
326
+
327
+ if bracketCount != 0 {
328
+ return nil
329
+ }
330
+
331
+ // 解析 JSON 数组
332
+ jsonStr := editContent[startIdx:endIdx]
333
+ var rawResults []struct {
334
+ Title string `json:"title"`
335
+ URL string `json:"url"`
336
+ Index int `json:"index"`
337
+ RefID string `json:"ref_id"`
338
+ }
339
+
340
+ if err := json.Unmarshal([]byte(jsonStr), &rawResults); err != nil {
341
+ return nil
342
+ }
343
+
344
+ var results []SearchResult
345
+ for _, r := range rawResults {
346
+ results = append(results, SearchResult{
347
+ Title: r.Title,
348
+ URL: r.URL,
349
+ Index: r.Index,
350
+ RefID: r.RefID,
351
+ })
352
+ }
353
+
354
+ return results
355
+ }
356
+
357
  // 检查是否为搜索工具调用内容(需要跳过)
358
  func IsSearchToolCall(editContent string, phase string) bool {
359
  if phase != "tool_call" {