|
|
import os
|
|
|
import openai
|
|
|
import threading
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from openai import APIError
|
|
|
|
|
|
API_KEY = os.getenv("DEEPSEEK_API_KEY", "your_api_key")
|
|
|
|
|
|
class ThreadSafeWriter:
|
|
|
"""线程安全写入器"""
|
|
|
def __init__(self, output_path: str):
|
|
|
self.file = open(output_path, 'a+', encoding='utf-8')
|
|
|
self.lock = threading.Lock()
|
|
|
self.counter = 0
|
|
|
|
|
|
def write_line(self, content: str):
|
|
|
with self.lock:
|
|
|
self.file.write(content + '\n')
|
|
|
self.file.flush()
|
|
|
self.counter += 1
|
|
|
|
|
|
def get_progress(self):
|
|
|
with self.lock:
|
|
|
return self.counter
|
|
|
|
|
|
def close(self):
|
|
|
self.file.close()
|
|
|
|
|
|
class DeepSeekBatchProcessor:
|
|
|
def __init__(self, max_workers: int = 100):
|
|
|
self.client = openai.OpenAI(
|
|
|
api_key=API_KEY,
|
|
|
base_url="https://api.deepseek.com/v1"
|
|
|
)
|
|
|
self.max_workers = max_workers
|
|
|
self.error_flag = threading.Event()
|
|
|
self.rate_limiter = threading.Semaphore(20)
|
|
|
|
|
|
def process_batch(self, batch, writer: ThreadSafeWriter):
|
|
|
"""批量处理,每个任务单独线程"""
|
|
|
futures = []
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
|
for line_num, line in batch:
|
|
|
if self.error_flag.is_set():
|
|
|
break
|
|
|
futures.append(
|
|
|
executor.submit(
|
|
|
self._process_single_line,
|
|
|
line_num,
|
|
|
line,
|
|
|
writer
|
|
|
)
|
|
|
)
|
|
|
for future in futures:
|
|
|
future.result()
|
|
|
|
|
|
def _process_single_line(self, line_num: int, line: str, writer: ThreadSafeWriter):
|
|
|
if self.error_flag.is_set():
|
|
|
return
|
|
|
|
|
|
|
|
|
separator = None
|
|
|
if ':' in line:
|
|
|
separator = ':'
|
|
|
elif ':' in line:
|
|
|
separator = ':'
|
|
|
|
|
|
if not separator:
|
|
|
print(f"\n行 {line_num} 格式错误")
|
|
|
writer.write_line(f"格式错误:{line}")
|
|
|
return
|
|
|
|
|
|
keywords_part, original_text = line.split(separator, 1)
|
|
|
|
|
|
keywords = [kw.strip() for kw in keywords_part.split(",") if kw.strip()]
|
|
|
if not keywords:
|
|
|
keywords = ["无关键词"]
|
|
|
|
|
|
|
|
|
prompt = "请根据以下关键词写一首诗:" + ",".join(keywords)
|
|
|
messages = [{"role": "user", "content": prompt}]
|
|
|
|
|
|
retries = 0
|
|
|
while retries < 3 and not self.error_flag.is_set():
|
|
|
try:
|
|
|
with self.rate_limiter:
|
|
|
response = self.client.chat.completions.create(
|
|
|
model="deepseek-reasoner",
|
|
|
messages=messages,
|
|
|
temperature=0.1
|
|
|
)
|
|
|
|
|
|
reasoning_content = response.choices[0].message.reasoning_content.replace('\n', '').replace('\r', '')
|
|
|
poem_original = response.choices[0].message.content.replace('\n', '/').replace('\r', '')
|
|
|
|
|
|
final_line = f"{','.join(keywords)}<think>{reasoning_content}</think>:{poem_original}"
|
|
|
writer.write_line(final_line)
|
|
|
progress = writer.get_progress()
|
|
|
print(f"\r已处理 {progress} 条", end='')
|
|
|
break
|
|
|
|
|
|
except APIError as e:
|
|
|
if e.status_code == 402:
|
|
|
print(f"\n行 {line_num} 处理失败:API余额不足")
|
|
|
self.error_flag.set()
|
|
|
return
|
|
|
elif e.status_code == 429:
|
|
|
print(f"\n行 {line_num} 速率受限,重试中...")
|
|
|
retries += 1
|
|
|
if retries >= 3:
|
|
|
print(f"\n行 {line_num} 重试次数耗尽")
|
|
|
else:
|
|
|
print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
|
|
|
return
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\n行 {line_num} 处理异常:{str(e)}")
|
|
|
retries += 1
|
|
|
if retries >= 3:
|
|
|
print(f"\n行 {line_num} 重试次数耗尽")
|
|
|
|
|
|
if retries >= 3 and not self.error_flag.is_set():
|
|
|
writer.write_line(f"处理失败:{line}")
|
|
|
|
|
|
def process_first_1000_lines(input_path: str, output_path: str, max_workers: int = 100):
|
|
|
"""仅读取前1000行数据,并使用多线程处理"""
|
|
|
processor = DeepSeekBatchProcessor(max_workers)
|
|
|
writer = ThreadSafeWriter(output_path)
|
|
|
batch = []
|
|
|
try:
|
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
if not line.strip():
|
|
|
continue
|
|
|
batch.append( (line_num, line.strip()) )
|
|
|
if line_num >= 1000:
|
|
|
break
|
|
|
|
|
|
total = len(batch)
|
|
|
print(f"总数据量:{total} 条")
|
|
|
processor.process_batch(batch, writer)
|
|
|
print("\n处理完成!")
|
|
|
finally:
|
|
|
writer.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
input_file = "data/DSdata.txt"
|
|
|
output_file = "data/CoTdata.txt"
|
|
|
process_first_1000_lines(input_file, output_file, max_workers=100)
|
|
|
|