import json import asyncio from typing import List, Dict, Optional from datetime import datetime from sqlalchemy.orm import Session from app.models.models import ( OptimizationSession, OptimizationSegment, SessionHistory, ChangeLog ) from app.services.ai_service import ( AIService, split_text_into_segments, count_chinese_characters, count_text_length, get_default_polish_prompt, get_default_enhance_prompt, get_emotion_polish_prompt, get_compression_prompt ) from app.services.concurrency import concurrency_manager from app.services.stream_manager import stream_manager from app.config import settings # 错误信息最大长度,避免数据库字段溢出 MAX_ERROR_MESSAGE_LENGTH = 500 class OptimizationService: """优化处理服务""" def __init__(self, db: Session, session_obj: OptimizationSession): self.db = db self.session_obj = session_obj self.polish_service: Optional[AIService] = None self.enhance_service: Optional[AIService] = None self.emotion_service: Optional[AIService] = None self.compression_service: Optional[AIService] = None def _init_ai_services(self): """初始化AI服务 改进的初始化逻辑: 1. 验证必需的配置项 2. 提供更详细的错误信息 3. 确保所有服务都正确初始化 """ try: # 润色服务 self.polish_service = AIService( model=self.session_obj.polish_model or settings.POLISH_MODEL, api_key=self.session_obj.polish_api_key or settings.POLISH_API_KEY, base_url=self.session_obj.polish_base_url or settings.POLISH_BASE_URL ) # 增强服务 self.enhance_service = AIService( model=self.session_obj.enhance_model or settings.ENHANCE_MODEL, api_key=self.session_obj.enhance_api_key or settings.ENHANCE_API_KEY, base_url=self.session_obj.enhance_base_url or settings.ENHANCE_BASE_URL ) # 感情文章润色服务 self.emotion_service = AIService( model=self.session_obj.emotion_model or settings.POLISH_MODEL, api_key=self.session_obj.emotion_api_key or settings.POLISH_API_KEY, base_url=self.session_obj.emotion_base_url or settings.POLISH_BASE_URL ) # 压缩服务 self.compression_service = AIService( model=settings.COMPRESSION_MODEL, api_key=settings.COMPRESSION_API_KEY or settings.OPENAI_API_KEY, base_url=settings.COMPRESSION_BASE_URL or settings.OPENAI_BASE_URL ) print(f"[INFO] 所有 AI 服务初始化成功,会话: {self.session_obj.session_id}") except Exception as e: error_msg = f"AI 服务初始化失败: {str(e)}" print(f"[ERROR] {error_msg}") raise Exception(error_msg) async def start_optimization(self): """开始优化流程""" try: # 初始化AI服务 self._init_ai_services() # 重置错误状态 self.session_obj.error_message = None self.session_obj.failed_segment_index = None self.db.commit() # 获取并发权限 acquired = await concurrency_manager.acquire(self.session_obj.session_id) if not acquired: self.session_obj.status = "queued" self.db.commit() # 等待获取权限 - acquire 方法内部已包含等待逻辑 acquired = await concurrency_manager.acquire(self.session_obj.session_id) if not acquired: raise Exception("等待并发权限超时") # 更新状态为处理中 self.session_obj.status = "processing" self.db.commit() # 检查是否已存在段落,避免重复创建 # 在每次循环前检查会话状态,如果被停止则中断执行 self.db.refresh(self.session_obj) if self.session_obj.status == "stopped": raise Exception("会话已被用户停止") existing_segments = self.db.query(OptimizationSegment).filter( OptimizationSegment.session_id == self.session_obj.id ).order_by(OptimizationSegment.segment_index).all() if not existing_segments: # 首次运行: 分割文本并创建段落记录 segments = split_text_into_segments(self.session_obj.original_text) self.session_obj.total_segments = len(segments) self.db.commit() for idx, segment_text in enumerate(segments): segment = OptimizationSegment( session_id=self.session_obj.id, segment_index=idx, stage="polish", original_text=segment_text, status="pending" ) self.db.add(segment) self.db.commit() else: # 继续运行: 同步总段落数 self.session_obj.total_segments = len(existing_segments) self.db.commit() # 根据处理模式执行不同的阶段 processing_mode = self.session_obj.processing_mode or 'paper_polish_enhance' if processing_mode == 'paper_polish': # 只进行论文润色 await self._process_stage("polish") elif processing_mode == 'paper_enhance': # 只进行论文增强(直接增强原文) await self._process_stage("enhance") elif processing_mode == 'emotion_polish': # 只进行感情文章润色 await self._process_stage("emotion_polish") elif processing_mode == 'paper_polish_enhance': # 论文润色 + 论文增强 await self._process_stage("polish") await self._process_stage("enhance") else: raise ValueError(f"不支持的处理模式: {processing_mode}") # 完成 self.session_obj.status = "completed" self.session_obj.completed_at = datetime.utcnow() self.session_obj.progress = 100.0 self.session_obj.failed_segment_index = None self.db.commit() except Exception as e: self.session_obj.status = "failed" # 安全地截断错误信息 error_msg = str(e) if len(error_msg) > MAX_ERROR_MESSAGE_LENGTH: error_msg = error_msg[:MAX_ERROR_MESSAGE_LENGTH - 50] + "... [错误信息已截断]" self.session_obj.error_message = error_msg self.db.commit() raise finally: # 释放并发权限 await concurrency_manager.release(self.session_obj.session_id) # 清理 AI 服务资源 self._cleanup_ai_services() def _cleanup_ai_services(self): """清理 AI 服务资源""" # 将服务引用设置为 None,让 Python 的垃圾回收处理 # AsyncOpenAI 客户端会自动清理连接 self.polish_service = None self.enhance_service = None self.emotion_service = None self.compression_service = None async def _process_stage(self, stage: str): """处理单个阶段""" print(f"\n[STAGE START] Stage: {stage}, Session: {self.session_obj.session_id}", flush=True) self.session_obj.current_stage = stage self.db.commit() # 获取该阶段的提示词 prompt = self._get_prompt(stage) # 获取AI服务 if stage == "emotion_polish": ai_service = self.emotion_service elif stage == "polish": ai_service = self.polish_service else: # enhance ai_service = self.enhance_service # 获取所有段落 segments = self.db.query(OptimizationSegment).filter( OptimizationSegment.session_id == self.session_obj.id ).order_by(OptimizationSegment.segment_index).all() # 如果存在失败段落,跳过已完成的段落 start_index = 0 if self.session_obj.failed_segment_index is not None: start_index = max(self.session_obj.failed_segment_index, 0) # 历史会话 - 只包含AI的回复内容 # 只加载 start_index 之前的段落到历史,避免重试时历史与当前处理位置不一致 history: List[Dict[str, str]] = [] total_chars = 0 for segment in segments[:start_index]: if segment.is_title: # 标题段落不参与历史上下文 continue if stage == "polish" and segment.polished_text: history.append({"role": "assistant", "content": segment.polished_text}) total_chars += count_chinese_characters(segment.polished_text) elif stage == "emotion_polish" and segment.polished_text: history.append({"role": "assistant", "content": segment.polished_text}) total_chars += count_chinese_characters(segment.polished_text) elif stage == "enhance" and segment.enhanced_text: history.append({"role": "assistant", "content": segment.enhanced_text}) total_chars += count_chinese_characters(segment.enhanced_text) print(f"[STAGE] Loaded {len(history)} history messages from segments[:start_index={start_index}]", flush=True) skip_threshold = max(settings.SEGMENT_SKIP_THRESHOLD, 0) # 获取处理模式,用于正确计算进度 processing_mode = self.session_obj.processing_mode or 'paper_polish_enhance' for idx, segment in enumerate(segments[start_index:], start=start_index): # 每次处理段落前检查会话状态 self.db.refresh(self.session_obj) if self.session_obj.status == "stopped": raise Exception("会话已被用户停止") # 更新进度(无论是否跳过都更新) self.session_obj.current_position = idx # 根据处理模式正确计算进度 if processing_mode == 'paper_polish_enhance': if stage == "polish": # 第一阶段占 0-50% progress = (idx / len(segments)) * 50 else: # enhance # 第二阶段占 50-100% progress = 50 + (idx / len(segments)) * 50 else: # 其他模式占 0-100% progress = (idx / len(segments)) * 100 self.session_obj.progress = min(progress, 100.0) self.db.commit() # 先判断标题和短段落(提前到这里) if count_text_length(segment.original_text) < skip_threshold: if not segment.is_title: segment.is_title = True segment.status = "completed" segment.polished_text = segment.original_text segment.enhanced_text = segment.original_text segment.completed_at = datetime.utcnow() segment.stage = stage self.db.commit() continue # 然后检查是否已处理 if stage in ["polish", "emotion_polish"] and segment.polished_text: continue if stage == "enhance": if segment.enhanced_text: continue if segment.is_title and not segment.enhanced_text: segment.enhanced_text = segment.polished_text or segment.original_text segment.status = "completed" segment.completed_at = segment.completed_at or datetime.utcnow() self.db.commit() continue try: print(f"\n[SEGMENT {idx}] Processing segment {idx+1}/{len(segments)}, Stage: {stage}", flush=True) print(f"[SEGMENT {idx}] Input Length: {count_text_length(segment.original_text)}", flush=True) segment.status = "processing" segment.stage = stage self.db.commit() # 准备输入文本 # 对于 enhance 阶段:如果有润色结果则使用,否则使用原文(适用于 paper_enhance 模式) if stage == "enhance": input_text = segment.polished_text if segment.polished_text else segment.original_text else: input_text = segment.original_text # 调用AI async def execute_call(): # 使用配置中的流式设置,默认非流式(False)以避免API阻止 use_stream = settings.USE_STREAMING if stage == "polish": response = await ai_service.polish_text(input_text, prompt, history, stream=use_stream) elif stage == "emotion_polish": response = await ai_service.polish_emotion_text(input_text, prompt, history, stream=use_stream) else: # enhance response = await ai_service.enhance_text(input_text, prompt, history, stream=use_stream) if use_stream: full_text = "" async for chunk in response: if chunk: full_text += chunk # 推送流式更新 await stream_manager.broadcast(self.session_obj.session_id, { "type": "content", "segment_index": idx, "stage": stage, "content": chunk, "full_text": full_text # 可选:发送全量或增量,这里发送增量chunk,全量用于恢复 }) return full_text else: return response output_text = await self._run_with_retry(idx, stage, execute_call) if stage in ["polish", "emotion_polish"]: segment.polished_text = output_text else: # enhance segment.enhanced_text = output_text segment.status = "completed" segment.completed_at = datetime.utcnow() self.db.commit() # 记录变更 await self._record_change(segment, input_text, output_text, stage) # 更新历史会话 - 只添加AI的回复内容 history.append({"role": "assistant", "content": output_text}) total_chars += count_chinese_characters(output_text) # 检查是否需要压缩历史 - 基于字符数阈值 if total_chars > settings.HISTORY_COMPRESSION_THRESHOLD: print(f"\n[HISTORY COMPRESS] Triggering compression, Stage: {stage}", flush=True) print(f"[HISTORY COMPRESS] Before: {total_chars} chars, {len(history)} messages", flush=True) compressed_history = await self._compress_history(history, stage) # 压缩后的历史替换原历史,用于后续处理 history = compressed_history # 重新计算字符数 total_chars = sum(count_chinese_characters(msg.get("content", "")) for msg in history) print(f"[HISTORY COMPRESS] After: {total_chars} chars, {len(history)} messages", flush=True) # 推送压缩通知给前端 await stream_manager.broadcast(self.session_obj.session_id, { "type": "history_compressed", "stage": stage, "message": f"历史会话已压缩({stage} 阶段),节省上下文空间", "new_char_count": total_chars }) # 只在压缩后保存历史,减少数据库写入 await self._save_history(history, stage, total_chars) except Exception as e: import traceback error_trace = traceback.format_exc() print(f"[ERROR] Segment {idx} processing failed:", flush=True) print(error_trace, flush=True) segment.status = "failed" self.session_obj.failed_segment_index = idx # 安全地截断错误信息,避免数据库字段溢出 error_msg = str(e) if len(error_msg) > MAX_ERROR_MESSAGE_LENGTH: # 保留前面的主要错误信息和末尾的部分 prefix_len = MAX_ERROR_MESSAGE_LENGTH - 50 error_msg = error_msg[:prefix_len] + "... [错误信息已截断]" self.session_obj.error_message = error_msg self.db.commit() # 直接抛出原异常,保留堆栈 raise async def _run_with_retry(self, segment_index: int, stage: str, task): """执行单次任务,不自动重试""" try: return await task() except Exception as exc: raise Exception( f"段落 {segment_index + 1} 在 {stage} 阶段失败: {str(exc)}" ) def _get_prompt(self, stage: str) -> str: """获取提示词""" if stage == "polish": return get_default_polish_prompt() elif stage == "emotion_polish": return get_emotion_polish_prompt() else: # enhance return get_default_enhance_prompt() async def _compress_history( self, history: List[Dict[str, str]], stage: str ) -> List[Dict[str, str]]: """压缩历史会话 - 智能提取关键信息 压缩历史会话以减少token使用,但保留处理风格的关键特征。 压缩后的内容单独保存,不影响已完成的润色和增强文本。 如果压缩失败,返回最近的几条消息而不是抛出异常。 """ try: # 如果历史已经是压缩格式(system消息),直接返回 if len(history) == 1 and history[0].get("role") == "system": return history # 保留最近的2-3条消息作为风格参考 recent_messages = history[-3:] if len(history) > 3 else history # 选择合适的压缩提示词 if stage == "emotion_polish": compression_prompt = """你是一个专业的文本摘要助手。请压缩以下历史处理内容,提取关键风格特征: 1. 总结文本的表达风格和语言特点 2. 提取关键的修改方向和处理模式 3. 保留重要的词汇使用倾向 4. 删除重复的内容和冗余表述 要求: - 压缩后内容不超过原内容的30% - 只输出压缩后的摘要,不要添加任何解释和注释 历史处理内容:""" else: compression_prompt = """你是一个专业的学术文本摘要助手。请压缩以下历史处理内容,提取关键信息: 1. 保留论文的主要术语、核心概念和关键数据 2. 总结已处理段落的主题和要点 3. 提取处理风格和改进方向的关键特征 4. 删除重复内容和冗余表述 要求: - 压缩后内容不超过原内容的30% - 保持学术性和专业性 - 只输出压缩后的摘要文本,不要添加任何解释和注释 历史处理内容:""" compressed_summary = await self.compression_service.compress_history( recent_messages, compression_prompt ) # 返回压缩后的历史作为系统消息,用于后续段落的上下文参考 return [ { "role": "system", "content": f"之前处理的段落摘要:\n{compressed_summary}" } ] except Exception as e: # 压缩失败时,不抛出异常,而是返回最近的几条消息 print(f"[WARNING] 历史压缩失败: {str(e)}, 将使用最近的消息代替", flush=True) # 返回最近的2条消息,避免上下文过长 return history[-2:] if len(history) > 2 else history async def _save_history(self, history: List[Dict[str, str]], stage: str, char_count: int): """保存历史会话 - 只在压缩后保存 只有压缩后的历史才保存到数据库,以避免频繁写入导致数据库膨胀。 压缩后的内容单独保存,不影响已完成的润色和增强文本。 注意:未压缩的历史不会保存,因为: 1. 润色/增强后的文本已经保存在 segments 表中 2. 压缩只在字符数超过阈值时触发 3. 压缩后的历史用于后续段落的上下文参考 """ # 检测是否为压缩后的历史:压缩后只有一条 system 消息,包含之前处理的摘要 # 这种检测方式与 _compress_history 的返回格式保持一致 is_compressed = len(history) == 1 and history[0].get("role") == "system" if not is_compressed: return # 非压缩状态不保存,减少数据库写入 # 检查是否已存在该阶段的压缩记录 existing = self.db.query(SessionHistory).filter( SessionHistory.session_id == self.session_obj.id, SessionHistory.stage == stage, SessionHistory.is_compressed.is_(True) ).first() if existing: # 更新现有记录 existing.history_data = json.dumps(history, ensure_ascii=False) existing.character_count = char_count existing.created_at = datetime.utcnow() else: # 创建新记录 history_obj = SessionHistory( session_id=self.session_obj.id, stage=stage, history_data=json.dumps(history, ensure_ascii=False), is_compressed=True, character_count=char_count ) self.db.add(history_obj) self.db.commit() async def _record_change( self, segment: OptimizationSegment, before: str, after: str, stage: str ): """记录变更""" # 简单的变更检测 changes = { "before_length": len(before), "after_length": len(after), "changed": before != after } existing_log = self.db.query(ChangeLog).filter( ChangeLog.session_id == self.session_obj.id, ChangeLog.segment_index == segment.segment_index, ChangeLog.stage == stage ).order_by(ChangeLog.created_at.desc()).first() serialized_detail = json.dumps(changes, ensure_ascii=False) if existing_log: # 如果之前已经生成过同一段落同一阶段的记录,直接更新内容避免重复条目 existing_log.before_text = before existing_log.after_text = after existing_log.changes_detail = serialized_detail else: change_log = ChangeLog( session_id=self.session_obj.id, segment_index=segment.segment_index, stage=stage, before_text=before, after_text=after, changes_detail=serialized_detail ) self.db.add(change_log) self.db.commit()