import os import json import asyncio import argparse import httpx from tqdm.asyncio import tqdm from transformers import AutoProcessor # ---- Configuration ---- DATA_PATH = "/home/mshahidul/readctrl/data/processed_test_raw_data/multiclinsum_test_en.json" OUT_PATH_TEMPLATE = ( "/home/mshahidul/readctrl/data/translated_data/translation_testing_3396/" "multiclinsum_test_{source_lang}2{target_lang}_gemma({start}_{end})_3396.json" ) # Chunking for long fulltext: split and merge if output is null/bad, or if text exceeds this length MAX_FULLTEXT_CHARS_BEFORE_CHUNK = 3500 MIN_TRANSLATION_RATIO = 0.15 # treat as bad if translation length < 15% of source TRANSLATE_URL = "http://127.0.0.1:8080/v1/chat/completions" CONCURRENCY_LIMIT = 8 # Matches your server's "-np" or "--parallel" value model_id = "google/translategemma-27b-it" processor = AutoProcessor.from_pretrained(model_id) semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT) async def call_llm(client, url, model, messages, temperature=0.1, max_tokens=None): """Generic async caller for both Translation and Judge.""" async with semaphore: try: payload = { "model": model, "messages": messages, "temperature": temperature } if max_tokens is not None: payload["max_tokens"] = max_tokens response = await client.post(url, json=payload, timeout=60.0) result = response.json() return result['choices'][0]['message']['content'].strip() except Exception as e: return None def split_text_into_two_chunks(text): """Split at a natural boundary (paragraph or sentence). Returns (chunk1, chunk2, separator).""" text = text.strip() if len(text) <= 1: return (text, "", "\n\n") mid = len(text) // 2 # Prefer paragraph boundary so merge preserves existing paragraph structure for sep in ("\n\n", ". ", ".\n", "! ", "!\n", "? ", "?\n"): idx = text.rfind(sep, 0, mid + 1) if idx > 0: return ( text[: idx + len(sep)].strip(), text[idx + len(sep) :].strip(), sep, ) # Fallback: split at last space before mid space_idx = text.rfind(" ", 0, mid + 1) if space_idx > 0: return (text[:space_idx].strip(), text[space_idx:].strip(), " ") return (text[:mid].strip(), text[mid:].strip(), " ") def _join_with_separator(part1, part2, sep): """Join two translated parts with the original boundary (paragraph/sentence).""" p1 = (part1 or "").strip() p2 = (part2 or "").strip() if not p1: return p2 if not p2: return p1 return p1 + sep + p2 def is_translation_acceptable(source_text, translated_text): """Return False if translation is null, empty, or clearly bad (too short/garbage).""" if translated_text is None: return False t = translated_text.strip() if not t: return False if len(source_text) > 0 and len(t) < len(source_text) * MIN_TRANSLATION_RATIO: return False return True def build_gemma_prompt(text, source_lang="en", target_lang="bn"): messages = [{ "role": "user", "content": [ { "type": "text", "source_lang_code": source_lang, "target_lang_code": target_lang, "text": text, } ], }] prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) messages=[{"role": "user", "content": prompt}] return messages async def translate_fulltext_with_chunking(client, fulltext, source_lang, target_lang, translate_url): """Translate fulltext; use two chunks and merge if text is long or first attempt fails.""" if not (fulltext or "").strip(): return None fulltext = fulltext.strip() # Proactively chunk if very long to avoid null/bad output if len(fulltext) > MAX_FULLTEXT_CHARS_BEFORE_CHUNK: chunk1, chunk2, sep = split_text_into_two_chunks(fulltext) parts = [] for chunk in (chunk1, chunk2): if not chunk.strip(): parts.append("") continue prompt = build_gemma_prompt(chunk, source_lang=source_lang, target_lang=target_lang) out = await call_llm( client, translate_url, "translate_gemma", prompt, max_tokens=4092 ) parts.append(out if out else "") merged = _join_with_separator(parts[0], parts[1], sep) return merged.strip() or None # Try full translation first prompt = build_gemma_prompt(fulltext, source_lang=source_lang, target_lang=target_lang) translated = await call_llm( client, translate_url, "translate_gemma", prompt, max_tokens=4092 ) if is_translation_acceptable(fulltext, translated): return translated # Retry with two chunks and merge using same boundary as split chunk1, chunk2, sep = split_text_into_two_chunks(fulltext) parts = [] for chunk in (chunk1, chunk2): if not chunk.strip(): parts.append("") continue prompt = build_gemma_prompt(chunk, source_lang=source_lang, target_lang=target_lang) out = await call_llm( client, translate_url, "translate_gemma", prompt, max_tokens=4092 ) parts.append(out if out else "") merged = _join_with_separator(parts[0], parts[1], sep) return merged.strip() if merged.strip() else translated # fallback to first attempt if merge empty async def process_record(client, record, source_lang, target_lang, translate_url): """Translates a single JSON record (fulltext and summary).""" fulltext = record.get("fulltext", "") summary = record.get("summary", "") # 1. Translate fulltext (with chunking for long or failed first attempt) translated_fulltext = await translate_fulltext_with_chunking( client, fulltext, source_lang, target_lang, translate_url ) # 2. Translate summary translated_summary_prompt = build_gemma_prompt( summary, source_lang=source_lang, target_lang=target_lang ) translated_summary = await call_llm( client, translate_url, "translate_gemma", translated_summary_prompt, max_tokens=1024 ) record["translated_fulltext"] = translated_fulltext record["translated_summary"] = translated_summary return record def record_key(record): record_id = record.get("id") if record_id is not None: return str(record_id) return f"{record.get('fulltext', '')}||{record.get('summary', '')}" def has_valid_translation(record): translated_fulltext = record.get("translated_fulltext") translated_summary = record.get("translated_summary") return translated_fulltext is not None and translated_summary is not None async def main(): parser = argparse.ArgumentParser(description="Translate Multiclinsum dataset.") parser.add_argument("--source-lang", default="en", help="Source language code") parser.add_argument("--target-lang", default="bn", help="Target language code") parser.add_argument( "--start-idx", type=int, default=0, help="Start index (inclusive) of the slice to translate", ) parser.add_argument( "--end-idx", type=int, default=200, help="End index (exclusive) of the slice to translate; use -1 for all", ) parser.add_argument( "--port", type=int, default=8080, help="Port for the translation API server (default: 8080)", ) args = parser.parse_args() translate_url = f"http://127.0.0.1:{args.port}/v1/chat/completions" start_idx = args.start_idx end_idx = args.end_idx with open(DATA_PATH, 'r', encoding='utf-8') as f: all_data = json.load(f) if end_idx == -1: end_idx = len(all_data) out_path = OUT_PATH_TEMPLATE.format( source_lang=args.source_lang, target_lang=args.target_lang, start=start_idx, end=end_idx, ) data = all_data[start_idx:end_idx] async with httpx.AsyncClient() as client: existing_results = [] if os.path.exists(out_path): with open(out_path, 'r', encoding='utf-8') as f: existing_results = json.load(f) existing_by_key = {record_key(rec): rec for rec in existing_results} output_results = [] batch_size = 10 max_regen = len(data) regenerated = 0 for i in tqdm(range(0, len(data), batch_size)): batch = data[i:i + batch_size] pending = [] pending_keys = [] new_generated = 0 for rec in batch: key = record_key(rec) existing = existing_by_key.get(key) if existing and has_valid_translation(existing): output_results.append(existing) else: if regenerated < max_regen: pending.append(process_record(client, rec, args.source_lang, args.target_lang, translate_url)) pending_keys.append(key) regenerated += 1 elif existing: output_results.append(existing) if pending: processed = await asyncio.gather(*pending) for key, rec in zip(pending_keys, processed): if rec is not None: existing_by_key[key] = rec output_results.append(rec) new_generated += 1 os.makedirs(os.path.dirname(out_path), exist_ok=True) with open(out_path, 'w', encoding='utf-8') as f: json.dump(output_results, f, ensure_ascii=False, indent=4) print( f"Batch {i // batch_size + 1}: new={new_generated}, total={len(output_results)}" ) if __name__ == "__main__": asyncio.run(main())