| |
| """ |
| Extract Bangla subclaims from translated MultiClinSum files using the |
| subclaim-extractor vLLM server (Qwen3-30B-A3B on port 8050). |
| |
| - Input: JSON files in translation_testing_3396 (attrs: translated_fulltext, translated_summary) |
| - Output: Save to extracting_subclaim/bn without fulltext/summary. |
| """ |
|
|
| import os |
| import json |
| import glob |
| import argparse |
| from openai import OpenAI |
|
|
| |
| |
| |
| DEFAULT_API_URL = "http://localhost:8050/v1" |
| DEFAULT_MODEL_NAME = "subclaim-extractor" |
|
|
| client = None |
|
|
|
|
| def get_client(base_url: str = None, api_key: str = "EMPTY"): |
| global client |
| if client is None: |
| client = OpenAI(base_url=base_url or DEFAULT_API_URL, api_key=api_key) |
| return client |
|
|
|
|
| |
| |
| |
| def extraction_prompt(medical_text: str, is_summary: bool = False) -> str: |
| source_type = "summary" if is_summary else "full medical text" |
| return f""" |
| You are an expert medical annotator. The following text is in Bangla (Bengali). |
| |
| Your task is to extract granular, factual subclaims from the provided {source_type}. |
| A subclaim is the smallest standalone factual unit that can be independently verified. |
| |
| Instructions: |
| 1. Read the Bangla medical text carefully. |
| 2. Extract factual statements explicitly stated in the text. |
| 3. Each subclaim must: |
| - Be in Bangla (same language as the input) |
| - Contain exactly ONE factual assertion |
| - Come directly from the text (no inference or interpretation) |
| - Preserve original wording as much as possible |
| - Include any negation, uncertainty, or qualifier |
| 4. Do NOT: |
| - Combine multiple facts into one subclaim |
| - Add new information |
| - Translate to another language |
| 5. Return ONLY a valid JSON array of strings. |
| 6. Use double quotes and valid JSON formatting only (no markdown, no commentary). |
| |
| Medical Text (Bangla): |
| {medical_text} |
| |
| Return format: |
| [ |
| "subclaim 1", |
| "subclaim 2" |
| ] |
| """.strip() |
|
|
|
|
| def _strip_markdown_json_block(text: str) -> str: |
| """Strip optional markdown code fence (e.g. ```json\\n[...]\\n```).""" |
| text = text.strip() |
| |
| if text.startswith("```json"): |
| text = text[7:].lstrip("\n") |
| elif text.startswith("```"): |
| text = text[3:].lstrip("\n") |
| |
| if text.endswith("```"): |
| text = text[:-3].rstrip("\n") |
| return text.strip() |
|
|
|
|
| def _parse_subclaims_output(output_text: str) -> list: |
| output_text = (output_text or "").strip() |
| if not output_text: |
| return [] |
|
|
| if "</think>" in output_text: |
| output_text = output_text.split("</think>")[-1].strip() |
|
|
| output_text = _strip_markdown_json_block(output_text) |
|
|
| start_idx = output_text.find("[") |
| end_idx = output_text.rfind("]") + 1 |
| if start_idx != -1 and end_idx > start_idx: |
| content = output_text[start_idx:end_idx] |
| parsed = json.loads(content) |
| if isinstance(parsed, list): |
| return [str(s).strip() for s in parsed if str(s).strip()] |
|
|
| raise ValueError("Incomplete or invalid JSON list") |
|
|
|
|
| def infer_subclaims_api( |
| medical_text: str, |
| is_summary: bool = False, |
| temperature: float = 0.2, |
| max_tokens: int = 2048, |
| retries: int = 2, |
| base_url: str = None, |
| model_name: str = None, |
| ) -> list: |
| if not medical_text or not medical_text.strip(): |
| return [] |
|
|
| prompt = extraction_prompt(medical_text, is_summary=is_summary) |
| c = get_client(base_url=base_url) |
| model = model_name or DEFAULT_MODEL_NAME |
|
|
| for attempt in range(retries + 1): |
| try: |
| response = c.chat.completions.create( |
| model=model, |
| messages=[{"role": "user", "content": prompt}], |
| temperature=temperature, |
| max_tokens=max_tokens, |
| ) |
| output_text = response.choices[0].message.content.strip() |
| return _parse_subclaims_output(output_text) |
| except (json.JSONDecodeError, ValueError, Exception) as e: |
| if attempt < retries: |
| max_tokens = max_tokens + 1024 |
| print(f" [Warning] {e}. Retry with max_tokens={max_tokens}") |
| continue |
| print(f" [Error] Failed after retries: {e}") |
| return [] |
|
|
| return [] |
|
|
|
|
| def infer_subclaims_batch_api( |
| medical_texts: list, |
| is_summary: bool = False, |
| temperature: float = 0.2, |
| max_tokens: int = 2048, |
| retries: int = 2, |
| base_url: str = None, |
| model_name: str = None, |
| ) -> list: |
| """ |
| Batched subclaim extraction. Returns a list of subclaim lists aligned to input order. |
| Uses the OpenAI-compatible /v1/completions endpoint with prompt=[...]. |
| Falls back to per-example chat calls if parsing fails for any element. |
| """ |
| if not medical_texts: |
| return [] |
|
|
| prompts = [] |
| for t in medical_texts: |
| t = t or "" |
| if not t.strip(): |
| prompts.append(None) |
| else: |
| prompts.append(extraction_prompt(t, is_summary=is_summary)) |
|
|
| out = [[] for _ in range(len(prompts))] |
| idxs = [i for i, p in enumerate(prompts) if p is not None] |
| if not idxs: |
| return out |
|
|
| c = get_client(base_url=base_url) |
| model = model_name or DEFAULT_MODEL_NAME |
|
|
| |
| batched_prompts = [prompts[i] for i in idxs] |
| for attempt in range(retries + 1): |
| try: |
| response = c.completions.create( |
| model=model, |
| prompt=batched_prompts, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| ) |
|
|
| |
| by_index = {} |
| for ch in response.choices: |
| try: |
| by_index[int(ch.index)] = ch.text |
| except Exception: |
| |
| pass |
|
|
| texts = [] |
| if len(by_index) == len(batched_prompts): |
| texts = [by_index[i] for i in range(len(batched_prompts))] |
| else: |
| |
| texts = [getattr(ch, "text", "") for ch in response.choices][: len(batched_prompts)] |
| if len(texts) < len(batched_prompts): |
| texts += [""] * (len(batched_prompts) - len(texts)) |
|
|
| parse_failed = [] |
| for local_i, global_i in enumerate(idxs): |
| try: |
| out[global_i] = _parse_subclaims_output(texts[local_i]) |
| except Exception: |
| parse_failed.append(global_i) |
|
|
| |
| if not parse_failed: |
| return out |
|
|
| |
| for global_i in parse_failed: |
| out[global_i] = infer_subclaims_api( |
| medical_texts[global_i], |
| is_summary=is_summary, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| retries=retries, |
| base_url=base_url, |
| model_name=model_name, |
| ) |
| return out |
| except Exception as e: |
| if attempt < retries: |
| max_tokens = max_tokens + 1024 |
| print(f" [Warning] batch request failed: {e}. Retry with max_tokens={max_tokens}") |
| continue |
| print(f" [Error] batch request failed after retries: {e}") |
| break |
|
|
| |
| for i in idxs: |
| out[i] = infer_subclaims_api( |
| medical_texts[i], |
| is_summary=is_summary, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| retries=retries, |
| base_url=base_url, |
| model_name=model_name, |
| ) |
| return out |
|
|
|
|
| def _has_null_translation(item: dict) -> bool: |
| """True if translated_fulltext or translated_summary is None (ignore such instances).""" |
| return item.get("translated_fulltext") is None or item.get("translated_summary") is None |
|
|
|
|
| def load_from_single_file(input_path: str) -> list: |
| """Load items from a single JSON file (list or single object). Ignore instances with null translations.""" |
| with open(input_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if not isinstance(data, list): |
| data = [data] |
| return [item for item in data if not _has_null_translation(item)] |
|
|
|
|
| def load_all_translation_items(input_dir: str) -> list: |
| """Load and merge all JSON arrays from translation_testing_3396. Ignore instances with null translations.""" |
| pattern = os.path.join(input_dir, "*.json") |
| files = sorted(glob.glob(pattern)) |
| if not files: |
| raise FileNotFoundError(f"No JSON files in {input_dir}") |
| all_items = [] |
| seen_ids = set() |
| for path in files: |
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if not isinstance(data, list): |
| data = [data] |
| for item in data: |
| if _has_null_translation(item): |
| continue |
| uid = item.get("id") |
| if uid in seen_ids: |
| continue |
| seen_ids.add(uid) |
| all_items.append(item) |
| return all_items |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Extract Bangla subclaims via subclaim-extractor vLLM") |
| parser.add_argument( |
| "--input_dir", |
| type=str, |
| default="/home/mshahidul/readctrl/data/translated_data/translation_testing_3396", |
| help="Directory containing translated JSON files (used when --input_file is not set)", |
| ) |
| parser.add_argument( |
| "--input_file", |
| type=str, |
| default=None, |
| help="Single JSON file to process (overrides --input_dir)", |
| ) |
| parser.add_argument( |
| "--save_dir", |
| type=str, |
| default="/home/mshahidul/readctrl/data/extracting_subclaim/bn", |
| help="Directory to save output JSON files", |
| ) |
| parser.add_argument( |
| "--api_url", |
| type=str, |
| default=DEFAULT_API_URL, |
| help="vLLM OpenAI-compatible API base URL (default: http://localhost:8050/v1)", |
| ) |
| parser.add_argument( |
| "--port", |
| type=int, |
| default=None, |
| help="Server port (e.g. 8050). Builds API URL as http://localhost:PORT/v1 (overrides --api_url if set)", |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| default=DEFAULT_MODEL_NAME, |
| help="Served model name (default: subclaim-extractor)", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=8, |
| help="Number of items to process per batch (each batch sends prompts in bulk to vLLM)", |
| ) |
| parser.add_argument("--start", type=int, default=0, help="Start index") |
| parser.add_argument("--end", type=int, default=None, help="End index (exclusive)") |
| parser.add_argument( |
| "--resume", |
| type=str, |
| default=None, |
| help="Path to existing output JSON to resume (append new items by id)", |
| ) |
| args = parser.parse_args() |
|
|
| if args.port is not None: |
| args.api_url = f"http://localhost:{args.port}/v1" |
| print(f"Using API URL: {args.api_url}") |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| if args.input_file: |
| if not os.path.isfile(args.input_file): |
| raise FileNotFoundError(f"Input file not found: {args.input_file}") |
| all_items = load_from_single_file(args.input_file) |
| print(f"Loaded {len(all_items)} items from {args.input_file}") |
| else: |
| all_items = load_all_translation_items(args.input_dir) |
| end = args.end if args.end is not None else len(all_items) |
| subset = all_items[args.start : end] |
| print(f"Processing indices [{args.start}:{end}], total items: {len(subset)}") |
|
|
| |
| processed_by_id = {} |
| if args.resume and os.path.isfile(args.resume): |
| with open(args.resume, "r", encoding="utf-8") as f: |
| existing = json.load(f) |
| for item in existing: |
| processed_by_id[item["id"]] = item |
| print(f"Resumed: {len(processed_by_id)} existing entries from {args.resume}") |
| last_checkpoint_count = len(processed_by_id) |
| checkpoint_every = 20 |
|
|
| |
| end_tag = end if end != len(all_items) else "end" |
| if args.input_file: |
| base = os.path.splitext(os.path.basename(args.input_file))[0] |
| output_name = f"{base}_extracted_subclaims_bn_{args.start}_{end_tag}.json" |
| else: |
| output_name = f"extracted_subclaims_bn_{args.start}_{end_tag}.json" |
| output_file = os.path.join(args.save_dir, output_name) |
| if args.resume: |
| output_file = args.resume |
|
|
| try: |
| import tqdm |
| iterator = tqdm.tqdm(subset, desc="Extracting subclaims") |
| except ImportError: |
| iterator = subset |
|
|
| batch = [] |
| for item in iterator: |
| uid = item.get("id") |
| if uid in processed_by_id: |
| continue |
| batch.append(item) |
|
|
| if len(batch) < max(1, int(args.batch_size)): |
| continue |
|
|
| uids = [it.get("id") for it in batch] |
| fulltexts = [(it.get("translated_fulltext") or "") for it in batch] |
| summaries = [(it.get("translated_summary") or "") for it in batch] |
|
|
| fulltext_subclaims_list = infer_subclaims_batch_api( |
| fulltexts, |
| is_summary=False, |
| max_tokens=4096, |
| base_url=args.api_url, |
| model_name=args.model, |
| ) |
| summary_subclaims_list = infer_subclaims_batch_api( |
| summaries, |
| is_summary=True, |
| max_tokens=2048, |
| base_url=args.api_url, |
| model_name=args.model, |
| ) |
|
|
| for b_i, uid in enumerate(uids): |
| translated_fulltext = fulltexts[b_i] |
| translated_summary = summaries[b_i] |
|
|
| |
| if not translated_fulltext.strip() and not translated_summary.strip(): |
| processed_by_id[uid] = { |
| "id": uid, |
| "fulltext": translated_fulltext, |
| "summary": translated_summary, |
| "fulltext_subclaims": [], |
| "summary_subclaims": [], |
| } |
| continue |
|
|
| processed_by_id[uid] = { |
| "id": uid, |
| "fulltext": translated_fulltext, |
| "summary": translated_summary, |
| "fulltext_subclaims": fulltext_subclaims_list[b_i], |
| "summary_subclaims": summary_subclaims_list[b_i], |
| } |
|
|
| batch = [] |
|
|
| |
| if len(processed_by_id) - last_checkpoint_count >= checkpoint_every: |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(list(processed_by_id.values()), f, indent=2, ensure_ascii=False) |
| last_checkpoint_count = len(processed_by_id) |
|
|
| |
| if batch: |
| uids = [it.get("id") for it in batch] |
| fulltexts = [(it.get("translated_fulltext") or "") for it in batch] |
| summaries = [(it.get("translated_summary") or "") for it in batch] |
|
|
| fulltext_subclaims_list = infer_subclaims_batch_api( |
| fulltexts, |
| is_summary=False, |
| max_tokens=4096, |
| base_url=args.api_url, |
| model_name=args.model, |
| ) |
| summary_subclaims_list = infer_subclaims_batch_api( |
| summaries, |
| is_summary=True, |
| max_tokens=2048, |
| base_url=args.api_url, |
| model_name=args.model, |
| ) |
|
|
| for b_i, uid in enumerate(uids): |
| translated_fulltext = fulltexts[b_i] |
| translated_summary = summaries[b_i] |
| if not translated_fulltext.strip() and not translated_summary.strip(): |
| processed_by_id[uid] = { |
| "id": uid, |
| "fulltext": translated_fulltext, |
| "summary": translated_summary, |
| "fulltext_subclaims": [], |
| "summary_subclaims": [], |
| } |
| continue |
|
|
| processed_by_id[uid] = { |
| "id": uid, |
| "fulltext": translated_fulltext, |
| "summary": translated_summary, |
| "fulltext_subclaims": fulltext_subclaims_list[b_i], |
| "summary_subclaims": summary_subclaims_list[b_i], |
| } |
|
|
| if len(processed_by_id) - last_checkpoint_count >= checkpoint_every: |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(list(processed_by_id.values()), f, indent=2, ensure_ascii=False) |
| last_checkpoint_count = len(processed_by_id) |
|
|
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump( |
| list(processed_by_id.values()), |
| f, |
| indent=2, |
| ensure_ascii=False, |
| ) |
| print(f"Saved {len(processed_by_id)} entries to {output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|