| | |
| | import argparse |
| | import json |
| | import os |
| | import re |
| | import time |
| | from typing import Dict, Any, Tuple |
| |
|
| | from openai import OpenAI |
| | from tqdm import tqdm |
| |
|
| |
|
| | def load_prompt_template(path: str) -> str: |
| | with open(path, "r", encoding="utf-8") as f: |
| | return f.read() |
| |
|
| |
|
| | def load_api_key_from_json(path: str, key_name: str) -> str: |
| | with open(path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | api_key = data.get(key_name, "") |
| | if not api_key: |
| | raise SystemExit(f"API key '{key_name}' not found in {path}.") |
| | return api_key |
| |
|
| |
|
| | def build_prompt(template: str, src_text: str, target_language: str, target_translation: str) -> str: |
| | return ( |
| | template.replace("{SRC_TEXT}", src_text) |
| | .replace("{TARGET_LANGUAGE}", target_language) |
| | .replace("{TARGET_TRANSLATION}", target_translation) |
| | ) |
| |
|
| |
|
| | def extract_json(text: str) -> Dict[str, Any]: |
| | try: |
| | return json.loads(text) |
| | except json.JSONDecodeError: |
| | match = re.search(r"\{.*\}", text, re.DOTALL) |
| | if not match: |
| | raise |
| | return json.loads(match.group(0)) |
| |
|
| |
|
| | def call_gpt5(client: OpenAI, model: str, prompt: str, max_retries: int = 5) -> Dict[str, Any]: |
| | last_err = None |
| | for attempt in range(1, max_retries + 1): |
| | try: |
| | resp = client.responses.create( |
| | model=model, |
| | input=[{"role": "user", "content": prompt}], |
| | ) |
| | return extract_json(resp.output_text) |
| | except Exception as err: |
| | last_err = err |
| | sleep_s = min(2 ** attempt, 30) |
| | time.sleep(sleep_s) |
| | raise last_err |
| |
|
| |
|
| | def process_record( |
| | client: OpenAI, |
| | model: str, |
| | template: str, |
| | target_language: str, |
| | record: Dict[str, Any], |
| | src_key: str, |
| | tgt_key: str, |
| | out_key: str, |
| | ) -> Tuple[str, Dict[str, Any]]: |
| | src_text = record.get(src_key, "") |
| | tgt_text = record.get(tgt_key, "") |
| | if not src_text or not tgt_text: |
| | return out_key, {"translated_text": tgt_text} |
| | prompt = build_prompt(template, src_text, target_language, tgt_text) |
| | return out_key, call_gpt5(client, model, prompt) |
| |
|
| |
|
| | def write_batch(output_dir: str, base_name: str, batch_start: int, batch_end: int, batch: list) -> None: |
| | os.makedirs(output_dir, exist_ok=True) |
| | out_name = f"{base_name}_{batch_start:04d}_{batch_end - 1:04d}.json" |
| | out_path = os.path.join(output_dir, out_name) |
| | with open(out_path, "w", encoding="utf-8") as out_f: |
| | json.dump(batch, out_f, ensure_ascii=False, indent=2) |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser(description="GPT-5 translation correction runner") |
| | parser.add_argument( |
| | "--input", |
| | default="/home/mshahidul/readctrl/data/translated_data/translation_wo_judge/multiclinsum_gs_train_en2bn_gemma(0_200).json", |
| | help="Path to input JSON file", |
| | ) |
| | parser.add_argument( |
| | "--output-dir", |
| | default="/home/mshahidul/readctrl/data/translated_data/dataset_correction_gpt5", |
| | help="Output directory (writes one file per 2 instances)", |
| | ) |
| | parser.add_argument( |
| | "--batch-size", |
| | type=int, |
| | default=2, |
| | help="Number of instances per output file", |
| | ) |
| | parser.add_argument( |
| | "--prompt", |
| | default="/home/mshahidul/readctrl/prompts/translation_correction_prompt", |
| | help="Path to prompt template", |
| | ) |
| | parser.add_argument( |
| | "--target-language", |
| | default="Bengali", |
| | help="Target language name", |
| | ) |
| | parser.add_argument( |
| | "--model", |
| | default="gpt-5", |
| | help="OpenAI model name", |
| | ) |
| | parser.add_argument( |
| | "--api-json", |
| | default="/home/mshahidul/api_new.json", |
| | help="Path to JSON file containing API keys", |
| | ) |
| | parser.add_argument( |
| | "--api-json-key", |
| | default="openai", |
| | help="Key name inside the JSON file", |
| | ) |
| | parser.add_argument( |
| | "--start", |
| | type=int, |
| | default=0, |
| | help="Start index (0-based)", |
| | ) |
| | parser.add_argument( |
| | "--end", |
| | type=int, |
| | default=None, |
| | help="End index (exclusive)", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | api_key = os.getenv("OPENAI_API_KEY") |
| | if not api_key: |
| | api_key = load_api_key_from_json(args.api_json, args.api_json_key) |
| | client = OpenAI(api_key=api_key) |
| |
|
| | with open(args.input, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| |
|
| | template = load_prompt_template(args.prompt) |
| |
|
| | src_map = { |
| | "translated_fulltext": "fulltext", |
| | "translated_summary": "summary", |
| | } |
| | out_map = { |
| | "translated_fulltext": "corrected_translated_fulltext", |
| | "translated_summary": "corrected_translated_summary", |
| | } |
| |
|
| | start = args.start |
| | end = args.end if args.end is not None else len(data) |
| |
|
| | base_name = os.path.splitext(os.path.basename(args.input))[0] |
| | batch_start = start |
| | batch = [] |
| |
|
| | for idx in tqdm(range(start, min(end, len(data))), desc="Processing", unit="item"): |
| | record = data[idx] |
| | for tgt_key, src_key in src_map.items(): |
| | out_key = out_map[tgt_key] |
| | if out_key in record: |
| | continue |
| | out_key, result = process_record( |
| | client, |
| | args.model, |
| | template, |
| | args.target_language, |
| | record, |
| | src_key, |
| | tgt_key, |
| | out_key, |
| | ) |
| | record[out_key] = result.get("translated_text", record.get(tgt_key, "")) |
| |
|
| | batch.append(record) |
| |
|
| | if len(batch) >= args.batch_size: |
| | write_batch(args.output_dir, base_name, batch_start, idx + 1, batch) |
| | batch = [] |
| | batch_start = idx + 1 |
| |
|
| | if batch: |
| | write_batch(args.output_dir, base_name, batch_start, min(end, len(data)), batch) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|