""" Anti-Truncation Module - Ensures complete streaming output 保持一个流式请求内完整输出的反截断模块 """ import io import json import re from typing import Any, AsyncGenerator, Dict, List, Tuple from fastapi.responses import StreamingResponse from log import log # 反截断配置 DONE_MARKER = "[done]" CONTINUATION_PROMPT = f"""请从刚才被截断的地方继续输出剩余的所有内容。 重要提醒: 1. 不要重复前面已经输出的内容 2. 直接继续输出,无需任何前言或解释 3. 当你完整完成所有内容输出后,必须在最后一行单独输出:{DONE_MARKER} 4. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 现在请继续输出:""" # 正则替换配置 REGEX_REPLACEMENTS: List[Tuple[str, str, str]] = [ ( "age_pattern", # 替换规则名称 r"(?:[1-9]|1[0-8])岁(?:的)?|(?:十一|十二|十三|十四|十五|十六|十七|十八|十|一|二|三|四|五|六|七|八|九)岁(?:的)?", # 正则模式 "", # 替换文本 ), # 可在此处添加更多替换规则 # ("rule_name", r"pattern", "replacement"), ] def apply_regex_replacements(text: str) -> str: """ 对文本应用正则替换规则 Args: text: 要处理的文本 Returns: 处理后的文本 """ if not text: return text processed_text = text replacement_count = 0 for rule_name, pattern, replacement in REGEX_REPLACEMENTS: try: # 编译正则表达式,使用IGNORECASE标志 regex = re.compile(pattern, re.IGNORECASE) # 执行替换 new_text, count = regex.subn(replacement, processed_text) if count > 0: log.debug(f"Regex replacement '{rule_name}': {count} matches replaced") processed_text = new_text replacement_count += count except re.error as e: log.error(f"Invalid regex pattern in rule '{rule_name}': {e}") continue if replacement_count > 0: log.info(f"Applied {replacement_count} regex replacements to text") return processed_text def apply_regex_replacements_to_payload(payload: Dict[str, Any]) -> Dict[str, Any]: """ 对请求payload中的文本内容应用正则替换 Args: payload: 请求payload Returns: 应用替换后的payload """ if not REGEX_REPLACEMENTS: return payload modified_payload = payload.copy() request_data = modified_payload.get("request", {}) # 处理contents中的文本 contents = request_data.get("contents", []) if contents: new_contents = [] for content in contents: if isinstance(content, dict): new_content = content.copy() parts = new_content.get("parts", []) if parts: new_parts = [] for part in parts: if isinstance(part, dict) and "text" in part: new_part = part.copy() new_part["text"] = apply_regex_replacements(part["text"]) new_parts.append(new_part) else: new_parts.append(part) new_content["parts"] = new_parts new_contents.append(new_content) else: new_contents.append(content) request_data["contents"] = new_contents modified_payload["request"] = request_data log.debug("Applied regex replacements to request contents") return modified_payload def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]: """ 对请求payload应用反截断处理和正则替换 在systemInstruction中添加提醒,要求模型在结束时输出DONE_MARKER标记 Args: payload: 原始请求payload Returns: 添加了反截断指令并应用了正则替换的payload """ # 首先应用正则替换 modified_payload = apply_regex_replacements_to_payload(payload) request_data = modified_payload.get("request", {}) # 获取或创建systemInstruction system_instruction = request_data.get("systemInstruction", {}) if not system_instruction: system_instruction = {"parts": []} elif "parts" not in system_instruction: system_instruction["parts"] = [] # 添加反截断指令 anti_truncation_instruction = { "text": f"""严格执行以下输出结束规则: 1. 当你完成完整回答时,必须在输出的最后单独一行输出:{DONE_MARKER} 2. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 3. 只有输出了 {DONE_MARKER} 标记,系统才认为你的回答是完整的 4. 如果你的回答被截断,系统会要求你继续输出剩余内容 5. 无论回答长短,都必须以 {DONE_MARKER} 标记结束 示例格式: ``` 你的回答内容... 更多回答内容... {DONE_MARKER} ``` 注意:{DONE_MARKER} 必须单独占一行,前面不要有任何其他字符。 这个规则对于确保输出完整性极其重要,请严格遵守。""" } # 检查是否已经包含反截断指令 has_done_instruction = any( part.get("text", "").find(DONE_MARKER) != -1 for part in system_instruction["parts"] if isinstance(part, dict) ) if not has_done_instruction: system_instruction["parts"].append(anti_truncation_instruction) request_data["systemInstruction"] = system_instruction modified_payload["request"] = request_data log.debug("Applied anti-truncation instruction to request") return modified_payload class AntiTruncationStreamProcessor: """反截断流式处理器""" def __init__( self, original_request_func, payload: Dict[str, Any], max_attempts: int = 3, enable_prefill_mode: bool = False, ): self.original_request_func = original_request_func self.base_payload = payload.copy() self.max_attempts = max_attempts self.enable_prefill_mode = enable_prefill_mode # 使用 StringIO 避免字符串拼接的内存问题 self.collected_content = io.StringIO() self.current_attempt = 0 def _get_collected_text(self) -> str: """获取收集的文本内容""" return self.collected_content.getvalue() def _append_content(self, content: str): """追加内容到收集器""" if content: self.collected_content.write(content) def _clear_content(self): """清空收集的内容,释放内存""" self.collected_content.close() self.collected_content = io.StringIO() async def process_stream(self) -> AsyncGenerator[bytes, None]: """处理流式响应,检测并处理截断""" while self.current_attempt < self.max_attempts: self.current_attempt += 1 # 构建当前请求payload current_payload = self._build_current_payload() log.debug(f"Anti-truncation attempt {self.current_attempt}/{self.max_attempts}") # 发送请求 try: response = await self.original_request_func(current_payload) if not isinstance(response, StreamingResponse): # 非流式响应,直接处理 yield await self._handle_non_streaming_response(response) return # 处理流式响应(按行处理) chunk_buffer = io.StringIO() # 使用 StringIO 缓存当前轮次的chunk found_done_marker = False async for line in response.body_iterator: if not line: yield line continue # 处理上游生成器 yield 出 Response 对象的情况(错误响应) from fastapi import Response as FastAPIResponse if isinstance(line, FastAPIResponse): log.error(f"Anti-truncation: Received Response object from stream (status={line.status_code}), treating as error") error_chunk = { "error": { "message": line.body.decode('utf-8', errors='ignore') if hasattr(line, 'body') and line.body else "Upstream error", "type": "api_error", "code": line.status_code, } } yield f"data: {json.dumps(error_chunk)}\n\n".encode() yield b"data: [DONE]\n\n" return # 处理 bytes 类型的流式数据 if isinstance(line, bytes): # 解码 bytes 为字符串 line_str = line.decode('utf-8', errors='ignore').strip() else: line_str = str(line).strip() # 跳过空行 if not line_str: yield line continue # 处理 SSE 格式的数据行 if line_str.startswith("data: "): payload_str = line_str[6:] # 去掉 "data: " 前缀 # 检查是否是 [DONE] 标记 if payload_str.strip() == "[DONE]": if found_done_marker: log.info("Anti-truncation: Found [done] marker, output complete") yield line # 清理内存 chunk_buffer.close() self._clear_content() return else: log.warning("Anti-truncation: Stream ended without [done] marker") # 不发送[DONE],准备继续 break # 尝试解析 JSON 数据 try: data = json.loads(payload_str) content = self._extract_content_from_chunk(data) log.debug(f"Anti-truncation: Extracted content: {repr(content[:100] if content else '')}") if content: chunk_buffer.write(content) # 检查是否包含done标记 has_marker = self._check_done_marker_in_chunk_content(content) log.debug(f"Anti-truncation: Check done marker result: {has_marker}, DONE_MARKER='{DONE_MARKER}'") if has_marker: found_done_marker = True log.debug(f"Anti-truncation: Found [done] marker in chunk, content: {content[:200]}") # 清理行中的[done]标记后再发送 cleaned_line = self._remove_done_marker_from_line(line, line_str, data) yield cleaned_line except (json.JSONDecodeError, ValueError): # 无法解析的行,直接传递 yield line continue else: # 非 data: 开头的行,直接传递 yield line # 更新收集的内容 - 使用 StringIO 高效处理 chunk_text = chunk_buffer.getvalue() if chunk_text: self._append_content(chunk_text) chunk_buffer.close() log.debug(f"Anti-truncation: After processing stream, found_done_marker={found_done_marker}") # 如果找到了done标记,结束 if found_done_marker: # 立即清理内容释放内存 self._clear_content() yield b"data: [DONE]\n\n" return # 只有在单个chunk中没有找到done标记时,才检查累积内容(防止done标记跨chunk出现) if not found_done_marker: accumulated_text = self._get_collected_text() if self._check_done_marker_in_text(accumulated_text): log.info("Anti-truncation: Found [done] marker in accumulated content") # 立即清理内容释放内存 self._clear_content() yield b"data: [DONE]\n\n" return # 如果没找到done标记且不是最后一次尝试,准备续传 if self.current_attempt < self.max_attempts: accumulated_text = self._get_collected_text() total_length = len(accumulated_text) log.info( f"Anti-truncation: No [done] marker found in output (length: {total_length}), preparing continuation (attempt {self.current_attempt + 1})" ) if total_length > 100: log.debug( f"Anti-truncation: Current collected content ends with: ...{accumulated_text[-100:]}" ) # 在下一次循环中会继续 continue else: # 最后一次尝试,直接结束 log.warning("Anti-truncation: Max attempts reached, ending stream") # 立即清理内容释放内存 self._clear_content() yield b"data: [DONE]\n\n" return except Exception as e: log.error(f"Anti-truncation error in attempt {self.current_attempt}: {str(e)}") if self.current_attempt >= self.max_attempts: # 发送错误chunk error_chunk = { "error": { "message": f"Anti-truncation failed: {str(e)}", "type": "api_error", "code": 500, } } yield f"data: {json.dumps(error_chunk)}\n\n".encode() yield b"data: [DONE]\n\n" return # 否则继续下一次尝试 # 如果所有尝试都失败了 log.error("Anti-truncation: All attempts failed") # 清理内存 self._clear_content() yield b"data: [DONE]\n\n" def _build_current_payload(self) -> Dict[str, Any]: """构建当前请求的payload""" if self.current_attempt == 1: # 第一次请求,使用原始payload(已经包含反截断指令) return self.base_payload # 后续请求,添加续传指令 continuation_payload = self.base_payload.copy() request_data = continuation_payload.get("request", {}) # 获取原始对话内容 contents = request_data.get("contents", []) new_contents = contents.copy() # 如果有收集到的内容,添加到对话中 accumulated_text = self._get_collected_text() if accumulated_text: new_contents.append({"role": "model", "parts": [{"text": accumulated_text}]}) # 预填充模式:直接用拼接内容作为末尾 model 预填充,不再增加 user 续写指令 if self.enable_prefill_mode: log.debug("Anti-truncation: Using prefill continuation mode (no user continuation prompt)") request_data["contents"] = new_contents continuation_payload["request"] = request_data return continuation_payload # 构建具体的续写指令,包含前面的内容摘要 content_summary = "" if accumulated_text: if len(accumulated_text) > 200: content_summary = f'\n\n前面你已经输出了约 {len(accumulated_text)} 个字符的内容,结尾是:\n"...{accumulated_text[-100:]}"' else: content_summary = f'\n\n前面你已经输出的内容是:\n"{accumulated_text}"' detailed_continuation_prompt = f"""{CONTINUATION_PROMPT}{content_summary}""" # 添加继续指令 continuation_message = {"role": "user", "parts": [{"text": detailed_continuation_prompt}]} new_contents.append(continuation_message) request_data["contents"] = new_contents continuation_payload["request"] = request_data return continuation_payload def _extract_content_from_chunk(self, data: Dict[str, Any]) -> str: """从chunk数据中提取文本内容""" content = "" # 先尝试解包 response 字段(Gemini API 格式) if "response" in data: data = data["response"] # 处理 Gemini 格式 if "candidates" in data: for candidate in data["candidates"]: if "content" in candidate: parts = candidate["content"].get("parts", []) for part in parts: if "text" in part: content += part["text"] # 处理 OpenAI 流式格式(choices/delta) elif "choices" in data: for choice in data["choices"]: if "delta" in choice and "content" in choice["delta"]: delta_content = choice["delta"]["content"] if delta_content: content += delta_content return content async def _handle_non_streaming_response(self, response) -> bytes: """处理非流式响应 - 使用循环代替递归避免栈溢出""" # 使用循环代替递归 while True: try: # 特殊处理:如果返回的是StreamingResponse,需要读取其body_iterator if isinstance(response, StreamingResponse): log.error("Anti-truncation: Received StreamingResponse in non-streaming handler - this should not happen") # 尝试读取流式响应的内容 chunks = [] async for chunk in response.body_iterator: chunks.append(chunk) content = b"".join(chunks).decode() if chunks else "" # 提取响应内容 elif hasattr(response, "body"): content = ( response.body.decode() if isinstance(response.body, bytes) else response.body ) elif hasattr(response, "content"): content = ( response.content.decode() if isinstance(response.content, bytes) else response.content ) else: log.error(f"Anti-truncation: Unknown response type: {type(response)}") content = str(response) # 验证内容不为空 if not content or not content.strip(): log.error("Anti-truncation: Received empty response content") return json.dumps( { "error": { "message": "Empty response from server", "type": "api_error", "code": 500, } } ).encode() # 尝试解析 JSON try: response_data = json.loads(content) except json.JSONDecodeError as json_err: log.error(f"Anti-truncation: Failed to parse JSON response: {json_err}, content: {content[:200]}") # 如果不是 JSON,直接返回原始内容 return content.encode() if isinstance(content, str) else content # 检查是否包含done标记 text_content = self._extract_content_from_response(response_data) has_done_marker = self._check_done_marker_in_text(text_content) if has_done_marker or self.current_attempt >= self.max_attempts: # 找到done标记或达到最大尝试次数,返回结果 return content.encode() if isinstance(content, str) else content # 需要继续,收集内容并构建下一个请求 if text_content: self._append_content(text_content) log.info("Anti-truncation: Non-streaming response needs continuation") # 增加尝试次数 self.current_attempt += 1 # 构建续传payload并发送下一个请求 next_payload = self._build_current_payload() response = await self.original_request_func(next_payload) # 继续循环处理下一个响应 except Exception as e: log.error(f"Anti-truncation non-streaming error: {str(e)}") return json.dumps( { "error": { "message": f"Anti-truncation failed: {str(e)}", "type": "api_error", "code": 500, } } ).encode() def _check_done_marker_in_text(self, text: str) -> bool: """检测文本中是否包含DONE_MARKER(只检测指定标记)""" if not text: return False # 只要文本中出现DONE_MARKER即可 return DONE_MARKER in text def _check_done_marker_in_chunk_content(self, content: str) -> bool: """检查单个chunk内容中是否包含done标记""" return self._check_done_marker_in_text(content) def _extract_content_from_response(self, data: Dict[str, Any]) -> str: """从响应数据中提取文本内容""" content = "" # 先尝试解包 response 字段(Gemini API 格式) if "response" in data: data = data["response"] # 处理Gemini格式 if "candidates" in data: for candidate in data["candidates"]: if "content" in candidate: parts = candidate["content"].get("parts", []) for part in parts: if "text" in part: content += part["text"] # 处理OpenAI格式 elif "choices" in data: for choice in data["choices"]: if "message" in choice and "content" in choice["message"]: content += choice["message"]["content"] return content def _remove_done_marker_from_line(self, line: bytes, line_str: str, data: Dict[str, Any]) -> bytes: """从行中移除[done]标记""" try: # 首先检查是否真的包含[done]标记 if "[done]" not in line_str.lower(): return line # 没有[done]标记,直接返回原始行 log.info(f"Anti-truncation: Attempting to remove [done] marker from line") log.debug(f"Anti-truncation: Original line (first 200 chars): {line_str[:200]}") # 编译正则表达式,匹配[done]标记(忽略大小写,包括可能的空白字符) done_pattern = re.compile(r"\s*\[done\]\s*", re.IGNORECASE) # 检查是否有 response 包裹层 has_response_wrapper = "response" in data log.debug(f"Anti-truncation: has_response_wrapper={has_response_wrapper}, data keys={list(data.keys())}") if has_response_wrapper: # 需要保留外层的 response 字段 inner_data = data["response"] else: inner_data = data log.debug(f"Anti-truncation: inner_data keys={list(inner_data.keys())}") log.debug(f"Anti-truncation: inner_data keys={list(inner_data.keys())}") # 处理Gemini格式 if "candidates" in inner_data: log.info(f"Anti-truncation: Processing Gemini format to remove [done] marker") modified_inner = inner_data.copy() modified_inner["candidates"] = [] for i, candidate in enumerate(inner_data["candidates"]): modified_candidate = candidate.copy() # 只在最后一个candidate中清理[done]标记 is_last_candidate = i == len(inner_data["candidates"]) - 1 if "content" in candidate: modified_content = candidate["content"].copy() if "parts" in modified_content: modified_parts = [] for part in modified_content["parts"]: if "text" in part and isinstance(part["text"], str): modified_part = part.copy() original_text = part["text"] # 只在最后一个candidate中清理[done]标记 if is_last_candidate: modified_part["text"] = done_pattern.sub("", part["text"]) if "[done]" in original_text.lower(): log.debug(f"Anti-truncation: Removed [done] from text: '{original_text[:100]}' -> '{modified_part['text'][:100]}'") modified_parts.append(modified_part) else: modified_parts.append(part) modified_content["parts"] = modified_parts modified_candidate["content"] = modified_content modified_inner["candidates"].append(modified_candidate) # 如果有 response 包裹层,需要重新包装 if has_response_wrapper: modified_data = data.copy() modified_data["response"] = modified_inner else: modified_data = modified_inner # 重新编码为行格式 - SSE格式需要两个换行符 json_str = json.dumps(modified_data, separators=(",", ":"), ensure_ascii=False) result = f"data: {json_str}\n\n".encode("utf-8") log.debug(f"Anti-truncation: Modified line (first 200 chars): {result.decode('utf-8', errors='ignore')[:200]}") return result # 处理OpenAI格式 elif "choices" in inner_data: modified_inner = inner_data.copy() modified_inner["choices"] = [] for choice in inner_data["choices"]: modified_choice = choice.copy() if "delta" in choice and "content" in choice["delta"]: modified_delta = choice["delta"].copy() modified_delta["content"] = done_pattern.sub("", choice["delta"]["content"]) modified_choice["delta"] = modified_delta elif "message" in choice and "content" in choice["message"]: modified_message = choice["message"].copy() modified_message["content"] = done_pattern.sub("", choice["message"]["content"]) modified_choice["message"] = modified_message modified_inner["choices"].append(modified_choice) # 如果有 response 包裹层,需要重新包装 if has_response_wrapper: modified_data = data.copy() modified_data["response"] = modified_inner else: modified_data = modified_inner # 重新编码为行格式 - SSE格式需要两个换行符 json_str = json.dumps(modified_data, separators=(",", ":"), ensure_ascii=False) return f"data: {json_str}\n\n".encode("utf-8") # 如果没有找到支持的格式,返回原始行 return line except Exception as e: log.warning(f"Failed to remove [done] marker from line: {str(e)}") return line async def apply_anti_truncation_to_stream( request_func, payload: Dict[str, Any], max_attempts: int = 3, enable_prefill_mode: bool = False, ) -> StreamingResponse: """ 对流式请求应用反截断处理 Args: request_func: 原始请求函数 payload: 请求payload max_attempts: 最大续传尝试次数 enable_prefill_mode: 是否启用预填充模式。启用后续传请求不再添加 user 续写指令, 而是将已收集内容作为末尾 model 内容进行预填充 Returns: 处理后的StreamingResponse """ # 首先对payload应用反截断指令 anti_truncation_payload = apply_anti_truncation(payload) # 创建反截断处理器 processor = AntiTruncationStreamProcessor( lambda p: request_func(p), anti_truncation_payload, max_attempts, enable_prefill_mode, ) # 返回包装后的流式响应 return StreamingResponse(processor.process_stream(), media_type="text/event-stream") def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool: """ 检查请求是否启用了反截断功能 Args: request_data: 请求数据 Returns: 是否启用反截断 """ return request_data.get("enable_anti_truncation", False)