File size: 6,984 Bytes
38d8dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import openai
import threading
from concurrent.futures import ThreadPoolExecutor
from openai import APIError
from typing import List, Tuple

class FormatValidator:
    """数据格式验证器"""
    @staticmethod
    def validate_line(keywords: List[str], original: str) -> str:
        """

        格式:关键词1,关键词2,关键词3:原文

        """
        # 清洗关键词中的非法符号
        cleaned_keywords = [
            kw.strip().replace(':', '').replace('\n', '')[:10]  # 限制关键词长度
            for kw in keywords if kw.strip()
        ][:3]  # 最多取前3个关键词
        
        # 处理空关键词情况
        if not cleaned_keywords:
            keywords_str = "无关键词"
        else:
            keywords_str = ",".join(cleaned_keywords)
        
        # 移除原文中的换行符
        cleaned_original = original.strip().replace('\n', ' ')
        return f"{keywords_str}{cleaned_original}"

class ThreadSafeWriter:
    """增强型线程安全写入器"""
    def __init__(self, output_path: str):
        self.file = open(output_path, 'a+', encoding='utf-8')
        self.lock = threading.Lock()
        self.counter = 0  # 写入计数器
    
    def write_line(self, content: str):
        with self.lock:
            self.file.write(content + '\n')
            self.file.flush()
            self.counter += 1
    
    def get_progress(self):
        with self.lock:
            return self.counter
    
    def close(self):
        self.file.close()

class DeepSeekBatchProcessor:
    def __init__(self, max_workers: int = 100):
        self.client = openai.OpenAI(
            api_key=os.getenv("DEEPSEEK_API_KEY", "sk-4da7e956235447e3b7bec1b51f5a3db7"),
            base_url="https://api.deepseek.com"
        )
        self.max_workers = max_workers
        self.error_flag = threading.Event()
        self.rate_limiter = threading.Semaphore(20)  # API速率限制

    def process_batch(self, batch: List[Tuple[int, str]], writer: ThreadSafeWriter):
        """批量处理并保持顺序"""
        futures = []
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            for line_num, original in batch:
                if self.error_flag.is_set():
                    break
                futures.append(
                    executor.submit(
                        self._process_single_line,
                        line_num,
                        original,
                        writer
                    )
                )
            
            # 等待当前批次完成
            for future in futures:
                future.result()

    def _process_single_line(self, line_num: int, original: str, writer: ThreadSafeWriter):
        if self.error_flag.is_set():
            return

        retries = 0
        while retries < 3 and not self.error_flag.is_set():
            try:
                with self.rate_limiter:
                    response = self.client.chat.completions.create(
                        model="deepseek-reasoner",
                        messages=[
                            {"role": "system", "content": self._get_prompt()},
                            {"role": "user", "content": original}
                        ],
                        temperature=0.1,
                        max_tokens=30
                    )
                
                # 解析响应
                keywords = self._parse_response(response)
                formatted_line = FormatValidator.validate_line(keywords, original)
                writer.write_line(formatted_line)
                
                # 更新进度
                progress = writer.get_progress()
                print(f"\r已处理 {progress} 条", end='')
                
                break  # 成功时退出重试循环

            except APIError as e:
                if e.status_code == 402:  # 余额不足
                    print(f"\n行 {line_num} 处理失败:API余额不足")
                    self.error_flag.set()
                    return
                elif e.status_code == 429:  # 速率限制
                    print(f"\n行 {line_num} 速率受限,重试中...")
                    retries += 1
                    if retries >= 3:
                        print(f"行 {line_num} 重试次数耗尽")
                else:
                    print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
                    return  # 其他API错误不重试

            except Exception as e:
                print(f"\n行 {line_num} 处理异常:{str(e)}")
                retries += 1
                if retries >= 3:
                    print(f"行 {line_num} 重试次数耗尽")
        
        # 重试失败处理
        if retries >= 3 and not self.error_flag.is_set():
            writer.write_line(f"处理失败:{original}")  # 记录失败行

    @staticmethod
    def _get_prompt() -> str:
        return 

    @staticmethod
    def _parse_response(response) -> List[str]:
        content = response.choices[0].message.content.strip()
        return [kw.strip("。、") for kw in content.replace(',', ',').split(',') if kw]

def process_large_file(

    input_path: str,

    output_path: str,

    batch_size: int = 500,

    max_workers: int = 100

):
    """大文件处理入口"""
    # 初始化组件
    processor = DeepSeekBatchProcessor(max_workers)
    writer = ThreadSafeWriter(output_path)
    
    try:
        # 读取并批处理数据
        with open(input_path, 'r', encoding='utf-8') as f:
            # 生成带行号的批次 [(行号, 内容), ...]
            batches = []
            current_batch = []
            for line_num, line in enumerate(f, 1):
                if line.strip():
                    current_batch.append( (line_num, line.strip()) )
                    if len(current_batch) >= batch_size:
                        batches.append(current_batch)
                        current_batch = []
            if current_batch:
                batches.append(current_batch)

        # 按批次处理(保持批次顺序)
        total = sum(len(b) for b in batches)
        print(f"总数据量:{total}条")
        
        for batch in batches:
            if processor.error_flag.is_set():
                break
            processor.process_batch(batch, writer)
        
        print("\n处理完成!")
    
    finally:
        writer.close()

if __name__ == '__main__':
    # 文件路径配置
    input_file = "data\DSdata.txt"
    output_file = "data\CoTdata.txt"
    
    # 启动处理流程
    process_large_file(
        input_path=input_file,
        output_path=output_file,
        batch_size=500,
        max_workers=100
    )