| import json |
| import os |
| from typing import List, Tuple |
|
|
| import gradio as gr |
| import httpx |
| from transformers import AutoProcessor |
|
|
| DATA_PATH = ( |
| "/home/mshahidul/readctrl/data/translated_data/translation_wo_judge/" |
| "multiclinsum_gs_train_en2bn_gemma(0_200).json" |
| ) |
|
|
| TRANSLATE_URL = "http://172.16.34.29:8081/v1/chat/completions" |
| SOURCE_LANG = "en" |
| TARGET_LANG = "bn" |
|
|
| MODEL_ID = "google/translategemma-27b-it" |
| SERVER_MODEL_NAME = "translate_gemma" |
|
|
| MAX_INSTANCES = 80 |
|
|
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
|
|
| def load_data(path: str) -> List[dict]: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def save_data(path: str, data: List[dict]) -> None: |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(data, f, ensure_ascii=False, indent=4) |
|
|
|
|
| def build_gemma_prompt(text: str, source_lang: str, target_lang: str) -> List[dict]: |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "source_lang_code": source_lang, |
| "target_lang_code": target_lang, |
| "text": text, |
| } |
| ], |
| } |
| ] |
| prompt = processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| return [{"role": "user", "content": prompt}] |
|
|
|
|
| def call_llm( |
| text: str, |
| temperature: float = 0.1, |
| max_tokens: int | None = None, |
| source_lang: str = SOURCE_LANG, |
| target_lang: str = TARGET_LANG, |
| ) -> Tuple[str | None, str | None]: |
| if not text: |
| return None, "Empty source text." |
| messages = build_gemma_prompt(text, source_lang=source_lang, target_lang=target_lang) |
| payload = { |
| "model": SERVER_MODEL_NAME, |
| "messages": messages, |
| "temperature": float(temperature), |
| } |
| if max_tokens is not None: |
| payload["max_tokens"] = int(max_tokens) |
| try: |
| response = httpx.post(TRANSLATE_URL, json=payload, timeout=60.0) |
| result = response.json() |
| content = result["choices"][0]["message"]["content"].strip() |
| return content, None |
| except Exception as exc: |
| return None, f"LLM call failed: {exc}" |
|
|
|
|
| data = load_data(DATA_PATH) |
| limit = min(MAX_INSTANCES, len(data)) |
| options = [(f"{i:03d} | {data[i].get('id', 'no-id')}", i) for i in range(limit)] |
|
|
|
|
| def get_record(idx: int) -> dict: |
| return data[idx] |
|
|
|
|
| def record_to_fields(idx: int): |
| rec = get_record(idx) |
| return ( |
| idx, |
| rec.get("id", ""), |
| rec.get("fulltext", ""), |
| rec.get("summary", ""), |
| rec.get("translated_fulltext") or "", |
| rec.get("translated_summary") or "", |
| f"Loaded index {idx}.", |
| ) |
|
|
|
|
| def goto_index(idx: int): |
| return record_to_fields(int(idx)) |
|
|
|
|
| def step_index(idx: int, delta: int): |
| new_idx = max(0, min(limit - 1, int(idx) + delta)) |
| return record_to_fields(new_idx) |
|
|
|
|
| def regenerate_fulltext(idx: int, temperature: float, max_tokens: int): |
| rec = get_record(int(idx)) |
| translated, error = call_llm( |
| rec.get("fulltext", ""), |
| temperature=temperature, |
| max_tokens=max_tokens, |
| ) |
| if translated is not None: |
| rec["translated_fulltext"] = translated |
| return translated, f"Regenerated fulltext at index {idx}." |
| return rec.get("translated_fulltext") or "", error or "Regenerate failed." |
|
|
|
|
| def regenerate_summary(idx: int, temperature: float, max_tokens: int): |
| rec = get_record(int(idx)) |
| translated, error = call_llm( |
| rec.get("summary", ""), |
| temperature=temperature, |
| max_tokens=max_tokens, |
| ) |
| if translated is not None: |
| rec["translated_summary"] = translated |
| return translated, f"Regenerated summary at index {idx}." |
| return rec.get("translated_summary") or "", error or "Regenerate failed." |
|
|
|
|
| def regenerate_both(idx: int, temperature: float, max_tokens_full: int, max_tokens_sum: int): |
| fulltext, full_error = regenerate_fulltext(idx, temperature, max_tokens_full) |
| summary, sum_error = regenerate_summary(idx, temperature, max_tokens_sum) |
| status = "Regenerated fulltext and summary." |
| if full_error or sum_error: |
| errors = "; ".join([e for e in [full_error, sum_error] if e]) |
| status = f"Partial regenerate: {errors}" |
| return fulltext, summary, status |
|
|
|
|
| def save_record(idx: int, translated_fulltext: str, translated_summary: str): |
| rec = get_record(int(idx)) |
| rec["translated_fulltext"] = translated_fulltext or None |
| rec["translated_summary"] = translated_summary or None |
| save_data(DATA_PATH, data) |
| gr.Info(f"Saved index {idx} to file.") |
| return f"Saved index {idx} to file." |
|
|
|
|
| with gr.Blocks(title="Translation Review") as demo: |
| gr.Markdown("## Translation review for first 80 instances") |
|
|
| with gr.Row(): |
| record_select = gr.Dropdown( |
| label="Record", |
| choices=options, |
| value=0, |
| interactive=True, |
| ) |
| status = gr.Textbox(label="Status", value="Ready.", interactive=False) |
|
|
| with gr.Row(): |
| prev_btn = gr.Button("Prev") |
| next_btn = gr.Button("Next") |
|
|
| record_id = gr.Textbox(label="Record ID", interactive=False) |
| fulltext = gr.Textbox(label="Fulltext (source)", lines=8, interactive=False) |
| summary = gr.Textbox(label="Summary (source)", lines=6, interactive=False) |
|
|
| with gr.Row(): |
| temperature = gr.Slider( |
| minimum=0.0, |
| maximum=1.5, |
| value=0.2, |
| step=0.05, |
| label="Temperature", |
| ) |
| max_tokens_full = gr.Number(value=2048, precision=0, label="Max tokens (fulltext)") |
| max_tokens_sum = gr.Number(value=1024, precision=0, label="Max tokens (summary)") |
|
|
| translated_fulltext = gr.Textbox(label="Translated fulltext", lines=8) |
| translated_summary = gr.Textbox(label="Translated summary", lines=6) |
|
|
| with gr.Row(): |
| regen_full_btn = gr.Button("Regenerate Fulltext") |
| regen_sum_btn = gr.Button("Regenerate Summary") |
| regen_both_btn = gr.Button("Regenerate Both") |
| save_btn = gr.Button("Save to file") |
|
|
| record_select.change( |
| goto_index, |
| inputs=[record_select], |
| outputs=[ |
| record_select, |
| record_id, |
| fulltext, |
| summary, |
| translated_fulltext, |
| translated_summary, |
| status, |
| ], |
| ) |
| prev_btn.click( |
| lambda idx: step_index(idx, -1), |
| inputs=[record_select], |
| outputs=[ |
| record_select, |
| record_id, |
| fulltext, |
| summary, |
| translated_fulltext, |
| translated_summary, |
| status, |
| ], |
| ) |
| next_btn.click( |
| lambda idx: step_index(idx, 1), |
| inputs=[record_select], |
| outputs=[ |
| record_select, |
| record_id, |
| fulltext, |
| summary, |
| translated_fulltext, |
| translated_summary, |
| status, |
| ], |
| ) |
|
|
| regen_full_btn.click( |
| regenerate_fulltext, |
| inputs=[record_select, temperature, max_tokens_full], |
| outputs=[translated_fulltext, status], |
| ) |
| regen_sum_btn.click( |
| regenerate_summary, |
| inputs=[record_select, temperature, max_tokens_sum], |
| outputs=[translated_summary, status], |
| ) |
| regen_both_btn.click( |
| regenerate_both, |
| inputs=[record_select, temperature, max_tokens_full, max_tokens_sum], |
| outputs=[translated_fulltext, translated_summary, status], |
| ) |
| save_btn.click( |
| save_record, |
| inputs=[record_select, translated_fulltext, translated_summary], |
| outputs=[status], |
| ) |
|
|
| demo.load( |
| goto_index, |
| inputs=[record_select], |
| outputs=[ |
| record_select, |
| record_id, |
| fulltext, |
| summary, |
| translated_fulltext, |
| translated_summary, |
| status, |
| ], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |
|
|