readCtrl_lambda / code /translation /translate_multiclinsum_en2bn_v2.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
import json
import asyncio
import argparse
import httpx
from tqdm.asyncio import tqdm
from transformers import AutoProcessor
# ---- Configuration ----
DATA_PATH = "/home/mshahidul/readctrl/data/processed_test_raw_data/multiclinsum_test_en.json"
OUT_PATH_TEMPLATE = (
"/home/mshahidul/readctrl/data/translated_data/translation_testing_3396/"
"multiclinsum_test_{source_lang}2{target_lang}_gemma({start}_{end})_3396.json"
)
# Chunking for long fulltext: split and merge if output is null/bad, or if text exceeds this length
MAX_FULLTEXT_CHARS_BEFORE_CHUNK = 3500
MIN_TRANSLATION_RATIO = 0.15 # treat as bad if translation length < 15% of source
TRANSLATE_URL = "http://127.0.0.1:8080/v1/chat/completions"
CONCURRENCY_LIMIT = 8 # Matches your server's "-np" or "--parallel" value
model_id = "google/translategemma-27b-it"
processor = AutoProcessor.from_pretrained(model_id)
semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT)
async def call_llm(client, url, model, messages, temperature=0.1, max_tokens=None):
"""Generic async caller for both Translation and Judge."""
async with semaphore:
try:
payload = {
"model": model,
"messages": messages,
"temperature": temperature
}
if max_tokens is not None:
payload["max_tokens"] = max_tokens
response = await client.post(url, json=payload, timeout=60.0)
result = response.json()
return result['choices'][0]['message']['content'].strip()
except Exception as e:
return None
def split_text_into_two_chunks(text):
"""Split at a natural boundary (paragraph or sentence). Returns (chunk1, chunk2, separator)."""
text = text.strip()
if len(text) <= 1:
return (text, "", "\n\n")
mid = len(text) // 2
# Prefer paragraph boundary so merge preserves existing paragraph structure
for sep in ("\n\n", ". ", ".\n", "! ", "!\n", "? ", "?\n"):
idx = text.rfind(sep, 0, mid + 1)
if idx > 0:
return (
text[: idx + len(sep)].strip(),
text[idx + len(sep) :].strip(),
sep,
)
# Fallback: split at last space before mid
space_idx = text.rfind(" ", 0, mid + 1)
if space_idx > 0:
return (text[:space_idx].strip(), text[space_idx:].strip(), " ")
return (text[:mid].strip(), text[mid:].strip(), " ")
def _join_with_separator(part1, part2, sep):
"""Join two translated parts with the original boundary (paragraph/sentence)."""
p1 = (part1 or "").strip()
p2 = (part2 or "").strip()
if not p1:
return p2
if not p2:
return p1
return p1 + sep + p2
def is_translation_acceptable(source_text, translated_text):
"""Return False if translation is null, empty, or clearly bad (too short/garbage)."""
if translated_text is None:
return False
t = translated_text.strip()
if not t:
return False
if len(source_text) > 0 and len(t) < len(source_text) * MIN_TRANSLATION_RATIO:
return False
return True
def build_gemma_prompt(text, source_lang="en", target_lang="bn"):
messages = [{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"text": text,
}
],
}]
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
messages=[{"role": "user", "content": prompt}]
return messages
async def translate_fulltext_with_chunking(client, fulltext, source_lang, target_lang, translate_url):
"""Translate fulltext; use two chunks and merge if text is long or first attempt fails."""
if not (fulltext or "").strip():
return None
fulltext = fulltext.strip()
# Proactively chunk if very long to avoid null/bad output
if len(fulltext) > MAX_FULLTEXT_CHARS_BEFORE_CHUNK:
chunk1, chunk2, sep = split_text_into_two_chunks(fulltext)
parts = []
for chunk in (chunk1, chunk2):
if not chunk.strip():
parts.append("")
continue
prompt = build_gemma_prompt(chunk, source_lang=source_lang, target_lang=target_lang)
out = await call_llm(
client, translate_url, "translate_gemma", prompt, max_tokens=4092
)
parts.append(out if out else "")
merged = _join_with_separator(parts[0], parts[1], sep)
return merged.strip() or None
# Try full translation first
prompt = build_gemma_prompt(fulltext, source_lang=source_lang, target_lang=target_lang)
translated = await call_llm(
client, translate_url, "translate_gemma", prompt, max_tokens=4092
)
if is_translation_acceptable(fulltext, translated):
return translated
# Retry with two chunks and merge using same boundary as split
chunk1, chunk2, sep = split_text_into_two_chunks(fulltext)
parts = []
for chunk in (chunk1, chunk2):
if not chunk.strip():
parts.append("")
continue
prompt = build_gemma_prompt(chunk, source_lang=source_lang, target_lang=target_lang)
out = await call_llm(
client, translate_url, "translate_gemma", prompt, max_tokens=4092
)
parts.append(out if out else "")
merged = _join_with_separator(parts[0], parts[1], sep)
return merged.strip() if merged.strip() else translated # fallback to first attempt if merge empty
async def process_record(client, record, source_lang, target_lang, translate_url):
"""Translates a single JSON record (fulltext and summary)."""
fulltext = record.get("fulltext", "")
summary = record.get("summary", "")
# 1. Translate fulltext (with chunking for long or failed first attempt)
translated_fulltext = await translate_fulltext_with_chunking(
client, fulltext, source_lang, target_lang, translate_url
)
# 2. Translate summary
translated_summary_prompt = build_gemma_prompt(
summary, source_lang=source_lang, target_lang=target_lang
)
translated_summary = await call_llm(
client, translate_url, "translate_gemma", translated_summary_prompt, max_tokens=1024
)
record["translated_fulltext"] = translated_fulltext
record["translated_summary"] = translated_summary
return record
def record_key(record):
record_id = record.get("id")
if record_id is not None:
return str(record_id)
return f"{record.get('fulltext', '')}||{record.get('summary', '')}"
def has_valid_translation(record):
translated_fulltext = record.get("translated_fulltext")
translated_summary = record.get("translated_summary")
return translated_fulltext is not None and translated_summary is not None
async def main():
parser = argparse.ArgumentParser(description="Translate Multiclinsum dataset.")
parser.add_argument("--source-lang", default="en", help="Source language code")
parser.add_argument("--target-lang", default="bn", help="Target language code")
parser.add_argument(
"--start-idx",
type=int,
default=0,
help="Start index (inclusive) of the slice to translate",
)
parser.add_argument(
"--end-idx",
type=int,
default=200,
help="End index (exclusive) of the slice to translate; use -1 for all",
)
parser.add_argument(
"--port",
type=int,
default=8080,
help="Port for the translation API server (default: 8080)",
)
args = parser.parse_args()
translate_url = f"http://127.0.0.1:{args.port}/v1/chat/completions"
start_idx = args.start_idx
end_idx = args.end_idx
with open(DATA_PATH, 'r', encoding='utf-8') as f:
all_data = json.load(f)
if end_idx == -1:
end_idx = len(all_data)
out_path = OUT_PATH_TEMPLATE.format(
source_lang=args.source_lang,
target_lang=args.target_lang,
start=start_idx,
end=end_idx,
)
data = all_data[start_idx:end_idx]
async with httpx.AsyncClient() as client:
existing_results = []
if os.path.exists(out_path):
with open(out_path, 'r', encoding='utf-8') as f:
existing_results = json.load(f)
existing_by_key = {record_key(rec): rec for rec in existing_results}
output_results = []
batch_size = 10
max_regen = len(data)
regenerated = 0
for i in tqdm(range(0, len(data), batch_size)):
batch = data[i:i + batch_size]
pending = []
pending_keys = []
new_generated = 0
for rec in batch:
key = record_key(rec)
existing = existing_by_key.get(key)
if existing and has_valid_translation(existing):
output_results.append(existing)
else:
if regenerated < max_regen:
pending.append(process_record(client, rec, args.source_lang, args.target_lang, translate_url))
pending_keys.append(key)
regenerated += 1
elif existing:
output_results.append(existing)
if pending:
processed = await asyncio.gather(*pending)
for key, rec in zip(pending_keys, processed):
if rec is not None:
existing_by_key[key] = rec
output_results.append(rec)
new_generated += 1
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(output_results, f, ensure_ascii=False, indent=4)
print(
f"Batch {i // batch_size + 1}: new={new_generated}, total={len(output_results)}"
)
if __name__ == "__main__":
asyncio.run(main())