import os import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import jsonlines from tqdm import tqdm import vertexai from vertexai.generative_models import GenerativeModel # ====================== # 1. 初始化 # ====================== os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/weifengsun/tangou1/step2/gemini.json" vertexai.init(project="tangou") model = GenerativeModel("gemini-2.5-flash-lite") # ====================== # 2. Gemini 价格配置 # ====================== PRICE_INPUT_PER_1M = 0.1 PRICE_OUTPUT_PER_1M = 0.4 # ====================== # 3. 全局消耗 & 熔断器 # ====================== class APIMonitor: def __init__(self, max_usd: float): self.max_usd = max_usd self.input_tokens = 0 self.output_tokens = 0 self.total_cost = 0.0 self.lock = threading.Lock() self.stop_event = threading.Event() self.start_time = time.time() @staticmethod def estimate_tokens(text: str) -> int: return max(1, len(text) // 3) def reserve_input(self, prompt: str): est = self.estimate_tokens(prompt) est_cost = est / 1_000_000 * PRICE_INPUT_PER_1M with self.lock: if self.total_cost + est_cost > self.max_usd: self.stop_event.set() raise RuntimeError("💥 API budget exceeded (input)") self.input_tokens += est self.total_cost += est_cost def record_output(self, text: str): est = self.estimate_tokens(text) cost = est / 1_000_000 * PRICE_OUTPUT_PER_1M with self.lock: self.output_tokens += est self.total_cost += cost if self.total_cost > self.max_usd: self.stop_event.set() raise RuntimeError("💥 API budget exceeded (output)") def snapshot(self): with self.lock: return { "input_tokens": self.input_tokens, "output_tokens": self.output_tokens, "total_cost": round(self.total_cost, 6), "elapsed": round(time.time() - self.start_time, 2), } monitor = APIMonitor(max_usd=100.0) # ====================== # 4. 推理函数 # ====================== def infer_one(prompt: str, idx: int): if monitor.stop_event.is_set(): return {"idx": idx, "status": "stopped", "output": ""} try: monitor.reserve_input(prompt) resp = model.generate_content(prompt) text = resp.text or "" monitor.record_output(text) return {"idx": idx, "status": "ok", "output": text} except Exception as e: return {"idx": idx, "status": "error", "error": str(e)} # ====================== # 5. 读取输入 prompt # ====================== prompt_template = Path("prompt.txt").read_text(encoding="utf-8") input_file = "/home/weifengsun/tangou1/step2/step22/output/function_filtered_scores.jsonl" inputs = [] amount = 500000 with jsonlines.open(input_file, "r") as reader: for obj in reader: if amount == 0: break amount -= 1 prompt = prompt_template.replace("<<>>", obj["code_content"]).replace( "<<>>", obj["md_summary"] ) inputs.append(prompt) # ====================== # 6. 断点续跑 # ====================== output_file = "/home/weifengsun/tangou1/step2/step22/output/gemini_results.jsonl" completed_idx = set() if os.path.exists(output_file): with jsonlines.open(output_file, "r") as reader: for obj in reader: completed_idx.add(obj["idx"]) # 只处理未完成的 tasks = [(idx, prompt) for idx, prompt in enumerate(inputs) if idx not in completed_idx] total_tasks = len(inputs) remaining_tasks = len(tasks) print(f"Total: {total_tasks}, Completed: {len(completed_idx)}, Remaining: {remaining_tasks}") # ====================== # 7. 并行执行 + 即时写入 + 进度条 # ====================== write_lock = threading.Lock() MAX_WORKERS = 8 with jsonlines.open(output_file, mode="a", flush=True) as writer, ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: futures = {executor.submit(infer_one, prompt, idx): idx for idx, prompt in tasks} pbar = tqdm(total=remaining_tasks, desc="Generating", unit="item") for future in as_completed(futures): result = future.result() # 写入 JSONL with write_lock: writer.write(result) # 更新进度条 pbar.update(1) # 显示 ETA 与成本 snap = monitor.snapshot() pbar.set_postfix({ "cost": f"${snap['total_cost']}", "in_tok": snap["input_tokens"], "out_tok": snap["output_tokens"], "elapsed_s": snap["elapsed"] }) # 超预算停止 if monitor.stop_event.is_set(): print("🛑 Budget limit reached. Stopping all requests.") break pbar.close() print("✅ All done.")