Spaces:
Sleeping
Sleeping
| """ | |
| HR Agent执行层 - 第二层 | |
| 根据第一层的指令生成具体的回复 | |
| """ | |
| from typing import Dict, List, Optional | |
| import random | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from config import MODEL_CONFIG, LLM_API_CONFIG | |
| from models.compliance import ComplianceChecker | |
| from models.correctness import CorrectnessEvaluator | |
| class HRAgentExecutor: | |
| """ | |
| HR Agent执行器 - 第二层 | |
| 根据第一层的分析报告和回复指令,生成具体的回复文本 | |
| """ | |
| def __init__(self): | |
| """初始化执行器""" | |
| self.compliance_checker = ComplianceChecker() | |
| self.correctness_evaluator = CorrectnessEvaluator() | |
| # 检查是否使用 API 模式 | |
| self.use_api = LLM_API_CONFIG.get("enabled", False) | |
| # 加载生成模型(仅在不使用 API 时) | |
| self.model = None | |
| self.tokenizer = None | |
| self.llm_api_client = None | |
| if self.use_api: | |
| self._init_api_client() | |
| else: | |
| self._load_model() | |
| def _init_api_client(self): | |
| """初始化 LLM API 客户端""" | |
| try: | |
| from services.llm_api_client import get_llm_api_client | |
| self.llm_api_client = get_llm_api_client() | |
| provider = LLM_API_CONFIG.get("provider", "unknown") | |
| model = LLM_API_CONFIG.get("model", "unknown") | |
| print(f"使用 LLM API 模式: {provider} - {model}") | |
| except Exception as e: | |
| print(f"初始化 LLM API 客户端失败: {e}") | |
| self.use_api = False | |
| def _load_model(self): | |
| """加载对话生成模型""" | |
| model_path = MODEL_CONFIG.get("dialogue_model_path") | |
| if not model_path or not os.path.exists(model_path): | |
| print(f"Warning: Dialogue model path not found: {model_path}") | |
| return | |
| try: | |
| print(f"Loading dialogue model from {model_path}...") | |
| self.device = MODEL_CONFIG.get("device", "cpu") | |
| # 确定 dtype | |
| torch_dtype = torch.float32 | |
| if self.device == "cuda": | |
| torch_dtype = torch.float16 | |
| elif self.device == "mps": | |
| torch_dtype = torch.bfloat16 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| # 确保加载chat_template | |
| if not self.tokenizer.chat_template: | |
| template_path = os.path.join(model_path, "chat_template.jinja") | |
| if os.path.exists(template_path): | |
| with open(template_path, "r", encoding="utf-8") as f: | |
| self.tokenizer.chat_template = f.read() | |
| print("Loaded chat template from file.") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch_dtype, | |
| device_map=self.device, | |
| trust_remote_code=True | |
| ) | |
| print("Dialogue model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading dialogue model: {e}") | |
| self.model = None | |
| def execute( | |
| self, | |
| instruction: Dict, | |
| analysis_report: Dict | |
| ) -> Dict: | |
| """ | |
| 执行回复生成 | |
| Args: | |
| instruction: 第一层生成的回复指令 | |
| analysis_report: 第一层的分析报告 | |
| Returns: | |
| { | |
| "answer": "好的,请问培训人数和预算是多少?", | |
| "template_used": "...", | |
| "modifications": [...], | |
| "compliance_check": {...}, | |
| "quality_score": 95 | |
| } | |
| """ | |
| # Step 1: 生成回复 | |
| if self.use_api and self.llm_api_client: | |
| answer = self._generate_with_api(instruction, analysis_report) | |
| template = f"generated_by_{LLM_API_CONFIG.get('provider', 'api')}" | |
| elif self.model: | |
| answer = self._generate_with_model(instruction, analysis_report) | |
| template = "generated_by_qwen_lora" | |
| else: | |
| # Fallback to template | |
| template = self._select_template(instruction) | |
| answer = self._customize_reply( | |
| template, | |
| instruction, | |
| analysis_report | |
| ) | |
| # Step 3: 合规性检查 | |
| compliance_check = self._check_compliance(answer) | |
| # Step 4: 正确性评估(对比知识库) | |
| correctness_check = self._check_correctness( | |
| answer, | |
| analysis_report | |
| ) | |
| # Step 5: 质量评分 | |
| quality_score = self._calculate_quality_score( | |
| instruction, | |
| compliance_check, | |
| correctness_check | |
| ) | |
| return { | |
| "answer": answer, | |
| "template_used": template, | |
| "modifications": [], | |
| "compliance_check": compliance_check, | |
| "correctness_check": correctness_check, | |
| "quality_score": quality_score | |
| } | |
| def _generate_with_model(self, instruction: Dict, analysis_report: Dict) -> str: | |
| """使用模型生成回复""" | |
| user_question = analysis_report.get("user_question", "") | |
| # 构建系统提示词 | |
| system_prompt = "你是一个专业的HR助手,请根据员工的问题提供准确、专业、合规的回答。" | |
| # 添加指令中的特殊要求 | |
| if instruction.get("tone_requirement"): | |
| system_prompt += f"\n语气要求: {instruction['tone_requirement']}" | |
| if instruction.get("must_include"): | |
| system_prompt += f"\n必须包含: {', '.join(instruction['must_include'])}" | |
| if instruction.get("must_avoid"): | |
| system_prompt += f"\n必须避免: {', '.join(instruction['must_avoid'])}" | |
| # 构建消息 | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_question} | |
| ] | |
| try: | |
| # 应用聊天模板 | |
| text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) | |
| # 生成 | |
| generated_ids = self.model.generate( | |
| model_inputs.input_ids, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 151643, | |
| eos_token_id=[151645, 151643], # <|im_end|> and <|endoftext|> | |
| repetition_penalty=1.1 | |
| ) | |
| # 解码 | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return self._clean_response(response) | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| # Fallback to template if generation fails | |
| template = self._select_template(instruction) | |
| return self._customize_reply(template, instruction, analysis_report) | |
| def _generate_with_api(self, instruction: Dict, analysis_report: Dict) -> str: | |
| """使用 LLM API 生成回复""" | |
| user_question = analysis_report.get("user_question", "") | |
| # 构建系统提示词 | |
| system_prompt = self._build_system_prompt(instruction, analysis_report) | |
| try: | |
| # 调用 API 生成 | |
| response = self.llm_api_client.generate( | |
| system_prompt=system_prompt, | |
| user_message=user_question, | |
| temperature=LLM_API_CONFIG.get("temperature", 0.7), | |
| max_tokens=LLM_API_CONFIG.get("max_tokens", 256) | |
| ) | |
| return response.strip() | |
| except Exception as e: | |
| print(f"API 生成失败: {e}") | |
| # Fallback to template | |
| template = self._select_template(instruction) | |
| return self._customize_reply(template, instruction, analysis_report) | |
| def _build_system_prompt(self, instruction: Dict, analysis_report: Dict) -> str: | |
| """构建系统提示词""" | |
| # 获取情绪信息 | |
| emotion = analysis_report.get("emotion", {}) | |
| emotion_type = emotion.get("emotion", "neutral") | |
| emotion_intensity = emotion.get("intensity", 0.3) | |
| # 获取风险等级 | |
| risk_assessment = analysis_report.get("risk_assessment", {}) | |
| risk_level = risk_assessment.get("risk_level", "low") | |
| # 判断是否是敏感场景 | |
| user_question = analysis_report.get("user_question", "") | |
| is_sensitive_topic = self._is_sensitive_topic(user_question) | |
| # 根据 情绪类型 + 情绪强度 + 敏感场景 来确定回复风格 | |
| style_mode = self._determine_reply_style(emotion_type, emotion_intensity, is_sensitive_topic, risk_level) | |
| # 根据风格模式构建不同的 prompt | |
| system_prompt = self._build_prompt_by_style(style_mode, user_question) | |
| # 添加场景信息 | |
| scenario = analysis_report.get("scenario", {}) | |
| if scenario: | |
| scenario_name = scenario.get("scenario_name", "") | |
| scenario_description = scenario.get("description", "") | |
| system_prompt += f"\n当前场景: {scenario_name}\n" | |
| if scenario_description: | |
| system_prompt += f"场景描述: {scenario_description}\n" | |
| # 添加语气要求 | |
| tone = instruction.get("tone_requirement", {}) | |
| if isinstance(tone, dict): | |
| keywords = tone.get("keywords", []) | |
| avoid = tone.get("avoid", []) | |
| if keywords: | |
| system_prompt += f"\n建议用词: {', '.join(keywords)}" | |
| if avoid: | |
| system_prompt += f"\n避免用词: {', '.join(avoid)}" | |
| # 添加必须包含的内容 | |
| must_include = instruction.get("must_include", []) | |
| if must_include: | |
| system_prompt += f"\n必须包含: {', '.join(must_include)}" | |
| # 添加必须避免的内容 | |
| must_avoid = instruction.get("must_avoid", []) | |
| if must_avoid: | |
| system_prompt += f"\n必须避免: {', '.join(must_avoid)}" | |
| # 添加对话阶段信息 | |
| conversation_stage = analysis_report.get("conversation_stage", {}) | |
| stage = conversation_stage.get("stage", "") | |
| if stage == "complete": | |
| system_prompt += "\n提示: 信息已收集完整,可以给出最终答复了" | |
| # 添加缺失信息提示 | |
| missing_info = analysis_report.get("missing_information", {}) | |
| missing_fields = missing_info.get("missing_fields", []) | |
| if missing_fields: | |
| # 将字段名转换为中文 | |
| field_names_map = { | |
| "training_type": "培训类型", | |
| "participant_count": "参与人数", | |
| "budget": "预算", | |
| "duration": "培训时长", | |
| "start_date": "开始日期", | |
| "location": "培训地点", | |
| "leave_type": "假期类型", | |
| "end_date": "结束日期", | |
| "reason": "原因" | |
| } | |
| missing_names = [field_names_map.get(f, f) for f in missing_fields] | |
| system_prompt += f"\n还需了解: {', '.join(missing_names)}" | |
| return system_prompt | |
| def _determine_reply_style(self, emotion_type: str, emotion_intensity: float, is_sensitive_topic: bool, risk_level: str) -> str: | |
| """ | |
| 根据情绪和场景确定回复风格 | |
| Args: | |
| emotion_type: 情绪类型 (positive/neutral/negative) | |
| emotion_intensity: 情绪强度 (0-1) | |
| is_sensitive_topic: 是否敏感话题 | |
| risk_level: 风险等级 (low/medium/high) | |
| Returns: | |
| 风格模式: empathetic/warm/normal/concise | |
| """ | |
| # 高风险 + 负面情绪 + 高强度 = 最需要同理心 | |
| if risk_level == "high" or (emotion_type == "negative" and emotion_intensity > 0.7): | |
| return "empathetic" # 高度同理心,安抚情绪 | |
| # 敏感话题(离职、劳资纠纷等)= 温暖关注 | |
| if is_sensitive_topic: | |
| return "warm" # 温暖理解,但不过度 | |
| # 负面情绪(即使不是敏感话题)= 温暖关注 | |
| if emotion_type == "negative": | |
| return "warm" | |
| # 正面情绪 + 低强度 = 简洁高效 | |
| if emotion_type == "positive" or emotion_intensity < 0.4: | |
| return "concise" # 简洁直接 | |
| # 默认:正常专业 | |
| return "normal" | |
| def _build_prompt_by_style(self, style_mode: str, user_question: str) -> str: | |
| """ | |
| 根据风格模式构建 prompt | |
| Args: | |
| style_mode: 风格模式 | |
| user_question: 用户问题(用于检测敏感词) | |
| Returns: | |
| 系统提示词 | |
| """ | |
| if style_mode == "empathetic": | |
| # 高度同理心:安抚情绪,承诺帮助 | |
| return """你是公司的 HR 助手,此刻最重要的是安抚员工情绪。 | |
| **沟通重点:** | |
| - 开场必须表达理解:"我理解您现在的心情"、"这确实让人着急"、"我明白您的感受" | |
| - 用"我"承诺行动:"我会亲自跟进来处理"、"我马上帮您确认" | |
| - 让员工感到被重视:"您别担心,这件事我会负责到底" | |
| - 绝对禁止引导投诉/仲裁/起诉 | |
| **回复结构:** | |
| 1. 先表达理解和关心 | |
| 2. 承诺会亲自跟进解决 | |
| 3. 引导相信公司内部协调 | |
| """ | |
| elif style_mode == "warm": | |
| # 温暖关注:表达理解,但不啰嗦 | |
| return """你是公司的 HR 助手,专业且有人情味。 | |
| **沟通风格:** | |
| - 可以简单表达理解:"我理解"、"我明白" | |
| - 用"我"代替"本公司" | |
| - 回复简洁,不要过度客套 | |
| **核心原则:** | |
| 1. 遇到问题时承诺会帮助协调 | |
| 2. 不引导外部投诉/仲裁途径 | |
| """ | |
| elif style_mode == "concise": | |
| # 简洁高效:直接回答,无废话 | |
| return """你是公司的 HR 助手,高效专业。 | |
| **沟通风格:** | |
| - 直接回答问题,不要客套话 | |
| - 需要信息时直接询问 | |
| - 用"我"代替"本公司",自然简洁 | |
| **禁止:** | |
| - 不要说"我理解您的心情"、"我明白您的顾虑"之类的话 | |
| - 不要过度表达关心 | |
| """ | |
| else: # normal | |
| # 正常专业:友好但不过度 | |
| return """你是公司的 HR 助手,专业友好。 | |
| **沟通风格:** | |
| - 回复简洁直接 | |
| - 用"我"代替"本公司"更自然 | |
| - 不要过度客套 | |
| **核心原则:** | |
| 高效帮助员工解决问题。 | |
| """ | |
| def _is_sensitive_topic(self, user_question: str) -> bool: | |
| """ | |
| 判断是否是需要同理心的敏感话题 | |
| Args: | |
| user_question: 用户问题 | |
| Returns: | |
| 是否是敏感话题 | |
| """ | |
| # 敏感关键词(劳资纠纷、投诉、不满等) | |
| sensitive_keywords = [ | |
| "欠薪", "拖欠工资", "不发工资", "克扣工资", | |
| "加班没工资", "加班不给钱", | |
| "投诉", "举报", "仲裁", "起诉", "诉讼", "告公司", | |
| "违法", "侵权", "逼迫", "威胁", "骚扰", "歧视", | |
| "不干了", "要辞职", "离职", "辞退", "开除", "赔偿", | |
| "不公平", "不合理", "太过分", "很生气", "不满" | |
| ] | |
| return any(kw in user_question for kw in sensitive_keywords) | |
| def _clean_response(self, text: str) -> str: | |
| """清理模型生成的回复,去除幻觉和重复内容""" | |
| # 常见的幻觉标记(模型开始模拟对话) | |
| stop_markers = [ | |
| "\nuser", "\nassistant", "\nSystem", "\nUser", "\nAssistant", | |
| "user:", "assistant:", "System:", | |
| "aeper", "рейт", "konkp", "okino", "torino" # 观察到的特定噪声 | |
| ] | |
| for marker in stop_markers: | |
| # 不区分大小写查找 | |
| idx = text.lower().find(marker.lower()) | |
| if idx != -1: | |
| text = text[:idx] | |
| return text.strip() | |
| def _select_template(self, instruction: Dict) -> str: | |
| """选择回复模板""" | |
| suggested_templates = instruction.get("suggested_templates", []) | |
| if not suggested_templates: | |
| return "好的,请问有什么可以帮您?" | |
| # 简单策略:选择第一个模板 | |
| # 实际可以根据上下文、历史等智能选择 | |
| return suggested_templates[0] | |
| def _customize_reply( | |
| self, | |
| template: str, | |
| instruction: Dict, | |
| analysis_report: Dict | |
| ) -> str: | |
| """根据指令定制回复""" | |
| answer = template | |
| # 根据语气要求调整 | |
| tone = instruction.get("tone_requirement", {}) | |
| if isinstance(tone, str): | |
| style = tone | |
| else: | |
| style = tone.get("style", "friendly professional") | |
| # 如果需要同理心 | |
| if style == "empathetic professional": | |
| # 检查是否已经包含同理心词汇 | |
| empathetic_keywords = ["理解", "抱歉", "不便"] | |
| if not any(kw in answer for kw in empathetic_keywords): | |
| # 在适当位置添加同理心表达 | |
| if "好的" in answer: | |
| answer = answer.replace("好的", "我理解您的需求", 1) | |
| elif "收到" in answer: | |
| answer = answer.replace("收到", "我理解您的诉求,收到", 1) | |
| # 确保包含必要内容 | |
| must_include = instruction.get("must_include", []) | |
| for item in must_include: | |
| if item not in answer: | |
| # 如果必要内容不在回复中,添加到末尾 | |
| answer = answer + " " + item | |
| # 根据对话阶段调整 | |
| conversation_stage = analysis_report.get("conversation_stage", {}) | |
| stage = conversation_stage.get("stage", "") | |
| # 检查是否是知识库答案(包含来源信息) | |
| is_knowledge_answer = "(来源:" in answer or "(来源:" in answer | |
| if stage == "complete" and not is_knowledge_answer: | |
| # 信息收集完成,添加确认信息 | |
| if "已记录" not in answer and "已确认" not in answer: | |
| scenario_name = analysis_report["scenario"]["scenario_name"] | |
| answer = answer + f" 您的{scenario_name}相关信息已全部确认。" | |
| return answer | |
| def _check_compliance(self, answer: str) -> Dict: | |
| """检查回复是否合规""" | |
| compliance_result = self.compliance_checker.check_turn(answer) | |
| return { | |
| "is_compliant": len(compliance_result["violations"]) == 0, | |
| "violations": compliance_result["violations"], | |
| "checked_text": answer | |
| } | |
| def _check_correctness( | |
| self, | |
| answer: str, | |
| analysis_report: Dict | |
| ) -> Dict: | |
| """ | |
| 检查回复的正确性(对比知识库) | |
| 优化:区分追问类型和陈述类型 | |
| """ | |
| # 判断回复类型 | |
| if self._is_question(answer): | |
| # 这是追问,不需要做语义相似度评估 | |
| return { | |
| "check_type": "question_validation", | |
| "is_question": True, | |
| "is_appropriate": True, | |
| "note": "这是合理的追问,用于收集更多信息", | |
| "question_detected": self._extract_question(answer), | |
| "checked_text": answer | |
| } | |
| # 陈述性回复,使用Sentence-BERT评估 | |
| user_question = analysis_report.get("user_question", "") | |
| dialogue = [ | |
| {"speaker": "Employee", "utterance": user_question}, | |
| {"speaker": "HR Assistant", "utterance": answer} | |
| ] | |
| # 使用正确性评估器 | |
| correctness_result = self.correctness_evaluator.evaluate_dialogue(dialogue) | |
| # 提取关键信息 | |
| details = correctness_result.get("details", []) | |
| best_match = details[0] if details else None | |
| return { | |
| "check_type": "semantic_similarity", | |
| "is_question": False, | |
| "similarity_score": correctness_result.get("avg_score", 0), | |
| "level": correctness_result.get("level", "unknown"), | |
| "matched_knowledge": best_match.get("matched_qa") if best_match else None, | |
| "is_correct": correctness_result.get("level") in ["good", "fair"], | |
| "checked_text": answer | |
| } | |
| def _is_question(self, text: str) -> bool: | |
| """判断文本是否是问题/追问""" | |
| question_indicators = [ | |
| "?", "?", | |
| "请问", "请问是", "请问有", | |
| "多少", "哪些", "哪个", | |
| "是否", "能不能", "可不可以", | |
| "需要", "请提供", "麻烦" | |
| ] | |
| text_lower = text.lower() | |
| return any(indicator in text for indicator in question_indicators) | |
| def _extract_question(self, text: str) -> str: | |
| """提取问题核心内容""" | |
| # 移除礼貌用语 | |
| for polite in ["请问", "麻烦", "能否"]: | |
| text = text.replace(polite, "") | |
| # 移除标点 | |
| for punct in ["?", "?", "。", "."]: | |
| text = text.replace(punct, "") | |
| return text.strip() | |
| def _calculate_quality_score( | |
| self, | |
| instruction: Dict, | |
| compliance_check: Dict, | |
| correctness_check: Dict | |
| ) -> int: | |
| """计算回复质量分数(优化版)""" | |
| score = 100 | |
| # 1. 正确性评分(根据类型调整) | |
| if correctness_check.get("is_question"): | |
| # 追问类型:检查问题是否合理 | |
| # 追问总是合理的,扣分较少 | |
| score = 95 # 追问默认高分 | |
| else: | |
| # 陈述类型:使用语义相似度 | |
| similarity = correctness_check.get("similarity_score", 0) | |
| correctness_penalty = (1 - similarity) * 40 | |
| score = max(0, score - int(correctness_penalty)) | |
| # 2. 合规性扣分(权重35%) | |
| if not compliance_check["is_compliant"]: | |
| violations = compliance_check["violations"] | |
| for violation in violations: | |
| severity = violation.get("severity", "low") | |
| if severity == "high": | |
| score -= 30 | |
| elif severity == "medium": | |
| score -= 15 | |
| else: | |
| score -= 5 | |
| # 检查是否包含必要内容 | |
| must_include = instruction.get("must_include", []) | |
| missing_content = [] | |
| for item in must_include: | |
| # 简化检查:看是否包含关键词 | |
| keywords = item.split()[:2] # 取前两个词作为关键词 | |
| if not any(kw in str(instruction.get("suggested_templates", "")) | |
| for kw in keywords): | |
| missing_content.append(item) | |
| score -= len(missing_content) * 5 | |
| return max(0, int(score)) | |