File size: 5,729 Bytes
38d8dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import openai
import threading
from concurrent.futures import ThreadPoolExecutor
from openai import APIError

API_KEY = os.getenv("DEEPSEEK_API_KEY", "your_api_key")

class ThreadSafeWriter:
    """线程安全写入器"""
    def __init__(self, output_path: str):
        self.file = open(output_path, 'a+', encoding='utf-8')
        self.lock = threading.Lock()
        self.counter = 0

    def write_line(self, content: str):
        with self.lock:
            self.file.write(content + '\n')
            self.file.flush()
            self.counter += 1

    def get_progress(self):
        with self.lock:
            return self.counter

    def close(self):
        self.file.close()

class DeepSeekBatchProcessor:
    def __init__(self, max_workers: int = 100):
        self.client = openai.OpenAI(
            api_key=API_KEY,
            base_url="https://api.deepseek.com/v1"
        )
        self.max_workers = max_workers
        self.error_flag = threading.Event()
        self.rate_limiter = threading.Semaphore(20)

    def process_batch(self, batch, writer: ThreadSafeWriter):
        """批量处理,每个任务单独线程"""
        futures = []
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            for line_num, line in batch:
                if self.error_flag.is_set():
                    break
                futures.append(
                    executor.submit(
                        self._process_single_line,
                        line_num,
                        line,
                        writer
                    )
                )
            for future in futures:
                future.result()

    def _process_single_line(self, line_num: int, line: str, writer: ThreadSafeWriter):
        if self.error_flag.is_set():
            return

        # 支持英文冒号(:)和中文全角冒号(:)
        separator = None
        if ':' in line:
            separator = ':'
        elif ':' in line:
            separator = ':'

        if not separator:
            print(f"\n行 {line_num} 格式错误")
            writer.write_line(f"格式错误:{line}")
            return

        keywords_part, original_text = line.split(separator, 1)
        # 这里只提取关键词部分(例如“风,雾,寂寞”)
        keywords = [kw.strip() for kw in keywords_part.split(",") if kw.strip()]
        if not keywords:
            keywords = ["无关键词"]

        # 构造提示:根据关键词生成诗歌
        prompt = "请根据以下关键词写一首诗:" + ",".join(keywords)
        messages = [{"role": "user", "content": prompt}]

        retries = 0
        while retries < 3 and not self.error_flag.is_set():
            try:
                with self.rate_limiter:
                    response = self.client.chat.completions.create(
                        model="deepseek-reasoner",
                        messages=messages,
                        temperature=0.1
                    )
                # 提取返回中的思考过程和诗歌原文
                reasoning_content = response.choices[0].message.reasoning_content.replace('\n', '').replace('\r', '')
                poem_original = response.choices[0].message.content.replace('\n', '/').replace('\r', '')
                # 拼接最终结果:关键词<think>思考过程</think>:诗歌原文
                final_line = f"{','.join(keywords)}<think>{reasoning_content}</think>:{poem_original}"
                writer.write_line(final_line)
                progress = writer.get_progress()
                print(f"\r已处理 {progress} 条", end='')
                break

            except APIError as e:
                if e.status_code == 402:
                    print(f"\n行 {line_num} 处理失败:API余额不足")
                    self.error_flag.set()
                    return
                elif e.status_code == 429:
                    print(f"\n行 {line_num} 速率受限,重试中...")
                    retries += 1
                    if retries >= 3:
                        print(f"\n行 {line_num} 重试次数耗尽")
                else:
                    print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
                    return

            except Exception as e:
                print(f"\n行 {line_num} 处理异常:{str(e)}")
                retries += 1
                if retries >= 3:
                    print(f"\n行 {line_num} 重试次数耗尽")

        if retries >= 3 and not self.error_flag.is_set():
            writer.write_line(f"处理失败:{line}")

def process_first_1000_lines(input_path: str, output_path: str, max_workers: int = 100):
    """仅读取前1000行数据,并使用多线程处理"""
    processor = DeepSeekBatchProcessor(max_workers)
    writer = ThreadSafeWriter(output_path)
    batch = []
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                if not line.strip():
                    continue
                batch.append( (line_num, line.strip()) )
                if line_num >= 1000:
                    break

        total = len(batch)
        print(f"总数据量:{total} 条")
        processor.process_batch(batch, writer)
        print("\n处理完成!")
    finally:
        writer.close()

if __name__ == '__main__':
    input_file = "data/DSdata.txt"
    output_file = "data/CoTdata.txt"
    process_first_1000_lines(input_file, output_file, max_workers=100)