File size: 8,741 Bytes
84ed1d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import json
import csv
import os
import pandas as pd
import openai
import time
import requests
from dotenv import load_dotenv
from tqdm import tqdm

# 加载环境变量(如果有.env文件)
load_dotenv()

# 配置SiliconFlow API
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "sk-ypjvmantsostdxrkirhidrtswohjpmlzuhyqojpudbreakwk")
SILICONFLOW_API_BASE = os.getenv("SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1")

# 保留OpenAI API配置(作为备选)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your_openai_api_key_here")
openai.api_key = OPENAI_API_KEY

# 可以配置为Azure OpenAI
#AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "")
#AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "")
#AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME", "")

# 获取Azure配置参数
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "")
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME", "")

# 如果有Azure OpenAI配置,则使用Azure OpenAI
if AZURE_OPENAI_ENDPOINT.strip() and AZURE_OPENAI_API_KEY.strip() and AZURE_DEPLOYMENT_NAME.strip():
    openai.api_type = "azure"
    openai.api_base = AZURE_OPENAI_ENDPOINT
    openai.api_key = AZURE_OPENAI_API_KEY
    openai.api_version = "2023-05-15"  # 可能需要根据实际情况调整

# 定义TruePositive的标准(根据mvp三类案例)
def define_positive_sample_criteria():
    """

    定义TruePositive的标准

    根据搜索结果,TruePositive被定义为"mvp三类案例",但没有找到具体定义

    这里我们定义一些可能的标准,实际使用时可以根据需求调整

    """
    return """

    请判断以下消息是否属于TruePositive。TruePositive定义为与任务管理、待办事项、提醒、通知筛选相关的有用信息,具体包括:

    1. 包含明确的任务、待办事项或需要完成的工作

    2. 包含时间安排、截止日期或日程提醒

    3. 包含项目进展、状态更新或工作报告

    

    如果消息符合以上任一条件,则为TruePositive;否则为TrueNegative。

    请只回答"TruePositive"或"TrueNegative"。

    """

# 使用大模型API进行分类
def classify_with_llm(message, criteria, max_retries=3, retry_delay=2):
    """

    使用大模型API对消息进行分类

    

    Args:

        message: 要分类的消息内容

        criteria: 分类标准

        max_retries: 最大重试次数

        retry_delay: 重试延迟(秒)

        

    Returns:

        str: "TruePositive" 或 "TrueNegative"

    """
    prompt = f"{criteria}\n\n消息内容: {message}"
    system_message = "你是一个专业的数据分类助手,根据给定标准判断消息是TruePositive还是TrueNegative。"
    
    for attempt in range(max_retries):
        try:
            # 使用SiliconFlow API
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {SILICONFLOW_API_KEY}"
            }
            
            payload = {
                "model": "deepseek-ai/DeepSeek-V3",
                "messages": [
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": prompt}
                ],
                "stream": False,
                "max_tokens": 512,
                "temperature": 0.1,
                "top_p": 0.7,
                "top_k": 50,
                "frequency_penalty": 0.5,
                "n": 1
            }
            
            response = requests.post(
                f"{SILICONFLOW_API_BASE}/chat/completions",
                headers=headers,
                json=payload
            )
            
            # 检查响应状态
            response.raise_for_status()
            response_data = response.json()
            
            # 解析响应
            result = response_data["choices"][0]["message"]["content"].strip()
            
            # 标准化结果
            if "TruePositive" in result:
                return "TruePositive"
            else:
                return "TrueNegative"
                
        except Exception as e:
            if attempt < max_retries - 1:
                print(f"API调用失败,{retry_delay}秒后重试: {e}\n响应状态码: {response.status_code if 'response' in locals() else 'N/A'}\n响应内容: {response.text if 'response' in locals() else 'N/A'}")
                time.sleep(retry_delay)
            else:
                print(f"API调用失败,达到最大重试次数: {e}\n最后响应状态码: {response.status_code if 'response' in locals() else 'N/A'}\n最后响应内容: {response.text if 'response' in locals() else 'N/A'}")
                return "分类失败"  # 返回一个默认值

# 批量处理消息
def batch_process_messages(messages, batch_size=10, delay=1):
    """

    批量处理消息以避免API限制

    

    Args:

        messages: 消息列表

        batch_size: 每批处理的消息数量

        delay: 批次间延迟(秒)

        

    Returns:

        list: 处理结果列表

    """
    results = []
    criteria = define_positive_sample_criteria()
    
    for i in tqdm(range(0, len(messages), batch_size), desc="处理批次"):
        batch = messages[i:i+batch_size]
        batch_results = []
        
        for msg in tqdm(batch, desc="处理消息", leave=False):
            # 只处理有实际内容的消息
            if msg.get("content") and len(msg["content"]) > 5:  # 忽略过短的消息
                classification = classify_with_llm(msg["content"], criteria)
                msg["classification"] = classification
            else:
                msg["classification"] = "TrueNegative"  # 默认短消息为TrueNegative
                
            batch_results.append(msg)
            
        results.extend(batch_results)
        
        if i + batch_size < len(messages):
            time.sleep(delay)  # 批次间延迟
            
    return results

# 主函数
def main():
    # 检查API密钥是否配置
    if SILICONFLOW_API_KEY == "":
        print("警告: 未设置SiliconFlow API密钥。请设置环境变量SILICONFLOW_API_KEY或在代码中直接设置。")
        return
    
    # 确定输入文件
    input_file = "Messages.json"  # 默认使用JSON格式
    if not os.path.exists(input_file):
        print(f"错误: 找不到JSON输入文件 {input_file}")
        return
    
    print(f"使用输入文件: {input_file}")
    
    # 读取数据
    messages = []
    if input_file.endswith(".json"):
        with open(input_file, "r", encoding="utf-8") as f:
            messages = json.load(f)

    
    print(f"读取了 {len(messages)} 条消息")
    
    # 询问用户是否要处理所有消息或仅处理一部分样本
    sample_size = input("请输入要处理的消息数量(输入'all'处理所有消息,或输入一个数字如'100'处理部分消息): ")
    
    if sample_size.lower() != "all":
        try:
            sample_size = int(sample_size)
            if sample_size < len(messages):
                print(f"将处理 {sample_size} 条消息作为样本")
                messages = messages[:sample_size]
            else:
                print(f"样本大小大于等于总消息数,将处理所有 {len(messages)} 条消息")
        except ValueError:
            print("无效输入,将处理所有消息")
    
    # 批量处理消息
    print("开始处理消息...")
    classified_messages = batch_process_messages(messages)
    
    # 分离TruePositive from TrueNegative
    positive_samples = [msg for msg in classified_messages if msg.get("classification") == "TruePositive"]
    negative_samples = [msg for msg in classified_messages if msg.get("classification") == "TrueNegative"]
    
    print(f"分类完成: TruePositive {len(positive_samples)} 条, TrueNegative {len(negative_samples)} 条")
    
    # 保存结果
    if input_file.endswith(".json"):
        # 保存JSON格式
        with open("positive_samples.json", "w", encoding="utf-8") as f:
            json.dump(positive_samples, f, ensure_ascii=False, indent=2)
        
        with open("negative_samples.json", "w", encoding="utf-8") as f:
            json.dump(negative_samples, f, ensure_ascii=False, indent=2)
    

    
    print("结果已保存到 positive_samples.json/csv 和 negative_samples.json/csv")

if __name__ == "__main__":
    main()