readctrl / code /translation /translate_correction_gpt5.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
#!/usr/bin/env python3
import argparse
import json
import os
import re
import time
from typing import Dict, Any, Tuple
from openai import OpenAI
from tqdm import tqdm
def load_prompt_template(path: str) -> str:
with open(path, "r", encoding="utf-8") as f:
return f.read()
def load_api_key_from_json(path: str, key_name: str) -> str:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
api_key = data.get(key_name, "")
if not api_key:
raise SystemExit(f"API key '{key_name}' not found in {path}.")
return api_key
def build_prompt(template: str, src_text: str, target_language: str, target_translation: str) -> str:
return (
template.replace("{SRC_TEXT}", src_text)
.replace("{TARGET_LANGUAGE}", target_language)
.replace("{TARGET_TRANSLATION}", target_translation)
)
def extract_json(text: str) -> Dict[str, Any]:
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
raise
return json.loads(match.group(0))
def call_gpt5(client: OpenAI, model: str, prompt: str, max_retries: int = 5) -> Dict[str, Any]:
last_err = None
for attempt in range(1, max_retries + 1):
try:
resp = client.responses.create(
model=model,
input=[{"role": "user", "content": prompt}],
)
return extract_json(resp.output_text)
except Exception as err:
last_err = err
sleep_s = min(2 ** attempt, 30)
time.sleep(sleep_s)
raise last_err
def process_record(
client: OpenAI,
model: str,
template: str,
target_language: str,
record: Dict[str, Any],
src_key: str,
tgt_key: str,
out_key: str,
) -> Tuple[str, Dict[str, Any]]:
src_text = record.get(src_key, "")
tgt_text = record.get(tgt_key, "")
if not src_text or not tgt_text:
return out_key, {"translated_text": tgt_text}
prompt = build_prompt(template, src_text, target_language, tgt_text)
return out_key, call_gpt5(client, model, prompt)
def write_batch(output_dir: str, base_name: str, batch_start: int, batch_end: int, batch: list) -> None:
os.makedirs(output_dir, exist_ok=True)
out_name = f"{base_name}_{batch_start:04d}_{batch_end - 1:04d}.json"
out_path = os.path.join(output_dir, out_name)
with open(out_path, "w", encoding="utf-8") as out_f:
json.dump(batch, out_f, ensure_ascii=False, indent=2)
def main() -> None:
parser = argparse.ArgumentParser(description="GPT-5 translation correction runner")
parser.add_argument(
"--input",
default="/home/mshahidul/readctrl/data/translated_data/translation_wo_judge/multiclinsum_gs_train_en2bn_gemma(0_200).json",
help="Path to input JSON file",
)
parser.add_argument(
"--output-dir",
default="/home/mshahidul/readctrl/data/translated_data/dataset_correction_gpt5",
help="Output directory (writes one file per 2 instances)",
)
parser.add_argument(
"--batch-size",
type=int,
default=2,
help="Number of instances per output file",
)
parser.add_argument(
"--prompt",
default="/home/mshahidul/readctrl/prompts/translation_correction_prompt",
help="Path to prompt template",
)
parser.add_argument(
"--target-language",
default="Bengali",
help="Target language name",
)
parser.add_argument(
"--model",
default="gpt-5",
help="OpenAI model name",
)
parser.add_argument(
"--api-json",
default="/home/mshahidul/api_new.json",
help="Path to JSON file containing API keys",
)
parser.add_argument(
"--api-json-key",
default="openai",
help="Key name inside the JSON file",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index (0-based)",
)
parser.add_argument(
"--end",
type=int,
default=None,
help="End index (exclusive)",
)
args = parser.parse_args()
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
api_key = load_api_key_from_json(args.api_json, args.api_json_key)
client = OpenAI(api_key=api_key)
with open(args.input, "r", encoding="utf-8") as f:
data = json.load(f)
template = load_prompt_template(args.prompt)
src_map = {
"translated_fulltext": "fulltext",
"translated_summary": "summary",
}
out_map = {
"translated_fulltext": "corrected_translated_fulltext",
"translated_summary": "corrected_translated_summary",
}
start = args.start
end = args.end if args.end is not None else len(data)
base_name = os.path.splitext(os.path.basename(args.input))[0]
batch_start = start
batch = []
for idx in tqdm(range(start, min(end, len(data))), desc="Processing", unit="item"):
record = data[idx]
for tgt_key, src_key in src_map.items():
out_key = out_map[tgt_key]
if out_key in record:
continue
out_key, result = process_record(
client,
args.model,
template,
args.target_language,
record,
src_key,
tgt_key,
out_key,
)
record[out_key] = result.get("translated_text", record.get(tgt_key, ""))
batch.append(record)
if len(batch) >= args.batch_size:
write_batch(args.output_dir, base_name, batch_start, idx + 1, batch)
batch = []
batch_start = idx + 1
if batch:
write_batch(args.output_dir, base_name, batch_start, min(end, len(data)), batch)
if __name__ == "__main__":
main()