dataset-builder / data2 /step22 /gemini_generation.py
DouDou
Upload data2/step22/gemini_generation.py with huggingface_hub
48128bb verified
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import jsonlines
from tqdm import tqdm
import vertexai
from vertexai.generative_models import GenerativeModel
# ======================
# 1. 初始化
# ======================
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/weifengsun/tangou1/step2/gemini.json"
vertexai.init(project="tangou")
model = GenerativeModel("gemini-2.5-flash-lite")
# ======================
# 2. Gemini 价格配置
# ======================
PRICE_INPUT_PER_1M = 0.1
PRICE_OUTPUT_PER_1M = 0.4
# ======================
# 3. 全局消耗 & 熔断器
# ======================
class APIMonitor:
def __init__(self, max_usd: float):
self.max_usd = max_usd
self.input_tokens = 0
self.output_tokens = 0
self.total_cost = 0.0
self.lock = threading.Lock()
self.stop_event = threading.Event()
self.start_time = time.time()
@staticmethod
def estimate_tokens(text: str) -> int:
return max(1, len(text) // 3)
def reserve_input(self, prompt: str):
est = self.estimate_tokens(prompt)
est_cost = est / 1_000_000 * PRICE_INPUT_PER_1M
with self.lock:
if self.total_cost + est_cost > self.max_usd:
self.stop_event.set()
raise RuntimeError("💥 API budget exceeded (input)")
self.input_tokens += est
self.total_cost += est_cost
def record_output(self, text: str):
est = self.estimate_tokens(text)
cost = est / 1_000_000 * PRICE_OUTPUT_PER_1M
with self.lock:
self.output_tokens += est
self.total_cost += cost
if self.total_cost > self.max_usd:
self.stop_event.set()
raise RuntimeError("💥 API budget exceeded (output)")
def snapshot(self):
with self.lock:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_cost": round(self.total_cost, 6),
"elapsed": round(time.time() - self.start_time, 2),
}
monitor = APIMonitor(max_usd=100.0)
# ======================
# 4. 推理函数
# ======================
def infer_one(prompt: str, idx: int):
if monitor.stop_event.is_set():
return {"idx": idx, "status": "stopped", "output": ""}
try:
monitor.reserve_input(prompt)
resp = model.generate_content(prompt)
text = resp.text or ""
monitor.record_output(text)
return {"idx": idx, "status": "ok", "output": text}
except Exception as e:
return {"idx": idx, "status": "error", "error": str(e)}
# ======================
# 5. 读取输入 prompt
# ======================
prompt_template = Path("prompt.txt").read_text(encoding="utf-8")
input_file = "/home/weifengsun/tangou1/step2/step22/output/function_filtered_scores.jsonl"
inputs = []
amount = 500000
with jsonlines.open(input_file, "r") as reader:
for obj in reader:
if amount == 0:
break
amount -= 1
prompt = prompt_template.replace("<<<CODE>>>", obj["code_content"]).replace(
"<<<README>>>", obj["md_summary"]
)
inputs.append(prompt)
# ======================
# 6. 断点续跑
# ======================
output_file = "/home/weifengsun/tangou1/step2/step22/output/gemini_results.jsonl"
completed_idx = set()
if os.path.exists(output_file):
with jsonlines.open(output_file, "r") as reader:
for obj in reader:
completed_idx.add(obj["idx"])
# 只处理未完成的
tasks = [(idx, prompt) for idx, prompt in enumerate(inputs) if idx not in completed_idx]
total_tasks = len(inputs)
remaining_tasks = len(tasks)
print(f"Total: {total_tasks}, Completed: {len(completed_idx)}, Remaining: {remaining_tasks}")
# ======================
# 7. 并行执行 + 即时写入 + 进度条
# ======================
write_lock = threading.Lock()
MAX_WORKERS = 8
with jsonlines.open(output_file, mode="a", flush=True) as writer, ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(infer_one, prompt, idx): idx for idx, prompt in tasks}
pbar = tqdm(total=remaining_tasks, desc="Generating", unit="item")
for future in as_completed(futures):
result = future.result()
# 写入 JSONL
with write_lock:
writer.write(result)
# 更新进度条
pbar.update(1)
# 显示 ETA 与成本
snap = monitor.snapshot()
pbar.set_postfix({
"cost": f"${snap['total_cost']}",
"in_tok": snap["input_tokens"],
"out_tok": snap["output_tokens"],
"elapsed_s": snap["elapsed"]
})
# 超预算停止
if monitor.stop_event.is_set():
print("🛑 Budget limit reached. Stopping all requests.")
break
pbar.close()
print("✅ All done.")