|
|
import os
|
|
|
import openai
|
|
|
import threading
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from openai import APIError
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
class FormatValidator:
|
|
|
"""数据格式验证器"""
|
|
|
@staticmethod
|
|
|
def validate_line(keywords: List[str], original: str) -> str:
|
|
|
"""
|
|
|
格式:关键词1,关键词2,关键词3:原文
|
|
|
"""
|
|
|
|
|
|
cleaned_keywords = [
|
|
|
kw.strip().replace(':', '').replace('\n', '')[:10]
|
|
|
for kw in keywords if kw.strip()
|
|
|
][:3]
|
|
|
|
|
|
|
|
|
if not cleaned_keywords:
|
|
|
keywords_str = "无关键词"
|
|
|
else:
|
|
|
keywords_str = ",".join(cleaned_keywords)
|
|
|
|
|
|
|
|
|
cleaned_original = original.strip().replace('\n', ' ')
|
|
|
return f"{keywords_str}:{cleaned_original}"
|
|
|
|
|
|
class ThreadSafeWriter:
|
|
|
"""增强型线程安全写入器"""
|
|
|
def __init__(self, output_path: str):
|
|
|
self.file = open(output_path, 'a+', encoding='utf-8')
|
|
|
self.lock = threading.Lock()
|
|
|
self.counter = 0
|
|
|
|
|
|
def write_line(self, content: str):
|
|
|
with self.lock:
|
|
|
self.file.write(content + '\n')
|
|
|
self.file.flush()
|
|
|
self.counter += 1
|
|
|
|
|
|
def get_progress(self):
|
|
|
with self.lock:
|
|
|
return self.counter
|
|
|
|
|
|
def close(self):
|
|
|
self.file.close()
|
|
|
|
|
|
class DeepSeekBatchProcessor:
|
|
|
def __init__(self, max_workers: int = 100):
|
|
|
self.client = openai.OpenAI(
|
|
|
api_key=os.getenv("DEEPSEEK_API_KEY", "sk-4da7e956235447e3b7bec1b51f5a3db7"),
|
|
|
base_url="https://api.deepseek.com"
|
|
|
)
|
|
|
self.max_workers = max_workers
|
|
|
self.error_flag = threading.Event()
|
|
|
self.rate_limiter = threading.Semaphore(20)
|
|
|
|
|
|
def process_batch(self, batch: List[Tuple[int, str]], writer: ThreadSafeWriter):
|
|
|
"""批量处理并保持顺序"""
|
|
|
futures = []
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
|
for line_num, original in batch:
|
|
|
if self.error_flag.is_set():
|
|
|
break
|
|
|
futures.append(
|
|
|
executor.submit(
|
|
|
self._process_single_line,
|
|
|
line_num,
|
|
|
original,
|
|
|
writer
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
for future in futures:
|
|
|
future.result()
|
|
|
|
|
|
def _process_single_line(self, line_num: int, original: str, writer: ThreadSafeWriter):
|
|
|
if self.error_flag.is_set():
|
|
|
return
|
|
|
|
|
|
retries = 0
|
|
|
while retries < 3 and not self.error_flag.is_set():
|
|
|
try:
|
|
|
with self.rate_limiter:
|
|
|
response = self.client.chat.completions.create(
|
|
|
model="deepseek-reasoner",
|
|
|
messages=[
|
|
|
{"role": "system", "content": self._get_prompt()},
|
|
|
{"role": "user", "content": original}
|
|
|
],
|
|
|
temperature=0.1,
|
|
|
max_tokens=30
|
|
|
)
|
|
|
|
|
|
|
|
|
keywords = self._parse_response(response)
|
|
|
formatted_line = FormatValidator.validate_line(keywords, original)
|
|
|
writer.write_line(formatted_line)
|
|
|
|
|
|
|
|
|
progress = writer.get_progress()
|
|
|
print(f"\r已处理 {progress} 条", end='')
|
|
|
|
|
|
break
|
|
|
|
|
|
except APIError as e:
|
|
|
if e.status_code == 402:
|
|
|
print(f"\n行 {line_num} 处理失败:API余额不足")
|
|
|
self.error_flag.set()
|
|
|
return
|
|
|
elif e.status_code == 429:
|
|
|
print(f"\n行 {line_num} 速率受限,重试中...")
|
|
|
retries += 1
|
|
|
if retries >= 3:
|
|
|
print(f"行 {line_num} 重试次数耗尽")
|
|
|
else:
|
|
|
print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
|
|
|
return
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\n行 {line_num} 处理异常:{str(e)}")
|
|
|
retries += 1
|
|
|
if retries >= 3:
|
|
|
print(f"行 {line_num} 重试次数耗尽")
|
|
|
|
|
|
|
|
|
if retries >= 3 and not self.error_flag.is_set():
|
|
|
writer.write_line(f"处理失败:{original}")
|
|
|
|
|
|
@staticmethod
|
|
|
def _get_prompt() -> str:
|
|
|
return
|
|
|
|
|
|
@staticmethod
|
|
|
def _parse_response(response) -> List[str]:
|
|
|
content = response.choices[0].message.content.strip()
|
|
|
return [kw.strip("。、") for kw in content.replace(',', ',').split(',') if kw]
|
|
|
|
|
|
def process_large_file(
|
|
|
input_path: str,
|
|
|
output_path: str,
|
|
|
batch_size: int = 500,
|
|
|
max_workers: int = 100
|
|
|
):
|
|
|
"""大文件处理入口"""
|
|
|
|
|
|
processor = DeepSeekBatchProcessor(max_workers)
|
|
|
writer = ThreadSafeWriter(output_path)
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
batches = []
|
|
|
current_batch = []
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
if line.strip():
|
|
|
current_batch.append( (line_num, line.strip()) )
|
|
|
if len(current_batch) >= batch_size:
|
|
|
batches.append(current_batch)
|
|
|
current_batch = []
|
|
|
if current_batch:
|
|
|
batches.append(current_batch)
|
|
|
|
|
|
|
|
|
total = sum(len(b) for b in batches)
|
|
|
print(f"总数据量:{total}条")
|
|
|
|
|
|
for batch in batches:
|
|
|
if processor.error_flag.is_set():
|
|
|
break
|
|
|
processor.process_batch(batch, writer)
|
|
|
|
|
|
print("\n处理完成!")
|
|
|
|
|
|
finally:
|
|
|
writer.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
input_file = "data\DSdata.txt"
|
|
|
output_file = "data\CoTdata.txt"
|
|
|
|
|
|
|
|
|
process_large_file(
|
|
|
input_path=input_file,
|
|
|
output_path=output_file,
|
|
|
batch_size=500,
|
|
|
max_workers=100
|
|
|
) |