Spaces:
Sleeping
Sleeping
File size: 8,741 Bytes
84ed1d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import json
import csv
import os
import pandas as pd
import openai
import time
import requests
from dotenv import load_dotenv
from tqdm import tqdm
# 加载环境变量(如果有.env文件)
load_dotenv()
# 配置SiliconFlow API
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "sk-ypjvmantsostdxrkirhidrtswohjpmlzuhyqojpudbreakwk")
SILICONFLOW_API_BASE = os.getenv("SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1")
# 保留OpenAI API配置(作为备选)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your_openai_api_key_here")
openai.api_key = OPENAI_API_KEY
# 可以配置为Azure OpenAI
#AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "")
#AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "")
#AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME", "")
# 获取Azure配置参数
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "")
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME", "")
# 如果有Azure OpenAI配置,则使用Azure OpenAI
if AZURE_OPENAI_ENDPOINT.strip() and AZURE_OPENAI_API_KEY.strip() and AZURE_DEPLOYMENT_NAME.strip():
openai.api_type = "azure"
openai.api_base = AZURE_OPENAI_ENDPOINT
openai.api_key = AZURE_OPENAI_API_KEY
openai.api_version = "2023-05-15" # 可能需要根据实际情况调整
# 定义TruePositive的标准(根据mvp三类案例)
def define_positive_sample_criteria():
"""
定义TruePositive的标准
根据搜索结果,TruePositive被定义为"mvp三类案例",但没有找到具体定义
这里我们定义一些可能的标准,实际使用时可以根据需求调整
"""
return """
请判断以下消息是否属于TruePositive。TruePositive定义为与任务管理、待办事项、提醒、通知筛选相关的有用信息,具体包括:
1. 包含明确的任务、待办事项或需要完成的工作
2. 包含时间安排、截止日期或日程提醒
3. 包含项目进展、状态更新或工作报告
如果消息符合以上任一条件,则为TruePositive;否则为TrueNegative。
请只回答"TruePositive"或"TrueNegative"。
"""
# 使用大模型API进行分类
def classify_with_llm(message, criteria, max_retries=3, retry_delay=2):
"""
使用大模型API对消息进行分类
Args:
message: 要分类的消息内容
criteria: 分类标准
max_retries: 最大重试次数
retry_delay: 重试延迟(秒)
Returns:
str: "TruePositive" 或 "TrueNegative"
"""
prompt = f"{criteria}\n\n消息内容: {message}"
system_message = "你是一个专业的数据分类助手,根据给定标准判断消息是TruePositive还是TrueNegative。"
for attempt in range(max_retries):
try:
# 使用SiliconFlow API
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {SILICONFLOW_API_KEY}"
}
payload = {
"model": "deepseek-ai/DeepSeek-V3",
"messages": [
{"role": "system", "content": system_message},
{"role": "user", "content": prompt}
],
"stream": False,
"max_tokens": 512,
"temperature": 0.1,
"top_p": 0.7,
"top_k": 50,
"frequency_penalty": 0.5,
"n": 1
}
response = requests.post(
f"{SILICONFLOW_API_BASE}/chat/completions",
headers=headers,
json=payload
)
# 检查响应状态
response.raise_for_status()
response_data = response.json()
# 解析响应
result = response_data["choices"][0]["message"]["content"].strip()
# 标准化结果
if "TruePositive" in result:
return "TruePositive"
else:
return "TrueNegative"
except Exception as e:
if attempt < max_retries - 1:
print(f"API调用失败,{retry_delay}秒后重试: {e}\n响应状态码: {response.status_code if 'response' in locals() else 'N/A'}\n响应内容: {response.text if 'response' in locals() else 'N/A'}")
time.sleep(retry_delay)
else:
print(f"API调用失败,达到最大重试次数: {e}\n最后响应状态码: {response.status_code if 'response' in locals() else 'N/A'}\n最后响应内容: {response.text if 'response' in locals() else 'N/A'}")
return "分类失败" # 返回一个默认值
# 批量处理消息
def batch_process_messages(messages, batch_size=10, delay=1):
"""
批量处理消息以避免API限制
Args:
messages: 消息列表
batch_size: 每批处理的消息数量
delay: 批次间延迟(秒)
Returns:
list: 处理结果列表
"""
results = []
criteria = define_positive_sample_criteria()
for i in tqdm(range(0, len(messages), batch_size), desc="处理批次"):
batch = messages[i:i+batch_size]
batch_results = []
for msg in tqdm(batch, desc="处理消息", leave=False):
# 只处理有实际内容的消息
if msg.get("content") and len(msg["content"]) > 5: # 忽略过短的消息
classification = classify_with_llm(msg["content"], criteria)
msg["classification"] = classification
else:
msg["classification"] = "TrueNegative" # 默认短消息为TrueNegative
batch_results.append(msg)
results.extend(batch_results)
if i + batch_size < len(messages):
time.sleep(delay) # 批次间延迟
return results
# 主函数
def main():
# 检查API密钥是否配置
if SILICONFLOW_API_KEY == "":
print("警告: 未设置SiliconFlow API密钥。请设置环境变量SILICONFLOW_API_KEY或在代码中直接设置。")
return
# 确定输入文件
input_file = "Messages.json" # 默认使用JSON格式
if not os.path.exists(input_file):
print(f"错误: 找不到JSON输入文件 {input_file}")
return
print(f"使用输入文件: {input_file}")
# 读取数据
messages = []
if input_file.endswith(".json"):
with open(input_file, "r", encoding="utf-8") as f:
messages = json.load(f)
print(f"读取了 {len(messages)} 条消息")
# 询问用户是否要处理所有消息或仅处理一部分样本
sample_size = input("请输入要处理的消息数量(输入'all'处理所有消息,或输入一个数字如'100'处理部分消息): ")
if sample_size.lower() != "all":
try:
sample_size = int(sample_size)
if sample_size < len(messages):
print(f"将处理 {sample_size} 条消息作为样本")
messages = messages[:sample_size]
else:
print(f"样本大小大于等于总消息数,将处理所有 {len(messages)} 条消息")
except ValueError:
print("无效输入,将处理所有消息")
# 批量处理消息
print("开始处理消息...")
classified_messages = batch_process_messages(messages)
# 分离TruePositive from TrueNegative
positive_samples = [msg for msg in classified_messages if msg.get("classification") == "TruePositive"]
negative_samples = [msg for msg in classified_messages if msg.get("classification") == "TrueNegative"]
print(f"分类完成: TruePositive {len(positive_samples)} 条, TrueNegative {len(negative_samples)} 条")
# 保存结果
if input_file.endswith(".json"):
# 保存JSON格式
with open("positive_samples.json", "w", encoding="utf-8") as f:
json.dump(positive_samples, f, ensure_ascii=False, indent=2)
with open("negative_samples.json", "w", encoding="utf-8") as f:
json.dump(negative_samples, f, ensure_ascii=False, indent=2)
print("结果已保存到 positive_samples.json/csv 和 negative_samples.json/csv")
if __name__ == "__main__":
main() |