Spaces:
Sleeping
Sleeping
| """ | |
| 对话上下文管理器 | |
| 跟踪多轮对话的状态和上下文信息 | |
| """ | |
| from typing import List, Dict, Optional, Any | |
| from collections import defaultdict | |
| import re | |
| class ConversationContext: | |
| """ | |
| 对话上下文管理器 | |
| 功能: | |
| 1. 跟踪对话历史 | |
| 2. 记录当前场景状态 | |
| 3. 追踪已收集和缺失的信息 | |
| 4. 检测追问类型 | |
| 5. 管理对话阶段 | |
| """ | |
| def __init__(self): | |
| """初始化上下文管理器""" | |
| # 对话历史 | |
| self.history = [] | |
| # 当前场景状态 | |
| self.current_scenario = None | |
| self.scenario_confidence = 0.0 | |
| # 信息收集状态 | |
| self.collected_info = {} # 已收集的信息 | |
| self.missing_info = [] # 仍需收集的信息 | |
| self.info_confidence = {} # 每个信息的置信度 | |
| # 对话阶段 | |
| self.conversation_stage = "initial" # initial/in_progress/complete | |
| self.current_step = 0 | |
| self.total_steps = 0 | |
| self.last_action = None | |
| # 上一次HR回复 | |
| self.last_hr_question = None | |
| self.last_hr_response = None | |
| def update_from_analysis( | |
| self, | |
| analysis_report: Dict, | |
| current_turn: Dict | |
| ) -> Dict: | |
| """ | |
| 从分析报告更新上下文 | |
| Args: | |
| analysis_report: 第一层的分析报告 | |
| current_turn: 当前轮次信息 | |
| Returns: | |
| 更新后的上下文信息 | |
| """ | |
| # 更新场景 | |
| scenario = analysis_report.get("scenario", {}) | |
| scenario_id = scenario.get("scenario_id") if isinstance(scenario, dict) else None | |
| scenario_confidence = scenario.get("confidence", 0.0) if isinstance(scenario, dict) else 0.0 | |
| # 场景切换检测 | |
| if scenario_id and scenario_id != self.current_scenario: | |
| # 新场景,重置信息收集状态 | |
| self._reset_for_new_scenario(scenario_id, scenario if isinstance(scenario, dict) else {}) | |
| elif scenario_id: | |
| # 同一场景,更新置信度 | |
| if scenario_confidence > self.scenario_confidence: | |
| self.scenario_confidence = scenario_confidence | |
| # 更新信息收集状态 | |
| extracted_info = analysis_report.get("information_extraction", {}) | |
| new_collected = extracted_info.get("extracted_data", {}) if isinstance(extracted_info, dict) else {} | |
| # 合并已收集信息 | |
| for key, value in new_collected.items(): | |
| if key not in self.collected_info: | |
| self.collected_info[key] = value | |
| self.info_confidence[key] = extracted_info.get("extraction_confidence", 0.5) if isinstance(extracted_info, dict) else 0.5 | |
| # 更新缺失信息列表 | |
| missing_info_data = analysis_report.get("missing_information", {}) | |
| self.missing_info = missing_info_data.get("missing_fields", []) if isinstance(missing_info_data, dict) else [] | |
| # 更新对话阶段 | |
| conv_stage = analysis_report.get("conversation_stage", {}) | |
| if isinstance(conv_stage, dict): | |
| self.conversation_stage = conv_stage.get("stage", "initial") | |
| self.current_step = conv_stage.get("current_step", 0) | |
| self.total_steps = conv_stage.get("total_steps", 1) | |
| return self.get_context_summary() | |
| def _reset_for_new_scenario(self, scenario_id: str, scenario: Dict): | |
| """为新场景重置状态""" | |
| self.current_scenario = scenario_id | |
| self.scenario_confidence = scenario.get("confidence", 0.0) | |
| self.conversation_stage = "initial" | |
| self.current_step = 0 | |
| self.total_steps = len(scenario.get("workflow", [])) | |
| self.collected_info = {} | |
| self.missing_info = scenario.get("required_info", []).copy() | |
| self.info_confidence = {} | |
| def add_to_history(self, turn: Dict): | |
| """添加轮次到历史""" | |
| self.history.append(turn) | |
| self.last_action = turn.get("speaker") | |
| def is_followup_question(self, current_question: str, conversation_history: Optional[List[Dict]] = None) -> Dict: | |
| """ | |
| 检测是否是追问 | |
| Args: | |
| current_question: 当前用户输入 | |
| conversation_history: 对话历史(可选,用于追问检测) | |
| Returns: | |
| { | |
| "is_followup": true/false, | |
| "followup_type": "information_supply/clarification/confirmation", | |
| "answers_previous": "问题人数" | |
| } | |
| """ | |
| # 优先从上下文状态中获取上一个HR问题 | |
| hr_question_to_check = self.last_hr_question | |
| # 如果上下文中没有,尝试从对话历史中获取最后一条assistant消息 | |
| if not hr_question_to_check and conversation_history: | |
| for msg in reversed(conversation_history): | |
| if msg.get("role") == "assistant": | |
| hr_question_to_check = msg.get("content", "") | |
| break | |
| # 检查是否回答了上一个问题 | |
| if hr_question_to_check: | |
| question_content = self._extract_question_content(hr_question_to_check) | |
| if self._answers_question(current_question, question_content): | |
| return { | |
| "is_followup": True, | |
| "followup_type": "information_supply", | |
| "answers_previous": question_content, | |
| "hr_question": hr_question_to_check | |
| } | |
| # 检查是否是澄清 | |
| clarification_indicators = ["我是说", "也就是说", "我的意思是", "具体是"] | |
| if any(ind in current_question for ind in clarification_indicators): | |
| return { | |
| "is_followup": True, | |
| "followup_type": "clarification", | |
| "reason": "用户在澄清之前的回答" | |
| } | |
| # 检查是否是确认 | |
| confirmation_indicators = ["对的", "是的", "没错", "就这些", "可以"] | |
| if any(ind in current_question for ind in confirmation_indicators): | |
| return { | |
| "is_followup": True, | |
| "followup_type": "confirmation", | |
| "reason": "用户确认信息" | |
| } | |
| return { | |
| "is_followup": False, | |
| "followup_type": None | |
| } | |
| def _extract_question_content(self, question: str) -> str: | |
| """提取问题核心内容""" | |
| # 移除礼貌用语和标点 | |
| for polite in ["请问", "麻烦", "能否"]: | |
| question = question.replace(polite, "") | |
| for punct in ["?", "?", "。", ".", ",", ","]: | |
| question = question.replace(punct, "") | |
| return question.strip() | |
| def _answers_question(self, answer: str, question: str) -> bool: | |
| """检查回答是否回应了问题""" | |
| # 提取问题中的关键信息类型 | |
| if "人数" in question or "人" in question: | |
| # 检查是否包含数字或人数相关词 | |
| if re.search(r'\d+|[三两四五六七八九十百千万亿]+|[几多少]个?人', answer): | |
| return True | |
| if "预算" in question or "费用" in question or "钱" in question: | |
| if re.search(r'\d+|[三两四五六七八九十百千万]+|[元块万k]', answer): | |
| return True | |
| if "时间" in question or "天" in question or "什么时候" in question: | |
| # 支持阿拉伯数字和中文数字 | |
| if re.search(r'\d+[月日天]|[一二三四五六七八九十百千万]+[月日天]|明天|后天|下周|本周', answer): | |
| return True | |
| # 年份/入职时间相关 | |
| if "哪一年" in question or "哪年" in question or "入职" in question or "加入" in question or "年份" in question: | |
| # 检查是否包含年份数字(如 2020、2021 等) | |
| if re.search(r'(19|20)\d{2}年?|[一二三四五六七八九十]{4}年?', answer): | |
| return True | |
| # 日期相关 | |
| if "日期" in question or "哪天" in question or "几号" in question: | |
| if re.search(r'\d+[号日]|[一二三四五六七八九十]{1,2}[号日]', answer): | |
| return True | |
| # 部门/岗位相关 | |
| if "部门" in question or "岗位" in question or "职位" in question: | |
| # 只要回答不是太短,就认为是有效回答 | |
| if len(answer.strip()) >= 2: | |
| return True | |
| # 时长/期限相关 | |
| if "多久" in question or "多长时间" in question or "期限" in question: | |
| if re.search(r'\d+[天月年周]|[一二三四五六七八九十百千万]+[天月年周]', answer): | |
| return True | |
| # 如果回答是纯数字或简短内容,且问题中有疑问词,倾向于认为是回答 | |
| if re.match(r'^\d+$', answer.strip()) or len(answer.strip()) <= 10: | |
| if any(qw in question for qw in ["什么", "哪", "多少", "如何", "怎么", "是否", "能不能"]): | |
| return True | |
| return False | |
| def should_continue_current_scenario(self) -> bool: | |
| """判断是否应该继续当前场景""" | |
| return ( | |
| self.conversation_stage == "in_progress" and | |
| len(self.missing_info) > 0 | |
| ) | |
| def get_next_action_suggestion(self) -> Dict: | |
| """ | |
| 获取下一步行动建议 | |
| Returns: | |
| { | |
| "action": "ask_next_question/confirm_complete/switch_scenario", | |
| "target_field": "participant_count", | |
| "suggested_question": "请问有多少人参加培训?" | |
| } | |
| """ | |
| if self.conversation_stage == "complete" or len(self.missing_info) == 0: | |
| return { | |
| "action": "confirm_complete", | |
| "reason": "信息收集完成", | |
| "collected_info": self.collected_info, | |
| "suggested_response": "好的,您的申请信息已完整确认,我们将尽快处理。" | |
| } | |
| # 有缺失信息,继续询问 | |
| if self.missing_info: | |
| next_field = self.missing_info[0] | |
| return { | |
| "action": "ask_next_question", | |
| "target_field": next_field, | |
| "missing_fields": self.missing_info, | |
| "suggested_question": self._get_question_for_field(next_field), | |
| "completion_rate": len(self.collected_info) / (len(self.collected_info) + len(self.missing_info)) | |
| } | |
| return { | |
| "action": "continue", | |
| "reason": "继续当前流程" | |
| } | |
| def _get_question_for_field(self, field: str) -> str: | |
| """获取询问特定字段的标准问题""" | |
| questions = { | |
| "training_type": "请问您想申请什么类型的培训?", | |
| "participant_count": "请问有多少人参加培训?", | |
| "budget": "请问培训预算大约是多少?", | |
| "duration": "请问培训计划进行多长时间?", | |
| "leave_type": "请问您想请什么类型的假期?", | |
| "start_date": "请问您打算从哪天开始请假?", | |
| "end_date": "请问您计划哪天回来上班?", | |
| "reason": "请问请假的原因是什么?", | |
| "issue_description": "请问能详细描述一下遇到的问题吗?", | |
| "affected_parties": "请问这个问题涉及哪些人员?", | |
| "last_working_day": "请问您计划的最后工作日是哪天?", | |
| "expense_type": "请问是哪种类型的费用?", | |
| "amount": "请问金额是多少?", | |
| "description": "请详细说明费用情况。", | |
| "destination": "请问要去哪里出差?", | |
| "purpose": "请问出差的目的是什么?", | |
| "overtime_date": "请问计划哪天加班?", | |
| "target_position": "请问想转到哪个岗位?", | |
| "benefit_type": "请问您想咨询哪方面的福利?" | |
| } | |
| return questions.get(field, "请问能提供更多信息吗?") | |
| def get_context_summary(self) -> Dict: | |
| """获取上下文摘要""" | |
| return { | |
| "current_scenario": self.current_scenario, | |
| "scenario_confidence": self.scenario_confidence, | |
| "conversation_stage": self.conversation_stage, | |
| "current_step": self.current_step, | |
| "total_steps": self.total_steps, | |
| "completion_rate": len(self.collected_info) / max(1, len(self.collected_info) + len(self.missing_info)), | |
| "collected_info": self.collected_info, | |
| "missing_info": self.missing_info, | |
| "history_length": len(self.history), | |
| "last_hr_question": self.last_hr_question | |
| } | |
| def record_hr_interaction( | |
| self, | |
| hr_response: str, | |
| extracted_question: Optional[str] = None | |
| ): | |
| """记录HR的交互""" | |
| self.last_hr_response = hr_response | |
| if extracted_question: | |
| self.last_hr_question = extracted_question | |
| def to_dict(self) -> Dict: | |
| """序列化为字典""" | |
| return { | |
| "history": self.history, | |
| "current_scenario": self.current_scenario, | |
| "scenario_confidence": self.scenario_confidence, | |
| "collected_info": self.collected_info, | |
| "missing_info": self.missing_info, | |
| "info_confidence": self.info_confidence, | |
| "conversation_stage": self.conversation_stage, | |
| "current_step": self.current_step, | |
| "total_steps": self.total_steps, | |
| "last_action": self.last_action, | |
| "last_hr_question": self.last_hr_question, | |
| "last_hr_response": self.last_hr_response | |
| } | |
| def from_dict(cls, data: Dict) -> 'ConversationContext': | |
| """从字典反序列化""" | |
| context = cls() | |
| context.history = data.get("history", []) | |
| context.current_scenario = data.get("current_scenario") | |
| context.scenario_confidence = data.get("scenario_confidence", 0.0) | |
| context.collected_info = data.get("collected_info", {}) | |
| context.missing_info = data.get("missing_info", []) | |
| context.info_confidence = data.get("info_confidence", {}) | |
| context.conversation_stage = data.get("conversation_stage", "initial") | |
| context.current_step = data.get("current_step", 0) | |
| context.total_steps = data.get("total_steps", 0) | |
| context.last_action = data.get("last_action") | |
| context.last_hr_question = data.get("last_hr_question") | |
| context.last_hr_response = data.get("last_hr_response") | |
| return context | |
| # 对话会话管理器(多用户支持) | |
| class ConversationManager: | |
| """管理多个对话会话""" | |
| def __init__(self): | |
| """初始化会话管理器""" | |
| self.sessions = {} # session_id -> ConversationContext | |
| from services.queue_manager import queue_manager | |
| self.queue_manager = queue_manager | |
| self.redis_client = getattr(queue_manager, 'client', None) if queue_manager.backend == "redis" else None | |
| def get_or_create_session(self, session_id: str) -> ConversationContext: | |
| """获取或创建会话""" | |
| # 1. 尝试从内存获取 | |
| if session_id in self.sessions: | |
| return self.sessions[session_id] | |
| # 2. 尝试从Redis获取 | |
| if self.redis_client: | |
| try: | |
| data = self.redis_client.get(f"hr_eval:session:{session_id}") | |
| if data: | |
| import json | |
| context_data = json.loads(data) | |
| context = ConversationContext.from_dict(context_data) | |
| self.sessions[session_id] = context | |
| return context | |
| except Exception as e: | |
| print(f"Failed to load session from Redis: {e}") | |
| # 3. 创建新会话 | |
| self.sessions[session_id] = ConversationContext() | |
| return self.sessions[session_id] | |
| def save_session(self, session_id: str): | |
| """保存会话到Redis""" | |
| if session_id in self.sessions and self.redis_client: | |
| try: | |
| context = self.sessions[session_id] | |
| import json | |
| data = json.dumps(context.to_dict(), ensure_ascii=False) | |
| # 设置过期时间为30分钟 | |
| self.redis_client.setex(f"hr_eval:session:{session_id}", 1800, data) | |
| except Exception as e: | |
| print(f"Failed to save session to Redis: {e}") | |
| def clear_session(self, session_id: str): | |
| """清除会话""" | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |
| if self.redis_client: | |
| try: | |
| self.redis_client.delete(f"hr_eval:session:{session_id}") | |
| except Exception as e: | |
| print(f"Failed to delete session from Redis: {e}") | |
| def get_all_sessions(self) -> Dict[str, Dict]: | |
| """获取所有会话摘要""" | |
| return { | |
| session_id: ctx.get_context_summary() | |
| for session_id, ctx in self.sessions.items() | |
| } | |
| # 全局单例 | |
| _conversation_manager_instance = None | |
| def get_conversation_manager() -> ConversationManager: | |
| """获取对话管理器单例""" | |
| global _conversation_manager_instance | |
| if _conversation_manager_instance is None: | |
| _conversation_manager_instance = ConversationManager() | |
| return _conversation_manager_instance | |