Spaces:
Running
Running
| import json | |
| import asyncio | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from sqlalchemy.orm import Session | |
| from app.models.models import ( | |
| OptimizationSession, OptimizationSegment, | |
| SessionHistory, ChangeLog | |
| ) | |
| from app.services.ai_service import ( | |
| AIService, split_text_into_segments, | |
| count_chinese_characters, count_text_length, get_default_polish_prompt, | |
| get_default_enhance_prompt, get_emotion_polish_prompt, get_compression_prompt | |
| ) | |
| from app.services.concurrency import concurrency_manager | |
| from app.services.stream_manager import stream_manager | |
| from app.config import settings | |
| # 错误信息最大长度,避免数据库字段溢出 | |
| MAX_ERROR_MESSAGE_LENGTH = 500 | |
| class OptimizationService: | |
| """优化处理服务""" | |
| def __init__(self, db: Session, session_obj: OptimizationSession): | |
| self.db = db | |
| self.session_obj = session_obj | |
| self.polish_service: Optional[AIService] = None | |
| self.enhance_service: Optional[AIService] = None | |
| self.emotion_service: Optional[AIService] = None | |
| self.compression_service: Optional[AIService] = None | |
| def _init_ai_services(self): | |
| """初始化AI服务 | |
| 改进的初始化逻辑: | |
| 1. 验证必需的配置项 | |
| 2. 提供更详细的错误信息 | |
| 3. 确保所有服务都正确初始化 | |
| """ | |
| try: | |
| # 润色服务 | |
| self.polish_service = AIService( | |
| model=self.session_obj.polish_model or settings.POLISH_MODEL, | |
| api_key=self.session_obj.polish_api_key or settings.POLISH_API_KEY, | |
| base_url=self.session_obj.polish_base_url or settings.POLISH_BASE_URL | |
| ) | |
| # 增强服务 | |
| self.enhance_service = AIService( | |
| model=self.session_obj.enhance_model or settings.ENHANCE_MODEL, | |
| api_key=self.session_obj.enhance_api_key or settings.ENHANCE_API_KEY, | |
| base_url=self.session_obj.enhance_base_url or settings.ENHANCE_BASE_URL | |
| ) | |
| # 感情文章润色服务 | |
| self.emotion_service = AIService( | |
| model=self.session_obj.emotion_model or settings.POLISH_MODEL, | |
| api_key=self.session_obj.emotion_api_key or settings.POLISH_API_KEY, | |
| base_url=self.session_obj.emotion_base_url or settings.POLISH_BASE_URL | |
| ) | |
| # 压缩服务 | |
| self.compression_service = AIService( | |
| model=settings.COMPRESSION_MODEL, | |
| api_key=settings.COMPRESSION_API_KEY or settings.OPENAI_API_KEY, | |
| base_url=settings.COMPRESSION_BASE_URL or settings.OPENAI_BASE_URL | |
| ) | |
| print(f"[INFO] 所有 AI 服务初始化成功,会话: {self.session_obj.session_id}") | |
| except Exception as e: | |
| error_msg = f"AI 服务初始化失败: {str(e)}" | |
| print(f"[ERROR] {error_msg}") | |
| raise Exception(error_msg) | |
| async def start_optimization(self): | |
| """开始优化流程""" | |
| try: | |
| # 初始化AI服务 | |
| self._init_ai_services() | |
| # 重置错误状态 | |
| self.session_obj.error_message = None | |
| self.session_obj.failed_segment_index = None | |
| self.db.commit() | |
| # 获取并发权限 | |
| acquired = await concurrency_manager.acquire(self.session_obj.session_id) | |
| if not acquired: | |
| self.session_obj.status = "queued" | |
| self.db.commit() | |
| # 等待获取权限 - acquire 方法内部已包含等待逻辑 | |
| acquired = await concurrency_manager.acquire(self.session_obj.session_id) | |
| if not acquired: | |
| raise Exception("等待并发权限超时") | |
| # 更新状态为处理中 | |
| self.session_obj.status = "processing" | |
| self.db.commit() | |
| # 检查是否已存在段落,避免重复创建 | |
| # 在每次循环前检查会话状态,如果被停止则中断执行 | |
| self.db.refresh(self.session_obj) | |
| if self.session_obj.status == "stopped": | |
| raise Exception("会话已被用户停止") | |
| existing_segments = self.db.query(OptimizationSegment).filter( | |
| OptimizationSegment.session_id == self.session_obj.id | |
| ).order_by(OptimizationSegment.segment_index).all() | |
| if not existing_segments: | |
| # 首次运行: 分割文本并创建段落记录 | |
| segments = split_text_into_segments(self.session_obj.original_text) | |
| self.session_obj.total_segments = len(segments) | |
| self.db.commit() | |
| for idx, segment_text in enumerate(segments): | |
| segment = OptimizationSegment( | |
| session_id=self.session_obj.id, | |
| segment_index=idx, | |
| stage="polish", | |
| original_text=segment_text, | |
| status="pending" | |
| ) | |
| self.db.add(segment) | |
| self.db.commit() | |
| else: | |
| # 继续运行: 同步总段落数 | |
| self.session_obj.total_segments = len(existing_segments) | |
| self.db.commit() | |
| # 根据处理模式执行不同的阶段 | |
| processing_mode = self.session_obj.processing_mode or 'paper_polish_enhance' | |
| if processing_mode == 'paper_polish': | |
| # 只进行论文润色 | |
| await self._process_stage("polish") | |
| elif processing_mode == 'paper_enhance': | |
| # 只进行论文增强(直接增强原文) | |
| await self._process_stage("enhance") | |
| elif processing_mode == 'emotion_polish': | |
| # 只进行感情文章润色 | |
| await self._process_stage("emotion_polish") | |
| elif processing_mode == 'paper_polish_enhance': | |
| # 论文润色 + 论文增强 | |
| await self._process_stage("polish") | |
| await self._process_stage("enhance") | |
| else: | |
| raise ValueError(f"不支持的处理模式: {processing_mode}") | |
| # 完成 | |
| self.session_obj.status = "completed" | |
| self.session_obj.completed_at = datetime.utcnow() | |
| self.session_obj.progress = 100.0 | |
| self.session_obj.failed_segment_index = None | |
| self.db.commit() | |
| except Exception as e: | |
| self.session_obj.status = "failed" | |
| # 安全地截断错误信息 | |
| error_msg = str(e) | |
| if len(error_msg) > MAX_ERROR_MESSAGE_LENGTH: | |
| error_msg = error_msg[:MAX_ERROR_MESSAGE_LENGTH - 50] + "... [错误信息已截断]" | |
| self.session_obj.error_message = error_msg | |
| self.db.commit() | |
| raise | |
| finally: | |
| # 释放并发权限 | |
| await concurrency_manager.release(self.session_obj.session_id) | |
| # 清理 AI 服务资源 | |
| self._cleanup_ai_services() | |
| def _cleanup_ai_services(self): | |
| """清理 AI 服务资源""" | |
| # 将服务引用设置为 None,让 Python 的垃圾回收处理 | |
| # AsyncOpenAI 客户端会自动清理连接 | |
| self.polish_service = None | |
| self.enhance_service = None | |
| self.emotion_service = None | |
| self.compression_service = None | |
| async def _process_stage(self, stage: str): | |
| """处理单个阶段""" | |
| print(f"\n[STAGE START] Stage: {stage}, Session: {self.session_obj.session_id}", flush=True) | |
| self.session_obj.current_stage = stage | |
| self.db.commit() | |
| # 获取该阶段的提示词 | |
| prompt = self._get_prompt(stage) | |
| # 获取AI服务 | |
| if stage == "emotion_polish": | |
| ai_service = self.emotion_service | |
| elif stage == "polish": | |
| ai_service = self.polish_service | |
| else: # enhance | |
| ai_service = self.enhance_service | |
| # 获取所有段落 | |
| segments = self.db.query(OptimizationSegment).filter( | |
| OptimizationSegment.session_id == self.session_obj.id | |
| ).order_by(OptimizationSegment.segment_index).all() | |
| # 如果存在失败段落,跳过已完成的段落 | |
| start_index = 0 | |
| if self.session_obj.failed_segment_index is not None: | |
| start_index = max(self.session_obj.failed_segment_index, 0) | |
| # 历史会话 - 只包含AI的回复内容 | |
| # 只加载 start_index 之前的段落到历史,避免重试时历史与当前处理位置不一致 | |
| history: List[Dict[str, str]] = [] | |
| total_chars = 0 | |
| for segment in segments[:start_index]: | |
| if segment.is_title: | |
| # 标题段落不参与历史上下文 | |
| continue | |
| if stage == "polish" and segment.polished_text: | |
| history.append({"role": "assistant", "content": segment.polished_text}) | |
| total_chars += count_chinese_characters(segment.polished_text) | |
| elif stage == "emotion_polish" and segment.polished_text: | |
| history.append({"role": "assistant", "content": segment.polished_text}) | |
| total_chars += count_chinese_characters(segment.polished_text) | |
| elif stage == "enhance" and segment.enhanced_text: | |
| history.append({"role": "assistant", "content": segment.enhanced_text}) | |
| total_chars += count_chinese_characters(segment.enhanced_text) | |
| print(f"[STAGE] Loaded {len(history)} history messages from segments[:start_index={start_index}]", flush=True) | |
| skip_threshold = max(settings.SEGMENT_SKIP_THRESHOLD, 0) | |
| # 获取处理模式,用于正确计算进度 | |
| processing_mode = self.session_obj.processing_mode or 'paper_polish_enhance' | |
| for idx, segment in enumerate(segments[start_index:], start=start_index): | |
| # 每次处理段落前检查会话状态 | |
| self.db.refresh(self.session_obj) | |
| if self.session_obj.status == "stopped": | |
| raise Exception("会话已被用户停止") | |
| # 更新进度(无论是否跳过都更新) | |
| self.session_obj.current_position = idx | |
| # 根据处理模式正确计算进度 | |
| if processing_mode == 'paper_polish_enhance': | |
| if stage == "polish": | |
| # 第一阶段占 0-50% | |
| progress = (idx / len(segments)) * 50 | |
| else: # enhance | |
| # 第二阶段占 50-100% | |
| progress = 50 + (idx / len(segments)) * 50 | |
| else: | |
| # 其他模式占 0-100% | |
| progress = (idx / len(segments)) * 100 | |
| self.session_obj.progress = min(progress, 100.0) | |
| self.db.commit() | |
| # 先判断标题和短段落(提前到这里) | |
| if count_text_length(segment.original_text) < skip_threshold: | |
| if not segment.is_title: | |
| segment.is_title = True | |
| segment.status = "completed" | |
| segment.polished_text = segment.original_text | |
| segment.enhanced_text = segment.original_text | |
| segment.completed_at = datetime.utcnow() | |
| segment.stage = stage | |
| self.db.commit() | |
| continue | |
| # 然后检查是否已处理 | |
| if stage in ["polish", "emotion_polish"] and segment.polished_text: | |
| continue | |
| if stage == "enhance": | |
| if segment.enhanced_text: | |
| continue | |
| if segment.is_title and not segment.enhanced_text: | |
| segment.enhanced_text = segment.polished_text or segment.original_text | |
| segment.status = "completed" | |
| segment.completed_at = segment.completed_at or datetime.utcnow() | |
| self.db.commit() | |
| continue | |
| try: | |
| print(f"\n[SEGMENT {idx}] Processing segment {idx+1}/{len(segments)}, Stage: {stage}", flush=True) | |
| print(f"[SEGMENT {idx}] Input Length: {count_text_length(segment.original_text)}", flush=True) | |
| segment.status = "processing" | |
| segment.stage = stage | |
| self.db.commit() | |
| # 准备输入文本 | |
| # 对于 enhance 阶段:如果有润色结果则使用,否则使用原文(适用于 paper_enhance 模式) | |
| if stage == "enhance": | |
| input_text = segment.polished_text if segment.polished_text else segment.original_text | |
| else: | |
| input_text = segment.original_text | |
| # 调用AI | |
| async def execute_call(): | |
| # 使用配置中的流式设置,默认非流式(False)以避免API阻止 | |
| use_stream = settings.USE_STREAMING | |
| if stage == "polish": | |
| response = await ai_service.polish_text(input_text, prompt, history, stream=use_stream) | |
| elif stage == "emotion_polish": | |
| response = await ai_service.polish_emotion_text(input_text, prompt, history, stream=use_stream) | |
| else: # enhance | |
| response = await ai_service.enhance_text(input_text, prompt, history, stream=use_stream) | |
| if use_stream: | |
| full_text = "" | |
| async for chunk in response: | |
| if chunk: | |
| full_text += chunk | |
| # 推送流式更新 | |
| await stream_manager.broadcast(self.session_obj.session_id, { | |
| "type": "content", | |
| "segment_index": idx, | |
| "stage": stage, | |
| "content": chunk, | |
| "full_text": full_text # 可选:发送全量或增量,这里发送增量chunk,全量用于恢复 | |
| }) | |
| return full_text | |
| else: | |
| return response | |
| output_text = await self._run_with_retry(idx, stage, execute_call) | |
| if stage in ["polish", "emotion_polish"]: | |
| segment.polished_text = output_text | |
| else: # enhance | |
| segment.enhanced_text = output_text | |
| segment.status = "completed" | |
| segment.completed_at = datetime.utcnow() | |
| self.db.commit() | |
| # 记录变更 | |
| await self._record_change(segment, input_text, output_text, stage) | |
| # 更新历史会话 - 只添加AI的回复内容 | |
| history.append({"role": "assistant", "content": output_text}) | |
| total_chars += count_chinese_characters(output_text) | |
| # 检查是否需要压缩历史 - 基于字符数阈值 | |
| if total_chars > settings.HISTORY_COMPRESSION_THRESHOLD: | |
| print(f"\n[HISTORY COMPRESS] Triggering compression, Stage: {stage}", flush=True) | |
| print(f"[HISTORY COMPRESS] Before: {total_chars} chars, {len(history)} messages", flush=True) | |
| compressed_history = await self._compress_history(history, stage) | |
| # 压缩后的历史替换原历史,用于后续处理 | |
| history = compressed_history | |
| # 重新计算字符数 | |
| total_chars = sum(count_chinese_characters(msg.get("content", "")) for msg in history) | |
| print(f"[HISTORY COMPRESS] After: {total_chars} chars, {len(history)} messages", flush=True) | |
| # 推送压缩通知给前端 | |
| await stream_manager.broadcast(self.session_obj.session_id, { | |
| "type": "history_compressed", | |
| "stage": stage, | |
| "message": f"历史会话已压缩({stage} 阶段),节省上下文空间", | |
| "new_char_count": total_chars | |
| }) | |
| # 只在压缩后保存历史,减少数据库写入 | |
| await self._save_history(history, stage, total_chars) | |
| except Exception as e: | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| print(f"[ERROR] Segment {idx} processing failed:", flush=True) | |
| print(error_trace, flush=True) | |
| segment.status = "failed" | |
| self.session_obj.failed_segment_index = idx | |
| # 安全地截断错误信息,避免数据库字段溢出 | |
| error_msg = str(e) | |
| if len(error_msg) > MAX_ERROR_MESSAGE_LENGTH: | |
| # 保留前面的主要错误信息和末尾的部分 | |
| prefix_len = MAX_ERROR_MESSAGE_LENGTH - 50 | |
| error_msg = error_msg[:prefix_len] + "... [错误信息已截断]" | |
| self.session_obj.error_message = error_msg | |
| self.db.commit() | |
| # 直接抛出原异常,保留堆栈 | |
| raise | |
| async def _run_with_retry(self, segment_index: int, stage: str, task): | |
| """执行单次任务,不自动重试""" | |
| try: | |
| return await task() | |
| except Exception as exc: | |
| raise Exception( | |
| f"段落 {segment_index + 1} 在 {stage} 阶段失败: {str(exc)}" | |
| ) | |
| def _get_prompt(self, stage: str) -> str: | |
| """获取提示词""" | |
| if stage == "polish": | |
| return get_default_polish_prompt() | |
| elif stage == "emotion_polish": | |
| return get_emotion_polish_prompt() | |
| else: # enhance | |
| return get_default_enhance_prompt() | |
| async def _compress_history( | |
| self, | |
| history: List[Dict[str, str]], | |
| stage: str | |
| ) -> List[Dict[str, str]]: | |
| """压缩历史会话 - 智能提取关键信息 | |
| 压缩历史会话以减少token使用,但保留处理风格的关键特征。 | |
| 压缩后的内容单独保存,不影响已完成的润色和增强文本。 | |
| 如果压缩失败,返回最近的几条消息而不是抛出异常。 | |
| """ | |
| try: | |
| # 如果历史已经是压缩格式(system消息),直接返回 | |
| if len(history) == 1 and history[0].get("role") == "system": | |
| return history | |
| # 保留最近的2-3条消息作为风格参考 | |
| recent_messages = history[-3:] if len(history) > 3 else history | |
| # 选择合适的压缩提示词 | |
| if stage == "emotion_polish": | |
| compression_prompt = """你是一个专业的文本摘要助手。请压缩以下历史处理内容,提取关键风格特征: | |
| 1. 总结文本的表达风格和语言特点 | |
| 2. 提取关键的修改方向和处理模式 | |
| 3. 保留重要的词汇使用倾向 | |
| 4. 删除重复的内容和冗余表述 | |
| 要求: | |
| - 压缩后内容不超过原内容的30% | |
| - 只输出压缩后的摘要,不要添加任何解释和注释 | |
| 历史处理内容:""" | |
| else: | |
| compression_prompt = """你是一个专业的学术文本摘要助手。请压缩以下历史处理内容,提取关键信息: | |
| 1. 保留论文的主要术语、核心概念和关键数据 | |
| 2. 总结已处理段落的主题和要点 | |
| 3. 提取处理风格和改进方向的关键特征 | |
| 4. 删除重复内容和冗余表述 | |
| 要求: | |
| - 压缩后内容不超过原内容的30% | |
| - 保持学术性和专业性 | |
| - 只输出压缩后的摘要文本,不要添加任何解释和注释 | |
| 历史处理内容:""" | |
| compressed_summary = await self.compression_service.compress_history( | |
| recent_messages, | |
| compression_prompt | |
| ) | |
| # 返回压缩后的历史作为系统消息,用于后续段落的上下文参考 | |
| return [ | |
| { | |
| "role": "system", | |
| "content": f"之前处理的段落摘要:\n{compressed_summary}" | |
| } | |
| ] | |
| except Exception as e: | |
| # 压缩失败时,不抛出异常,而是返回最近的几条消息 | |
| print(f"[WARNING] 历史压缩失败: {str(e)}, 将使用最近的消息代替", flush=True) | |
| # 返回最近的2条消息,避免上下文过长 | |
| return history[-2:] if len(history) > 2 else history | |
| async def _save_history(self, history: List[Dict[str, str]], stage: str, char_count: int): | |
| """保存历史会话 - 只在压缩后保存 | |
| 只有压缩后的历史才保存到数据库,以避免频繁写入导致数据库膨胀。 | |
| 压缩后的内容单独保存,不影响已完成的润色和增强文本。 | |
| 注意:未压缩的历史不会保存,因为: | |
| 1. 润色/增强后的文本已经保存在 segments 表中 | |
| 2. 压缩只在字符数超过阈值时触发 | |
| 3. 压缩后的历史用于后续段落的上下文参考 | |
| """ | |
| # 检测是否为压缩后的历史:压缩后只有一条 system 消息,包含之前处理的摘要 | |
| # 这种检测方式与 _compress_history 的返回格式保持一致 | |
| is_compressed = len(history) == 1 and history[0].get("role") == "system" | |
| if not is_compressed: | |
| return # 非压缩状态不保存,减少数据库写入 | |
| # 检查是否已存在该阶段的压缩记录 | |
| existing = self.db.query(SessionHistory).filter( | |
| SessionHistory.session_id == self.session_obj.id, | |
| SessionHistory.stage == stage, | |
| SessionHistory.is_compressed.is_(True) | |
| ).first() | |
| if existing: | |
| # 更新现有记录 | |
| existing.history_data = json.dumps(history, ensure_ascii=False) | |
| existing.character_count = char_count | |
| existing.created_at = datetime.utcnow() | |
| else: | |
| # 创建新记录 | |
| history_obj = SessionHistory( | |
| session_id=self.session_obj.id, | |
| stage=stage, | |
| history_data=json.dumps(history, ensure_ascii=False), | |
| is_compressed=True, | |
| character_count=char_count | |
| ) | |
| self.db.add(history_obj) | |
| self.db.commit() | |
| async def _record_change( | |
| self, | |
| segment: OptimizationSegment, | |
| before: str, | |
| after: str, | |
| stage: str | |
| ): | |
| """记录变更""" | |
| # 简单的变更检测 | |
| changes = { | |
| "before_length": len(before), | |
| "after_length": len(after), | |
| "changed": before != after | |
| } | |
| existing_log = self.db.query(ChangeLog).filter( | |
| ChangeLog.session_id == self.session_obj.id, | |
| ChangeLog.segment_index == segment.segment_index, | |
| ChangeLog.stage == stage | |
| ).order_by(ChangeLog.created_at.desc()).first() | |
| serialized_detail = json.dumps(changes, ensure_ascii=False) | |
| if existing_log: | |
| # 如果之前已经生成过同一段落同一阶段的记录,直接更新内容避免重复条目 | |
| existing_log.before_text = before | |
| existing_log.after_text = after | |
| existing_log.changes_detail = serialized_detail | |
| else: | |
| change_log = ChangeLog( | |
| session_id=self.session_obj.id, | |
| segment_index=segment.segment_index, | |
| stage=stage, | |
| before_text=before, | |
| after_text=after, | |
| changes_detail=serialized_detail | |
| ) | |
| self.db.add(change_log) | |
| self.db.commit() | |