| import os |
| import json |
| import asyncio |
| import argparse |
| import httpx |
| from tqdm.asyncio import tqdm |
| from transformers import AutoProcessor |
|
|
| |
| DATA_PATH = "/home/mshahidul/readctrl/data/processed_test_raw_data/multiclinsum_test_en.json" |
| OUT_PATH_TEMPLATE = ( |
| "/home/mshahidul/readctrl/data/translated_data/translation_testing_3396/" |
| "multiclinsum_test_{source_lang}2{target_lang}_gemma({start}_{end})_3396.json" |
| ) |
|
|
| |
| MAX_FULLTEXT_CHARS_BEFORE_CHUNK = 3500 |
| MIN_TRANSLATION_RATIO = 0.15 |
|
|
| TRANSLATE_URL = "http://127.0.0.1:8080/v1/chat/completions" |
| CONCURRENCY_LIMIT = 8 |
|
|
| model_id = "google/translategemma-27b-it" |
| processor = AutoProcessor.from_pretrained(model_id) |
|
|
| semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT) |
|
|
| async def call_llm(client, url, model, messages, temperature=0.1, max_tokens=None): |
| """Generic async caller for both Translation and Judge.""" |
| async with semaphore: |
| try: |
| payload = { |
| "model": model, |
| "messages": messages, |
| "temperature": temperature |
| } |
| if max_tokens is not None: |
| payload["max_tokens"] = max_tokens |
| response = await client.post(url, json=payload, timeout=60.0) |
| result = response.json() |
| return result['choices'][0]['message']['content'].strip() |
| except Exception as e: |
| return None |
|
|
| def split_text_into_two_chunks(text): |
| """Split at a natural boundary (paragraph or sentence). Returns (chunk1, chunk2, separator).""" |
| text = text.strip() |
| if len(text) <= 1: |
| return (text, "", "\n\n") |
| mid = len(text) // 2 |
| |
| for sep in ("\n\n", ". ", ".\n", "! ", "!\n", "? ", "?\n"): |
| idx = text.rfind(sep, 0, mid + 1) |
| if idx > 0: |
| return ( |
| text[: idx + len(sep)].strip(), |
| text[idx + len(sep) :].strip(), |
| sep, |
| ) |
| |
| space_idx = text.rfind(" ", 0, mid + 1) |
| if space_idx > 0: |
| return (text[:space_idx].strip(), text[space_idx:].strip(), " ") |
| return (text[:mid].strip(), text[mid:].strip(), " ") |
|
|
|
|
| def _join_with_separator(part1, part2, sep): |
| """Join two translated parts with the original boundary (paragraph/sentence).""" |
| p1 = (part1 or "").strip() |
| p2 = (part2 or "").strip() |
| if not p1: |
| return p2 |
| if not p2: |
| return p1 |
| return p1 + sep + p2 |
|
|
|
|
| def is_translation_acceptable(source_text, translated_text): |
| """Return False if translation is null, empty, or clearly bad (too short/garbage).""" |
| if translated_text is None: |
| return False |
| t = translated_text.strip() |
| if not t: |
| return False |
| if len(source_text) > 0 and len(t) < len(source_text) * MIN_TRANSLATION_RATIO: |
| return False |
| return True |
|
|
|
|
| def build_gemma_prompt(text, source_lang="en", target_lang="bn"): |
| messages = [{ |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "source_lang_code": source_lang, |
| "target_lang_code": target_lang, |
| "text": text, |
| } |
| ], |
| }] |
| prompt = processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| messages=[{"role": "user", "content": prompt}] |
| return messages |
|
|
| async def translate_fulltext_with_chunking(client, fulltext, source_lang, target_lang, translate_url): |
| """Translate fulltext; use two chunks and merge if text is long or first attempt fails.""" |
| if not (fulltext or "").strip(): |
| return None |
| fulltext = fulltext.strip() |
| |
| if len(fulltext) > MAX_FULLTEXT_CHARS_BEFORE_CHUNK: |
| chunk1, chunk2, sep = split_text_into_two_chunks(fulltext) |
| parts = [] |
| for chunk in (chunk1, chunk2): |
| if not chunk.strip(): |
| parts.append("") |
| continue |
| prompt = build_gemma_prompt(chunk, source_lang=source_lang, target_lang=target_lang) |
| out = await call_llm( |
| client, translate_url, "translate_gemma", prompt, max_tokens=4092 |
| ) |
| parts.append(out if out else "") |
| merged = _join_with_separator(parts[0], parts[1], sep) |
| return merged.strip() or None |
|
|
| |
| prompt = build_gemma_prompt(fulltext, source_lang=source_lang, target_lang=target_lang) |
| translated = await call_llm( |
| client, translate_url, "translate_gemma", prompt, max_tokens=4092 |
| ) |
| if is_translation_acceptable(fulltext, translated): |
| return translated |
|
|
| |
| chunk1, chunk2, sep = split_text_into_two_chunks(fulltext) |
| parts = [] |
| for chunk in (chunk1, chunk2): |
| if not chunk.strip(): |
| parts.append("") |
| continue |
| prompt = build_gemma_prompt(chunk, source_lang=source_lang, target_lang=target_lang) |
| out = await call_llm( |
| client, translate_url, "translate_gemma", prompt, max_tokens=4092 |
| ) |
| parts.append(out if out else "") |
| merged = _join_with_separator(parts[0], parts[1], sep) |
| return merged.strip() if merged.strip() else translated |
|
|
| async def process_record(client, record, source_lang, target_lang, translate_url): |
| """Translates a single JSON record (fulltext and summary).""" |
| fulltext = record.get("fulltext", "") |
| summary = record.get("summary", "") |
|
|
| |
| translated_fulltext = await translate_fulltext_with_chunking( |
| client, fulltext, source_lang, target_lang, translate_url |
| ) |
|
|
| |
| translated_summary_prompt = build_gemma_prompt( |
| summary, source_lang=source_lang, target_lang=target_lang |
| ) |
| translated_summary = await call_llm( |
| client, translate_url, "translate_gemma", translated_summary_prompt, max_tokens=1024 |
| ) |
|
|
| record["translated_fulltext"] = translated_fulltext |
| record["translated_summary"] = translated_summary |
| return record |
|
|
| def record_key(record): |
| record_id = record.get("id") |
| if record_id is not None: |
| return str(record_id) |
| return f"{record.get('fulltext', '')}||{record.get('summary', '')}" |
|
|
| def has_valid_translation(record): |
| translated_fulltext = record.get("translated_fulltext") |
| translated_summary = record.get("translated_summary") |
| return translated_fulltext is not None and translated_summary is not None |
|
|
| async def main(): |
| parser = argparse.ArgumentParser(description="Translate Multiclinsum dataset.") |
| parser.add_argument("--source-lang", default="en", help="Source language code") |
| parser.add_argument("--target-lang", default="bn", help="Target language code") |
| parser.add_argument( |
| "--start-idx", |
| type=int, |
| default=0, |
| help="Start index (inclusive) of the slice to translate", |
| ) |
| parser.add_argument( |
| "--end-idx", |
| type=int, |
| default=200, |
| help="End index (exclusive) of the slice to translate; use -1 for all", |
| ) |
| parser.add_argument( |
| "--port", |
| type=int, |
| default=8080, |
| help="Port for the translation API server (default: 8080)", |
| ) |
| args = parser.parse_args() |
|
|
| translate_url = f"http://127.0.0.1:{args.port}/v1/chat/completions" |
|
|
| start_idx = args.start_idx |
| end_idx = args.end_idx |
|
|
| with open(DATA_PATH, 'r', encoding='utf-8') as f: |
| all_data = json.load(f) |
| if end_idx == -1: |
| end_idx = len(all_data) |
|
|
| out_path = OUT_PATH_TEMPLATE.format( |
| source_lang=args.source_lang, |
| target_lang=args.target_lang, |
| start=start_idx, |
| end=end_idx, |
| ) |
| data = all_data[start_idx:end_idx] |
|
|
| async with httpx.AsyncClient() as client: |
| existing_results = [] |
| if os.path.exists(out_path): |
| with open(out_path, 'r', encoding='utf-8') as f: |
| existing_results = json.load(f) |
|
|
| existing_by_key = {record_key(rec): rec for rec in existing_results} |
| output_results = [] |
|
|
| batch_size = 10 |
| max_regen = len(data) |
| regenerated = 0 |
| for i in tqdm(range(0, len(data), batch_size)): |
| batch = data[i:i + batch_size] |
| pending = [] |
| pending_keys = [] |
| new_generated = 0 |
|
|
| for rec in batch: |
| key = record_key(rec) |
| existing = existing_by_key.get(key) |
| if existing and has_valid_translation(existing): |
| output_results.append(existing) |
| else: |
| if regenerated < max_regen: |
| pending.append(process_record(client, rec, args.source_lang, args.target_lang, translate_url)) |
| pending_keys.append(key) |
| regenerated += 1 |
| elif existing: |
| output_results.append(existing) |
|
|
| if pending: |
| processed = await asyncio.gather(*pending) |
| for key, rec in zip(pending_keys, processed): |
| if rec is not None: |
| existing_by_key[key] = rec |
| output_results.append(rec) |
| new_generated += 1 |
|
|
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| with open(out_path, 'w', encoding='utf-8') as f: |
| json.dump(output_results, f, ensure_ascii=False, indent=4) |
| print( |
| f"Batch {i // batch_size + 1}: new={new_generated}, total={len(output_results)}" |
| ) |
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |