Spaces:
Sleeping
Sleeping
| """ | |
| Anti-Truncation Module - Ensures complete streaming output | |
| 保持一个流式请求内完整输出的反截断模块 | |
| """ | |
| import json | |
| import re | |
| from typing import Dict, Any, AsyncGenerator, List, Tuple | |
| from fastapi.responses import StreamingResponse | |
| from log import log | |
| # 反截断配置 | |
| DONE_MARKER = "[done]" | |
| MAX_CONTINUATION_ATTEMPTS = 3 | |
| CONTINUATION_PROMPT = f"""请从刚才被截断的地方继续输出剩余的所有内容。 | |
| 重要提醒: | |
| 1. 不要重复前面已经输出的内容 | |
| 2. 直接继续输出,无需任何前言或解释 | |
| 3. 当你完整完成所有内容输出后,必须在最后一行单独输出:{DONE_MARKER} | |
| 4. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 | |
| 现在请继续输出:""" | |
| # 正则替换配置 | |
| REGEX_REPLACEMENTS: List[Tuple[str, str, str]] = [ | |
| ( | |
| "age_pattern", # 替换规则名称 | |
| r"(?:[1-9]|1[0-8])岁(?:的)?|(?:十一|十二|十三|十四|十五|十六|十七|十八|十|一|二|三|四|五|六|七|八|九)岁(?:的)?", # 正则模式 | |
| "" # 替换文本 | |
| ), | |
| # 可在此处添加更多替换规则 | |
| # ("rule_name", r"pattern", "replacement"), | |
| ] | |
| def apply_regex_replacements(text: str) -> str: | |
| """ | |
| 对文本应用正则替换规则 | |
| Args: | |
| text: 要处理的文本 | |
| Returns: | |
| 处理后的文本 | |
| """ | |
| if not text: | |
| return text | |
| processed_text = text | |
| replacement_count = 0 | |
| for rule_name, pattern, replacement in REGEX_REPLACEMENTS: | |
| try: | |
| # 编译正则表达式,使用IGNORECASE标志 | |
| regex = re.compile(pattern, re.IGNORECASE) | |
| # 执行替换 | |
| new_text, count = regex.subn(replacement, processed_text) | |
| if count > 0: | |
| log.debug(f"Regex replacement '{rule_name}': {count} matches replaced") | |
| processed_text = new_text | |
| replacement_count += count | |
| except re.error as e: | |
| log.error(f"Invalid regex pattern in rule '{rule_name}': {e}") | |
| continue | |
| if replacement_count > 0: | |
| log.info(f"Applied {replacement_count} regex replacements to text") | |
| return processed_text | |
| def apply_regex_replacements_to_payload(payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| 对请求payload中的文本内容应用正则替换 | |
| Args: | |
| payload: 请求payload | |
| Returns: | |
| 应用替换后的payload | |
| """ | |
| if not REGEX_REPLACEMENTS: | |
| return payload | |
| modified_payload = payload.copy() | |
| request_data = modified_payload.get("request", {}) | |
| # 处理contents中的文本 | |
| contents = request_data.get("contents", []) | |
| if contents: | |
| new_contents = [] | |
| for content in contents: | |
| if isinstance(content, dict): | |
| new_content = content.copy() | |
| parts = new_content.get("parts", []) | |
| if parts: | |
| new_parts = [] | |
| for part in parts: | |
| if isinstance(part, dict) and "text" in part: | |
| new_part = part.copy() | |
| new_part["text"] = apply_regex_replacements(part["text"]) | |
| new_parts.append(new_part) | |
| else: | |
| new_parts.append(part) | |
| new_content["parts"] = new_parts | |
| new_contents.append(new_content) | |
| else: | |
| new_contents.append(content) | |
| request_data["contents"] = new_contents | |
| modified_payload["request"] = request_data | |
| log.debug("Applied regex replacements to request contents") | |
| return modified_payload | |
| def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| 对请求payload应用反截断处理和正则替换 | |
| 在systemInstruction中添加提醒,要求模型在结束时输出DONE_MARKER标记 | |
| Args: | |
| payload: 原始请求payload | |
| Returns: | |
| 添加了反截断指令并应用了正则替换的payload | |
| """ | |
| # 首先应用正则替换 | |
| modified_payload = apply_regex_replacements_to_payload(payload) | |
| request_data = modified_payload.get("request", {}) | |
| # 获取或创建systemInstruction | |
| system_instruction = request_data.get("systemInstruction", {}) | |
| if not system_instruction: | |
| system_instruction = {"parts": []} | |
| elif "parts" not in system_instruction: | |
| system_instruction["parts"] = [] | |
| # 添加反截断指令 | |
| anti_truncation_instruction = { | |
| "text": f"""严格执行以下输出结束规则: | |
| 1. 当你完成完整回答时,必须在输出的最后单独一行输出:{DONE_MARKER} | |
| 2. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 | |
| 3. 只有输出了 {DONE_MARKER} 标记,系统才认为你的回答是完整的 | |
| 4. 如果你的回答被截断,系统会要求你继续输出剩余内容 | |
| 5. 无论回答长短,都必须以 {DONE_MARKER} 标记结束 | |
| 示例格式: | |
| ``` | |
| 你的回答内容... | |
| 更多回答内容... | |
| {DONE_MARKER} | |
| ``` | |
| 注意:{DONE_MARKER} 必须单独占一行,前面不要有任何其他字符。 | |
| 这个规则对于确保输出完整性极其重要,请严格遵守。""" | |
| } | |
| # 检查是否已经包含反截断指令 | |
| has_done_instruction = any( | |
| part.get("text", "").find(DONE_MARKER) != -1 | |
| for part in system_instruction["parts"] | |
| if isinstance(part, dict) | |
| ) | |
| if not has_done_instruction: | |
| system_instruction["parts"].append(anti_truncation_instruction) | |
| request_data["systemInstruction"] = system_instruction | |
| modified_payload["request"] = request_data | |
| log.debug("Applied anti-truncation instruction to request") | |
| return modified_payload | |
| class AntiTruncationStreamProcessor: | |
| """反截断流式处理器""" | |
| def __init__(self, | |
| original_request_func, | |
| payload: Dict[str, Any], | |
| max_attempts: int = MAX_CONTINUATION_ATTEMPTS): | |
| self.original_request_func = original_request_func | |
| self.base_payload = payload.copy() | |
| self.max_attempts = max_attempts | |
| self.collected_content = [] # 使用列表避免字符串重复拼接 | |
| self.current_attempt = 0 | |
| async def process_stream(self) -> AsyncGenerator[bytes, None]: | |
| """处理流式响应,检测并处理截断""" | |
| while self.current_attempt < self.max_attempts: | |
| self.current_attempt += 1 | |
| # 构建当前请求payload | |
| current_payload = self._build_current_payload() | |
| log.debug(f"Anti-truncation attempt {self.current_attempt}/{self.max_attempts}") | |
| # 发送请求 | |
| try: | |
| response = await self.original_request_func(current_payload) | |
| if not isinstance(response, StreamingResponse): | |
| # 非流式响应,直接处理 | |
| yield await self._handle_non_streaming_response(response) | |
| return | |
| # 处理流式响应 | |
| chunk_content = "" | |
| found_done_marker = False | |
| async for chunk in response.body_iterator: | |
| if not chunk: | |
| yield chunk | |
| continue | |
| # 处理不同数据类型的startswith问题 | |
| if isinstance(chunk, bytes): | |
| if not chunk.startswith(b'data: '): | |
| yield chunk | |
| continue | |
| payload_data = chunk[len(b'data: '):] | |
| else: | |
| chunk_str = str(chunk) | |
| if not chunk_str.startswith('data: '): | |
| yield chunk | |
| continue | |
| payload_data = chunk_str[len('data: '):].encode() | |
| # 解析chunk内容 | |
| if payload_data.strip() == b'[DONE]': | |
| # 检查是否找到了done标记 | |
| if found_done_marker: | |
| log.info("Anti-truncation: Found [done] marker, output complete") | |
| yield chunk | |
| return | |
| else: | |
| log.warning("Anti-truncation: Stream ended without [done] marker") | |
| # 不发送[DONE],准备继续 | |
| break | |
| try: | |
| data = json.loads(payload_data.decode()) | |
| content = self._extract_content_from_chunk(data) | |
| if content: | |
| chunk_content += content | |
| # 检查是否包含done标记 | |
| if self._check_done_marker_in_chunk_content(content): | |
| found_done_marker = True | |
| log.info("Anti-truncation: Found [done] marker in chunk") | |
| # 清理chunk中的[done]标记后再发送 | |
| cleaned_chunk = self._remove_done_marker_from_chunk(chunk, data) | |
| yield cleaned_chunk | |
| except (json.JSONDecodeError, UnicodeDecodeError): | |
| yield chunk | |
| continue | |
| # 更新收集的内容 - 使用列表避免字符串重复拼接 | |
| if chunk_content: | |
| self.collected_content.append(chunk_content) | |
| # 如果找到了done标记,结束 | |
| if found_done_marker: | |
| # 立即清理内容释放内存 | |
| self.collected_content.clear() | |
| yield b'data: [DONE]\n\n' | |
| return | |
| # 只有在单个chunk中没有找到done标记时,才检查累积内容(防止done标记跨chunk出现) | |
| if not found_done_marker: | |
| accumulated_text = ''.join(self.collected_content) if self.collected_content else "" | |
| if self._check_done_marker_in_text(accumulated_text): | |
| log.info("Anti-truncation: Found [done] marker in accumulated content") | |
| # 立即清理内容释放内存 | |
| self.collected_content.clear() | |
| yield b'data: [DONE]\n\n' | |
| return | |
| # 如果没找到done标记且不是最后一次尝试,准备续传 | |
| if self.current_attempt < self.max_attempts: | |
| total_length = sum(len(chunk) for chunk in self.collected_content) if self.collected_content else 0 | |
| log.info(f"Anti-truncation: No [done] marker found in output (length: {total_length}), preparing continuation (attempt {self.current_attempt + 1})") | |
| if self.collected_content and total_length > 100: | |
| last_chunk = self.collected_content[-1] if self.collected_content else "" | |
| log.debug(f"Anti-truncation: Current collected content ends with: {'...' + last_chunk[-100:]}") | |
| # 在下一次循环中会继续 | |
| continue | |
| else: | |
| # 最后一次尝试,直接结束 | |
| log.warning("Anti-truncation: Max attempts reached, ending stream") | |
| # 立即清理内容释放内存 | |
| self.collected_content.clear() | |
| yield b'data: [DONE]\n\n' | |
| return | |
| except Exception as e: | |
| log.error(f"Anti-truncation error in attempt {self.current_attempt}: {str(e)}") | |
| if self.current_attempt >= self.max_attempts: | |
| # 发送错误chunk | |
| error_chunk = { | |
| "error": { | |
| "message": f"Anti-truncation failed: {str(e)}", | |
| "type": "api_error", | |
| "code": 500 | |
| } | |
| } | |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode() | |
| yield b'data: [DONE]\n\n' | |
| return | |
| # 否则继续下一次尝试 | |
| # 如果所有尝试都失败了 | |
| log.error("Anti-truncation: All attempts failed") | |
| # 确保清理内容释放内存 | |
| self.collected_content.clear() | |
| yield b'data: [DONE]\n\n' | |
| def _build_current_payload(self) -> Dict[str, Any]: | |
| """构建当前请求的payload""" | |
| if self.current_attempt == 1: | |
| # 第一次请求,使用原始payload(已经包含反截断指令) | |
| return self.base_payload | |
| # 后续请求,添加续传指令 | |
| continuation_payload = self.base_payload.copy() | |
| request_data = continuation_payload.get("request", {}) | |
| # 获取原始对话内容 | |
| contents = request_data.get("contents", []) | |
| new_contents = contents.copy() | |
| # 如果有收集到的内容,添加到对话中 | |
| if self.collected_content: | |
| # 拼接收集的内容并添加模型的回复 | |
| accumulated_text = ''.join(self.collected_content) | |
| new_contents.append({ | |
| "role": "model", | |
| "parts": [{"text": accumulated_text}] | |
| }) | |
| # 构建具体的续写指令,包含前面的内容摘要 | |
| content_summary = "" | |
| if self.collected_content: | |
| accumulated_text = ''.join(self.collected_content) | |
| if len(accumulated_text) > 200: | |
| content_summary = f"\n\n前面你已经输出了约 {len(accumulated_text)} 个字符的内容,结尾是:\n\"...{accumulated_text[-100:]}\"" | |
| else: | |
| content_summary = f"\n\n前面你已经输出的内容是:\n\"{accumulated_text}\"" | |
| detailed_continuation_prompt = f"""{CONTINUATION_PROMPT}{content_summary}""" | |
| # 添加继续指令 | |
| continuation_message = { | |
| "role": "user", | |
| "parts": [{"text": detailed_continuation_prompt}] | |
| } | |
| new_contents.append(continuation_message) | |
| request_data["contents"] = new_contents | |
| continuation_payload["request"] = request_data | |
| return continuation_payload | |
| def _extract_content_from_chunk(self, data: Dict[str, Any]) -> str: | |
| """从chunk数据中提取文本内容""" | |
| content = "" | |
| # 处理Gemini格式 | |
| if "candidates" in data: | |
| for candidate in data["candidates"]: | |
| if "content" in candidate: | |
| parts = candidate["content"].get("parts", []) | |
| for part in parts: | |
| if "text" in part: | |
| content += part["text"] | |
| # 处理OpenAI格式 | |
| elif "choices" in data: | |
| for choice in data["choices"]: | |
| if "delta" in choice and "content" in choice["delta"]: | |
| content += choice["delta"]["content"] | |
| elif "message" in choice and "content" in choice["message"]: | |
| content += choice["message"]["content"] | |
| return content | |
| async def _handle_non_streaming_response(self, response) -> bytes: | |
| """处理非流式响应""" | |
| try: | |
| if hasattr(response, 'body'): | |
| content = response.body.decode() if isinstance(response.body, bytes) else response.body | |
| elif hasattr(response, 'content'): | |
| content = response.content.decode() if isinstance(response.content, bytes) else response.content | |
| else: | |
| content = str(response) | |
| response_data = json.loads(content) | |
| # 检查是否包含done标记 | |
| text_content = self._extract_content_from_response(response_data) | |
| has_done_marker = self._check_done_marker_in_text(text_content) | |
| if not has_done_marker and self.current_attempt < self.max_attempts: | |
| log.info("Anti-truncation: Non-streaming response needs continuation") | |
| if text_content: | |
| self.collected_content.append(text_content) | |
| # 递归处理续传 | |
| return await self._handle_non_streaming_response( | |
| await self.original_request_func(self._build_current_payload()) | |
| ) | |
| return content.encode() | |
| except Exception as e: | |
| log.error(f"Anti-truncation non-streaming error: {str(e)}") | |
| return json.dumps({ | |
| "error": { | |
| "message": f"Anti-truncation failed: {str(e)}", | |
| "type": "api_error", | |
| "code": 500 | |
| } | |
| }).encode() | |
| def _check_done_marker_in_text(self, text: str) -> bool: | |
| """检测文本中是否包含DONE_MARKER(只检测指定标记)""" | |
| if not text: | |
| return False | |
| # 只要文本中出现DONE_MARKER即可 | |
| return DONE_MARKER in text | |
| def _check_done_marker_in_chunk_content(self, content: str) -> bool: | |
| """检查单个chunk内容中是否包含done标记""" | |
| return self._check_done_marker_in_text(content) | |
| def _extract_content_from_response(self, data: Dict[str, Any]) -> str: | |
| """从响应数据中提取文本内容""" | |
| content = "" | |
| # 处理Gemini格式 | |
| if "candidates" in data: | |
| for candidate in data["candidates"]: | |
| if "content" in candidate: | |
| parts = candidate["content"].get("parts", []) | |
| for part in parts: | |
| if "text" in part: | |
| content += part["text"] | |
| # 处理OpenAI格式 | |
| elif "choices" in data: | |
| for choice in data["choices"]: | |
| if "message" in choice and "content" in choice["message"]: | |
| content += choice["message"]["content"] | |
| return content | |
| def _remove_done_marker_from_chunk(self, chunk: bytes, data: Dict[str, Any]) -> bytes: | |
| """使用正则表达式从chunk中移除[done]标记""" | |
| try: | |
| # 首先检查是否真的包含[done]标记,如果没有则直接返回原始chunk | |
| chunk_text = chunk.decode('utf-8', errors='ignore') if isinstance(chunk, bytes) else str(chunk) | |
| if '[done]' not in chunk_text.lower(): | |
| return chunk # 没有[done]标记,直接返回原始chunk | |
| # 编译正则表达式,匹配[done]标记(忽略大小写,包括可能的空白字符) | |
| done_pattern = re.compile(r'\s*\[done\]\s*', re.IGNORECASE) | |
| # 处理Gemini格式 | |
| if "candidates" in data: | |
| modified_data = data.copy() | |
| modified_data["candidates"] = [] | |
| for i, candidate in enumerate(data["candidates"]): | |
| modified_candidate = candidate.copy() | |
| # 只在最后一个candidate中清理[done]标记 | |
| is_last_candidate = (i == len(data["candidates"]) - 1) | |
| if "content" in candidate: | |
| modified_content = candidate["content"].copy() | |
| if "parts" in modified_content: | |
| modified_parts = [] | |
| for part in modified_content["parts"]: | |
| if "text" in part and isinstance(part["text"], str): | |
| modified_part = part.copy() | |
| # 只在最后一个candidate中清理[done]标记 | |
| if is_last_candidate: | |
| modified_part["text"] = done_pattern.sub('', part["text"]) | |
| modified_parts.append(modified_part) | |
| else: | |
| modified_parts.append(part) | |
| modified_content["parts"] = modified_parts | |
| modified_candidate["content"] = modified_content | |
| modified_data["candidates"].append(modified_candidate) | |
| # 重新编码为chunk格式,保持原始的换行符 | |
| if isinstance(chunk, bytes): | |
| prefix = b'data: ' | |
| suffix = b'\n\n' # 确保有正确的换行符 | |
| json_data = json.dumps(modified_data, separators=(',',':'), ensure_ascii=False).encode('utf-8') | |
| return prefix + json_data + suffix | |
| else: | |
| return f"data: {json.dumps(modified_data, separators=(',',':'), ensure_ascii=False)}\n\n" | |
| # 处理OpenAI格式 | |
| elif "choices" in data: | |
| modified_data = data.copy() | |
| modified_data["choices"] = [] | |
| for choice in data["choices"]: | |
| modified_choice = choice.copy() | |
| if "delta" in choice and "content" in choice["delta"]: | |
| modified_delta = choice["delta"].copy() | |
| modified_delta["content"] = done_pattern.sub('', choice["delta"]["content"]) | |
| modified_choice["delta"] = modified_delta | |
| elif "message" in choice and "content" in choice["message"]: | |
| modified_message = choice["message"].copy() | |
| modified_message["content"] = done_pattern.sub('', choice["message"]["content"]) | |
| modified_choice["message"] = modified_message | |
| modified_data["choices"].append(modified_choice) | |
| # 重新编码为chunk格式,保持原始的换行符 | |
| if isinstance(chunk, bytes): | |
| prefix = b'data: ' | |
| suffix = b'\n\n' # 确保有正确的换行符 | |
| json_data = json.dumps(modified_data, separators=(',',':'), ensure_ascii=False).encode('utf-8') | |
| return prefix + json_data + suffix | |
| else: | |
| return f"data: {json.dumps(modified_data, separators=(',',':'), ensure_ascii=False)}\n\n" | |
| # 如果没有找到支持的格式,返回原始chunk | |
| return chunk | |
| except Exception as e: | |
| log.warning(f"Failed to remove [done] marker from chunk: {str(e)}") | |
| return chunk | |
| async def apply_anti_truncation_to_stream( | |
| request_func, | |
| payload: Dict[str, Any], | |
| max_attempts: int = MAX_CONTINUATION_ATTEMPTS | |
| ) -> StreamingResponse: | |
| """ | |
| 对流式请求应用反截断处理 | |
| Args: | |
| request_func: 原始请求函数 | |
| payload: 请求payload | |
| max_attempts: 最大续传尝试次数 | |
| Returns: | |
| 处理后的StreamingResponse | |
| """ | |
| # 首先对payload应用反截断指令 | |
| anti_truncation_payload = apply_anti_truncation(payload) | |
| # 创建反截断处理器 | |
| processor = AntiTruncationStreamProcessor( | |
| lambda p: request_func(p), | |
| anti_truncation_payload, | |
| max_attempts | |
| ) | |
| # 返回包装后的流式响应 | |
| return StreamingResponse( | |
| processor.process_stream(), | |
| media_type="text/event-stream" | |
| ) | |
| def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool: | |
| """ | |
| 检查请求是否启用了反截断功能 | |
| Args: | |
| request_data: 请求数据 | |
| Returns: | |
| 是否启用反截断 | |
| """ | |
| return request_data.get("enable_anti_truncation", False) |