File size: 9,044 Bytes
750bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
package reasoning

import (
	"strings"
)

// DetectThinkingStartToken checks if the prompt or template contains a thinking start token
// and returns the detected token. This indicates that the model's prompt template
// already includes the thinking token, so the model output will start with reasoning
// content without an explicit opening tag.
// Returns the detected token if found, empty string otherwise.
// Common tokens checked (in order of specificity - longer first):
// Based on llama.cpp's chat-parser.cpp implementations:
// - <|START_THINKING|>      (Command-R models)
// - <|inner_prefix|>        (Apertus models)
// - <seed:think>            (Seed models)
// - <think>    (DeepSeek, Granite, ExaOne models)
// - <|think|>               (Solar Open models)
// - <thinking>              (General thinking tag)
// - [THINK]                 (Magistral models)
// Custom tokens from config are checked first, then default tokens.
func DetectThinkingStartToken(prompt string, config *Config) string {
	// Common thinking start tokens (in order of specificity - longer first)
	// Based on llama.cpp's chat-parser.cpp implementations
	defaultTokens := []string{
		"<|START_THINKING|>", // Command-R models
		"<|inner_prefix|>",   // Apertus models
		"<seed:think>",       // Seed models
		"<think>",            // DeepSeek, Granite, ExaOne models
		"<|think|>",          // Solar Open models
		"<thinking>",         // General thinking tag
		"[THINK]",            // Magistral models
	}

	// Merge custom tokens with default tokens (custom tokens first for priority)
	var thinkingStartTokens []string
	if config != nil && len(config.ThinkingStartTokens) > 0 {
		thinkingStartTokens = append(thinkingStartTokens, config.ThinkingStartTokens...)
	}
	thinkingStartTokens = append(thinkingStartTokens, defaultTokens...)

	// Check if prompt ends with any of these tokens (allowing for trailing whitespace/newlines)
	trimmedPrompt := strings.TrimRight(prompt, " \t\n\r")
	for _, token := range thinkingStartTokens {
		if strings.Contains(trimmedPrompt, token) {
			return token
		}
	}

	// Also check if any of these tokens appear near the end (within last 100 chars)
	// This handles cases where there might be stop tokens or other content after
	if len(trimmedPrompt) > 100 {
		lastPart := trimmedPrompt[len(trimmedPrompt)-100:]
		for _, token := range thinkingStartTokens {
			if idx := strings.LastIndex(lastPart, token); idx != -1 {
				// Check if this is the last meaningful content (only whitespace after)
				afterToken := lastPart[idx+len(token):]
				if strings.TrimSpace(afterToken) == "" {
					return token
				}
			}
		}
	}

	return ""
}

// ExtractReasoningWithConfig extracts reasoning from content with the given config.
// If reasoning is disabled, it returns the original content.
// If thinking start token prefill is enabled, it prepends the thinking start token to the content.
// It returns the extracted reasoning and the cleaned content.
func ExtractReasoningWithConfig(content, thinkingStartToken string, config Config) (reasoning string, cleanedContent string) {
	cleanedContent = content
	// If reasoning is not disabled, prepend the thinking start token if needed and extract reasoning
	if config.DisableReasoning == nil || !*config.DisableReasoning {
		// If thinking start token prefill is not disabled, prepend the thinking start token
		if config.DisableReasoningTagPrefill == nil || !*config.DisableReasoningTagPrefill {
			cleanedContent = PrependThinkingTokenIfNeeded(cleanedContent, thinkingStartToken)
		}
		// Extract reasoning from the cleaned content
		reasoning, cleanedContent = ExtractReasoning(cleanedContent, &config)
		if config.StripReasoningOnly != nil && *config.StripReasoningOnly {
			reasoning = ""
		}
	}

	return reasoning, cleanedContent
}

// PrependThinkingTokenIfNeeded prepends the thinking start token to content if it was
// detected in the prompt. This allows the standard extraction logic to work correctly
// for models where the thinking token is already in the prompt.
func PrependThinkingTokenIfNeeded(content string, startToken string) string {
	if startToken == "" {
		return content
	}

	// Check if content already starts with the token (allowing for leading whitespace)
	trimmed := strings.TrimLeftFunc(content, func(r rune) bool {
		return r == ' ' || r == '\t' || r == '\n' || r == '\r'
	})

	// If content already starts with the token, don't prepend
	if strings.Contains(trimmed, startToken) {
		return content
	}

	// Find where leading whitespace ends
	whitespaceEnd := 0
	for whitespaceEnd < len(content) {
		r := content[whitespaceEnd]
		if r != ' ' && r != '\t' && r != '\n' && r != '\r' {
			break
		}
		whitespaceEnd++
	}

	// Prepend the token after whitespace to make it look like normal tagged content
	if whitespaceEnd > 0 {
		return content[:whitespaceEnd] + startToken + content[whitespaceEnd:]
	}
	return startToken + content
}

