#!/usr/bin/env python3 import argparse import json import os import re import time from typing import Dict, Any, Tuple from openai import OpenAI from tqdm import tqdm def load_prompt_template(path: str) -> str: with open(path, "r", encoding="utf-8") as f: return f.read() def load_api_key_from_json(path: str, key_name: str) -> str: with open(path, "r", encoding="utf-8") as f: data = json.load(f) api_key = data.get(key_name, "") if not api_key: raise SystemExit(f"API key '{key_name}' not found in {path}.") return api_key def build_prompt(template: str, src_text: str, target_language: str, target_translation: str) -> str: return ( template.replace("{SRC_TEXT}", src_text) .replace("{TARGET_LANGUAGE}", target_language) .replace("{TARGET_TRANSLATION}", target_translation) ) def extract_json(text: str) -> Dict[str, Any]: try: return json.loads(text) except json.JSONDecodeError: match = re.search(r"\{.*\}", text, re.DOTALL) if not match: raise return json.loads(match.group(0)) def call_gpt5(client: OpenAI, model: str, prompt: str, max_retries: int = 5) -> Dict[str, Any]: last_err = None for attempt in range(1, max_retries + 1): try: resp = client.responses.create( model=model, input=[{"role": "user", "content": prompt}], ) return extract_json(resp.output_text) except Exception as err: last_err = err sleep_s = min(2 ** attempt, 30) time.sleep(sleep_s) raise last_err def process_record( client: OpenAI, model: str, template: str, target_language: str, record: Dict[str, Any], src_key: str, tgt_key: str, out_key: str, ) -> Tuple[str, Dict[str, Any]]: src_text = record.get(src_key, "") tgt_text = record.get(tgt_key, "") if not src_text or not tgt_text: return out_key, {"translated_text": tgt_text} prompt = build_prompt(template, src_text, target_language, tgt_text) return out_key, call_gpt5(client, model, prompt) def write_batch(output_dir: str, base_name: str, batch_start: int, batch_end: int, batch: list) -> None: os.makedirs(output_dir, exist_ok=True) out_name = f"{base_name}_{batch_start:04d}_{batch_end - 1:04d}.json" out_path = os.path.join(output_dir, out_name) with open(out_path, "w", encoding="utf-8") as out_f: json.dump(batch, out_f, ensure_ascii=False, indent=2) def main() -> None: parser = argparse.ArgumentParser(description="GPT-5 translation correction runner") parser.add_argument( "--input", default="/home/mshahidul/readctrl/data/translated_data/translation_wo_judge/multiclinsum_gs_train_en2bn_gemma(0_200).json", help="Path to input JSON file", ) parser.add_argument( "--output-dir", default="/home/mshahidul/readctrl/data/translated_data/dataset_correction_gpt5", help="Output directory (writes one file per 2 instances)", ) parser.add_argument( "--batch-size", type=int, default=2, help="Number of instances per output file", ) parser.add_argument( "--prompt", default="/home/mshahidul/readctrl/prompts/translation_correction_prompt", help="Path to prompt template", ) parser.add_argument( "--target-language", default="Bengali", help="Target language name", ) parser.add_argument( "--model", default="gpt-5", help="OpenAI model name", ) parser.add_argument( "--api-json", default="/home/mshahidul/api_new.json", help="Path to JSON file containing API keys", ) parser.add_argument( "--api-json-key", default="openai", help="Key name inside the JSON file", ) parser.add_argument( "--start", type=int, default=0, help="Start index (0-based)", ) parser.add_argument( "--end", type=int, default=None, help="End index (exclusive)", ) args = parser.parse_args() api_key = os.getenv("OPENAI_API_KEY") if not api_key: api_key = load_api_key_from_json(args.api_json, args.api_json_key) client = OpenAI(api_key=api_key) with open(args.input, "r", encoding="utf-8") as f: data = json.load(f) template = load_prompt_template(args.prompt) src_map = { "translated_fulltext": "fulltext", "translated_summary": "summary", } out_map = { "translated_fulltext": "corrected_translated_fulltext", "translated_summary": "corrected_translated_summary", } start = args.start end = args.end if args.end is not None else len(data) base_name = os.path.splitext(os.path.basename(args.input))[0] batch_start = start batch = [] for idx in tqdm(range(start, min(end, len(data))), desc="Processing", unit="item"): record = data[idx] for tgt_key, src_key in src_map.items(): out_key = out_map[tgt_key] if out_key in record: continue out_key, result = process_record( client, args.model, template, args.target_language, record, src_key, tgt_key, out_key, ) record[out_key] = result.get("translated_text", record.get(tgt_key, "")) batch.append(record) if len(batch) >= args.batch_size: write_batch(args.output_dir, base_name, batch_start, idx + 1, batch) batch = [] batch_start = idx + 1 if batch: write_batch(args.output_dir, base_name, batch_start, min(end, len(data)), batch) if __name__ == "__main__": main()