import json import csv import os import pandas as pd import openai import time import requests from dotenv import load_dotenv from tqdm import tqdm # 加载环境变量(如果有.env文件) load_dotenv() # 配置SiliconFlow API SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "sk-ypjvmantsostdxrkirhidrtswohjpmlzuhyqojpudbreakwk") SILICONFLOW_API_BASE = os.getenv("SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1") # 保留OpenAI API配置(作为备选) OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your_openai_api_key_here") openai.api_key = OPENAI_API_KEY # 可以配置为Azure OpenAI #AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "") #AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "") #AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME", "") # 获取Azure配置参数 AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "") AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "") AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME", "") # 如果有Azure OpenAI配置,则使用Azure OpenAI if AZURE_OPENAI_ENDPOINT.strip() and AZURE_OPENAI_API_KEY.strip() and AZURE_DEPLOYMENT_NAME.strip(): openai.api_type = "azure" openai.api_base = AZURE_OPENAI_ENDPOINT openai.api_key = AZURE_OPENAI_API_KEY openai.api_version = "2023-05-15" # 可能需要根据实际情况调整 # 定义TruePositive的标准(根据mvp三类案例) def define_positive_sample_criteria(): """ 定义TruePositive的标准 根据搜索结果,TruePositive被定义为"mvp三类案例",但没有找到具体定义 这里我们定义一些可能的标准,实际使用时可以根据需求调整 """ return """ 请判断以下消息是否属于TruePositive。TruePositive定义为与任务管理、待办事项、提醒、通知筛选相关的有用信息,具体包括: 1. 包含明确的任务、待办事项或需要完成的工作 2. 包含时间安排、截止日期或日程提醒 3. 包含项目进展、状态更新或工作报告 如果消息符合以上任一条件,则为TruePositive;否则为TrueNegative。 请只回答"TruePositive"或"TrueNegative"。 """ # 使用大模型API进行分类 def classify_with_llm(message, criteria, max_retries=3, retry_delay=2): """ 使用大模型API对消息进行分类 Args: message: 要分类的消息内容 criteria: 分类标准 max_retries: 最大重试次数 retry_delay: 重试延迟(秒) Returns: str: "TruePositive" 或 "TrueNegative" """ prompt = f"{criteria}\n\n消息内容: {message}" system_message = "你是一个专业的数据分类助手,根据给定标准判断消息是TruePositive还是TrueNegative。" for attempt in range(max_retries): try: # 使用SiliconFlow API headers = { "Content-Type": "application/json", "Authorization": f"Bearer {SILICONFLOW_API_KEY}" } payload = { "model": "deepseek-ai/DeepSeek-V3", "messages": [ {"role": "system", "content": system_message}, {"role": "user", "content": prompt} ], "stream": False, "max_tokens": 512, "temperature": 0.1, "top_p": 0.7, "top_k": 50, "frequency_penalty": 0.5, "n": 1 } response = requests.post( f"{SILICONFLOW_API_BASE}/chat/completions", headers=headers, json=payload ) # 检查响应状态 response.raise_for_status() response_data = response.json() # 解析响应 result = response_data["choices"][0]["message"]["content"].strip() # 标准化结果 if "TruePositive" in result: return "TruePositive" else: return "TrueNegative" except Exception as e: if attempt < max_retries - 1: print(f"API调用失败,{retry_delay}秒后重试: {e}\n响应状态码: {response.status_code if 'response' in locals() else 'N/A'}\n响应内容: {response.text if 'response' in locals() else 'N/A'}") time.sleep(retry_delay) else: print(f"API调用失败,达到最大重试次数: {e}\n最后响应状态码: {response.status_code if 'response' in locals() else 'N/A'}\n最后响应内容: {response.text if 'response' in locals() else 'N/A'}") return "分类失败" # 返回一个默认值 # 批量处理消息 def batch_process_messages(messages, batch_size=10, delay=1): """ 批量处理消息以避免API限制 Args: messages: 消息列表 batch_size: 每批处理的消息数量 delay: 批次间延迟(秒) Returns: list: 处理结果列表 """ results = [] criteria = define_positive_sample_criteria() for i in tqdm(range(0, len(messages), batch_size), desc="处理批次"): batch = messages[i:i+batch_size] batch_results = [] for msg in tqdm(batch, desc="处理消息", leave=False): # 只处理有实际内容的消息 if msg.get("content") and len(msg["content"]) > 5: # 忽略过短的消息 classification = classify_with_llm(msg["content"], criteria) msg["classification"] = classification else: msg["classification"] = "TrueNegative" # 默认短消息为TrueNegative batch_results.append(msg) results.extend(batch_results) if i + batch_size < len(messages): time.sleep(delay) # 批次间延迟 return results # 主函数 def main(): # 检查API密钥是否配置 if SILICONFLOW_API_KEY == "": print("警告: 未设置SiliconFlow API密钥。请设置环境变量SILICONFLOW_API_KEY或在代码中直接设置。") return # 确定输入文件 input_file = "Messages.json" # 默认使用JSON格式 if not os.path.exists(input_file): print(f"错误: 找不到JSON输入文件 {input_file}") return print(f"使用输入文件: {input_file}") # 读取数据 messages = [] if input_file.endswith(".json"): with open(input_file, "r", encoding="utf-8") as f: messages = json.load(f) print(f"读取了 {len(messages)} 条消息") # 询问用户是否要处理所有消息或仅处理一部分样本 sample_size = input("请输入要处理的消息数量(输入'all'处理所有消息,或输入一个数字如'100'处理部分消息): ") if sample_size.lower() != "all": try: sample_size = int(sample_size) if sample_size < len(messages): print(f"将处理 {sample_size} 条消息作为样本") messages = messages[:sample_size] else: print(f"样本大小大于等于总消息数,将处理所有 {len(messages)} 条消息") except ValueError: print("无效输入,将处理所有消息") # 批量处理消息 print("开始处理消息...") classified_messages = batch_process_messages(messages) # 分离TruePositive from TrueNegative positive_samples = [msg for msg in classified_messages if msg.get("classification") == "TruePositive"] negative_samples = [msg for msg in classified_messages if msg.get("classification") == "TrueNegative"] print(f"分类完成: TruePositive {len(positive_samples)} 条, TrueNegative {len(negative_samples)} 条") # 保存结果 if input_file.endswith(".json"): # 保存JSON格式 with open("positive_samples.json", "w", encoding="utf-8") as f: json.dump(positive_samples, f, ensure_ascii=False, indent=2) with open("negative_samples.json", "w", encoding="utf-8") as f: json.dump(negative_samples, f, ensure_ascii=False, indent=2) print("结果已保存到 positive_samples.json/csv 和 negative_samples.json/csv") if __name__ == "__main__": main()