// ExtractReasoning extracts reasoning content from thinking tags and returns
// both the extracted reasoning and the cleaned content (with tags removed).
// It handles <thinking>...</thinking> and <think>...</think> tags.
// Multiple reasoning blocks are concatenated with newlines.
// Custom tag pairs from config are checked first, then default tag pairs.
func ExtractReasoning(content string, config *Config) (reasoning string, cleanedContent string) {
	if content == "" {
		return "", content
	}

	var reasoningParts []string
	var cleanedParts []string
	remaining := content

	// Define default tag pairs to look for (matching llama.cpp's chat-parser.cpp)
	defaultTagPairs := []struct {
		start string
		end   string
	}{
		{"<|START_THINKING|>", "<|END_THINKING|>"},            // Command-R models
		{"<|inner_prefix|>", "<|inner_suffix|>"},              // Apertus models
		{"<seed:think>", "</seed:think>"},                     // Seed models
		{"<think>", "</think>"},                               // DeepSeek, Granite, ExaOne models
		{"<|think|>", "<|end|><|begin|>assistant<|content|>"}, // Solar Open models (complex end)
		{"<thinking>", "</thinking>"},                         // General thinking tag
		{"[THINK]", "[/THINK]"},                               // Magistral models
	}

	// Merge custom tag pairs with default tag pairs (custom pairs first for priority)
	var tagPairs []struct {
		start string
		end   string
	}
	if config != nil && len(config.TagPairs) > 0 {
		for _, pair := range config.TagPairs {
			if pair.Start != "" && pair.End != "" {
				tagPairs = append(tagPairs, struct {
					start string
					end   string
				}{pair.Start, pair.End})
			}
		}
	}
	// Add default tag pairs
	for _, pair := range defaultTagPairs {
		tagPairs = append(tagPairs, pair)
	}

	// Track the last position we've processed
	lastPos := 0

	for {
		// Find the earliest tag start
		earliestStart := -1
		earliestEnd := -1
		isUnclosed := false
		var matchedTag struct {
			start string
			end   string
		}

		for _, tagPair := range tagPairs {
			startIdx := strings.Index(remaining[lastPos:], tagPair.start)
			if startIdx == -1 {
				continue
			}
			startIdx += lastPos

			// Find the corresponding end tag
			endIdx := strings.Index(remaining[startIdx+len(tagPair.start):], tagPair.end)
			if endIdx == -1 {
				// Unclosed tag - extract what we have
				if earliestStart == -1 || startIdx < earliestStart {
					earliestStart = startIdx
					earliestEnd = len(remaining)
					isUnclosed = true
					matchedTag = tagPair
				}
				continue
			}
			endIdx += startIdx + len(tagPair.start)

			// Found a complete tag pair
			if earliestStart == -1 || startIdx < earliestStart {
				earliestStart = startIdx
				earliestEnd = endIdx + len(tagPair.end)
				isUnclosed = false
				matchedTag = tagPair
			}
		}

		if earliestStart == -1 {
			// No more tags found, add remaining content
			if lastPos < len(remaining) {
				cleanedParts = append(cleanedParts, remaining[lastPos:])
			}
			break
		}

		// Add content before the tag
		if earliestStart > lastPos {
			cleanedParts = append(cleanedParts, remaining[lastPos:earliestStart])
		}

		// Extract reasoning content
		reasoningStart := earliestStart + len(matchedTag.start)
		// For unclosed tags, earliestEnd is already at the end of the string
		// For closed tags, earliestEnd points to after the closing tag, so we subtract the end tag length
		var reasoningEnd int
		if isUnclosed {
			// Unclosed tag - extract everything to the end
			reasoningEnd = len(remaining)
		} else {
			// Closed tag - exclude the end tag
			reasoningEnd = earliestEnd - len(matchedTag.end)
		}
		if reasoningEnd > reasoningStart {
			reasoningContent := strings.TrimSpace(remaining[reasoningStart:reasoningEnd])
			if reasoningContent != "" {
				reasoningParts = append(reasoningParts, reasoningContent)
			}
		}

		// Move past this tag
		lastPos = earliestEnd
	}

	// Combine reasoning parts
	reasoning = strings.Join(reasoningParts, "\n\n")
	// Combine cleaned content parts
	cleanedContent = strings.Join(cleanedParts, "")

	return reasoning, cleanedContent
}