File size: 6,156 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import json
import asyncio
import argparse
import httpx
from tqdm.asyncio import tqdm
from transformers import AutoProcessor

# ---- Configuration ----
DATA_PATH = "/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_en.json"
OUT_PATH_TEMPLATE = (
    "/home/mshahidul/readctrl/data/translated_data/"
    "multiclinsum_gs_train_{source_lang}2{target_lang}_gemma(0_200).json"
)

TRANSLATE_URL = "http://localhost:8081/v1/chat/completions"
JUDGE_URL = "http://localhost:8004/v1/chat/completions"
CONCURRENCY_LIMIT = 8  # Matches your server's "-np" or "--parallel" value

model_id = "google/translategemma-27b-it"
processor = AutoProcessor.from_pretrained(model_id)

semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT)

async def call_llm(client, url, model, messages, temperature=0.1, max_tokens=None):
    """Generic async caller for both Translation and Judge."""
    async with semaphore:
        try:
            payload = {
                "model": model,
                "messages": messages,
                "temperature": temperature
            }
            if max_tokens is not None:
                payload["max_tokens"] = max_tokens
            response = await client.post(url, json=payload, timeout=60.0)
            result = response.json()
            return result['choices'][0]['message']['content'].strip()
        except Exception as e:
            return None

def build_gemma_prompt(text, source_lang="en", target_lang="bn"):
    messages = [{
        "role": "user",
        "content": [
            {
                "type": "text",
                "source_lang_code": source_lang,
                "target_lang_code": target_lang,
                "text": text,
            }
        ],
    }]
    prompt = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    messages=[{"role": "user", "content": prompt}]
    return messages

def describe_lang(code):
    lang_names = {
        "en": "English",
        "bn": "Bengali",
        "zh": "Chinese",
        "vi": "Vietnamese",
        "hi": "Hindi"
    }
    return lang_names.get(code, "Unknown Language")

async def process_record(client, record, source_lang, target_lang):
    """Translates and judges a single JSON record."""
    # 1. Translate Fulltext & Summary
    # (Using the prompt format your local server expects)
    translated_fulltext_prompt = build_gemma_prompt(
        record['fulltext'], source_lang=source_lang, target_lang=target_lang
    )
    translated_summary_prompt = build_gemma_prompt(
        record['summary'], source_lang=source_lang, target_lang=target_lang
    )
    translated_fulltext = await call_llm(
        client, TRANSLATE_URL, "translate_gemma", translated_fulltext_prompt, max_tokens=1024
    )
    translated_summary = await call_llm(
        client, TRANSLATE_URL, "translate_gemma", translated_summary_prompt, max_tokens=512
    )

    # 2. Judge Phase
    source_lang_label = describe_lang(source_lang)
    target_lang_label = describe_lang(target_lang)
    judge_prompt = f"""
    You are a linguistic judge. Evaluate the following {target_lang_label} translation of a {source_lang_label} medical text.
    Check for:
    1. Presence of any language other than {target_lang_label} or {source_lang_label} medical terms.
    2. Hallucinated keywords not present in the original.

    Original {source_lang_label}: {record['fulltext']}
    Translated {target_lang_label}: {translated_fulltext} 

    Does this translation pass? Respond with ONLY 'PASS' or 'FAIL'.
    """
    judge_pass = False
    for _ in range(3):
        judge_res = await call_llm(client, JUDGE_URL, "Qwen/Qwen3-30B-A3B-Instruct-2507", [
            {"role": "user", "content": judge_prompt}
        ])
        judge_pass = "PASS" in (judge_res or "").upper()
        if judge_pass:
            break

    if not judge_pass:
        return None

    record['translated_fulltext'] = translated_fulltext
    record['translated_summary'] = translated_summary
    record['judge_pass'] = True
    return record

def record_key(record):
    record_id = record.get("id")
    if record_id is not None:
        return str(record_id)
    return f"{record.get('fulltext', '')}||{record.get('summary', '')}"

async def main():
    parser = argparse.ArgumentParser(description="Translate Multiclinsum dataset.")
    parser.add_argument("--source-lang", default="en", help="Source language code")
    parser.add_argument("--target-lang", default="bn", help="Target language code")
    args = parser.parse_args()

    out_path = OUT_PATH_TEMPLATE.format(
        source_lang=args.source_lang, target_lang=args.target_lang
    )

    with open(DATA_PATH, 'r', encoding='utf-8') as f:
        data = json.load(f)[0:200]

    async with httpx.AsyncClient() as client:
        existing_results = []
        if os.path.exists(out_path):
            with open(out_path, 'r', encoding='utf-8') as f:
                existing_results = json.load(f)

        existing_by_key = {record_key(rec): rec for rec in existing_results}
        output_results = []

        batch_size = 10
        for i in tqdm(range(0, len(data), batch_size)):
            batch = data[i:i + batch_size]
            pending = []
            pending_keys = []

            for rec in batch:
                key = record_key(rec)
                if key in existing_by_key:
                    output_results.append(existing_by_key[key])
                else:
                    pending.append(process_record(client, rec, args.source_lang, args.target_lang))
                    pending_keys.append(key)

            if pending:
                processed = await asyncio.gather(*pending)
                for key, rec in zip(pending_keys, processed):
                    if rec is not None:
                        existing_by_key[key] = rec
                        output_results.append(rec)

            os.makedirs(os.path.dirname(out_path), exist_ok=True)
            with open(out_path, 'w', encoding='utf-8') as f:
                json.dump(output_results, f, ensure_ascii=False, indent=4)

if __name__ == "__main__":
    asyncio.run(main())