File size: 5,018 Bytes
48128bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import jsonlines
from tqdm import tqdm

import vertexai
from vertexai.generative_models import GenerativeModel

# ======================
# 1. 初始化
# ======================
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/weifengsun/tangou1/step2/gemini.json"
vertexai.init(project="tangou")
model = GenerativeModel("gemini-2.5-flash-lite")

# ======================
# 2. Gemini 价格配置
# ======================
PRICE_INPUT_PER_1M = 0.1
PRICE_OUTPUT_PER_1M = 0.4

# ======================
# 3. 全局消耗 & 熔断器
# ======================
class APIMonitor:
    def __init__(self, max_usd: float):
        self.max_usd = max_usd
        self.input_tokens = 0
        self.output_tokens = 0
        self.total_cost = 0.0
        self.lock = threading.Lock()
        self.stop_event = threading.Event()
        self.start_time = time.time()

    @staticmethod
    def estimate_tokens(text: str) -> int:
        return max(1, len(text) // 3)

    def reserve_input(self, prompt: str):
        est = self.estimate_tokens(prompt)
        est_cost = est / 1_000_000 * PRICE_INPUT_PER_1M
        with self.lock:
            if self.total_cost + est_cost > self.max_usd:
                self.stop_event.set()
                raise RuntimeError("💥 API budget exceeded (input)")
            self.input_tokens += est
            self.total_cost += est_cost

    def record_output(self, text: str):
        est = self.estimate_tokens(text)
        cost = est / 1_000_000 * PRICE_OUTPUT_PER_1M
        with self.lock:
            self.output_tokens += est
            self.total_cost += cost
            if self.total_cost > self.max_usd:
                self.stop_event.set()
                raise RuntimeError("💥 API budget exceeded (output)")

    def snapshot(self):
        with self.lock:
            return {
                "input_tokens": self.input_tokens,
                "output_tokens": self.output_tokens,
                "total_cost": round(self.total_cost, 6),
                "elapsed": round(time.time() - self.start_time, 2),
            }

monitor = APIMonitor(max_usd=100.0)

# ======================
# 4. 推理函数
# ======================
def infer_one(prompt: str, idx: int):
    if monitor.stop_event.is_set():
        return {"idx": idx, "status": "stopped", "output": ""}

    try:
        monitor.reserve_input(prompt)
        resp = model.generate_content(prompt)
        text = resp.text or ""
        monitor.record_output(text)
        return {"idx": idx, "status": "ok", "output": text}
    except Exception as e:
        return {"idx": idx, "status": "error", "error": str(e)}

# ======================
# 5. 读取输入 prompt
# ======================
prompt_template = Path("prompt.txt").read_text(encoding="utf-8")
input_file = "/home/weifengsun/tangou1/step2/step22/output/function_filtered_scores.jsonl"
inputs = []

amount = 500000
with jsonlines.open(input_file, "r") as reader:
    for obj in reader:
        if amount == 0:
            break
        amount -= 1
        prompt = prompt_template.replace("<<<CODE>>>", obj["code_content"]).replace(
            "<<<README>>>", obj["md_summary"]
        )
        inputs.append(prompt)

# ======================
# 6. 断点续跑
# ======================
output_file = "/home/weifengsun/tangou1/step2/step22/output/gemini_results.jsonl"
completed_idx = set()
if os.path.exists(output_file):
    with jsonlines.open(output_file, "r") as reader:
        for obj in reader:
            completed_idx.add(obj["idx"])

# 只处理未完成的
tasks = [(idx, prompt) for idx, prompt in enumerate(inputs) if idx not in completed_idx]
total_tasks = len(inputs)
remaining_tasks = len(tasks)
print(f"Total: {total_tasks}, Completed: {len(completed_idx)}, Remaining: {remaining_tasks}")

# ======================
# 7. 并行执行 + 即时写入 + 进度条
# ======================
write_lock = threading.Lock()
MAX_WORKERS = 8

with jsonlines.open(output_file, mode="a", flush=True) as writer, ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = {executor.submit(infer_one, prompt, idx): idx for idx, prompt in tasks}

    pbar = tqdm(total=remaining_tasks, desc="Generating", unit="item")

    for future in as_completed(futures):
        result = future.result()

        # 写入 JSONL
        with write_lock:
            writer.write(result)

        # 更新进度条
        pbar.update(1)

        # 显示 ETA 与成本
        snap = monitor.snapshot()
        pbar.set_postfix({
            "cost": f"${snap['total_cost']}",
            "in_tok": snap["input_tokens"],
            "out_tok": snap["output_tokens"],
            "elapsed_s": snap["elapsed"]
        })

        # 超预算停止
        if monitor.stop_event.is_set():
            print("🛑 Budget limit reached. Stopping all requests.")
            break

    pbar.close()

print("✅ All done.")