| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "7" |
| import json |
| import os |
| from datetime import datetime |
|
|
| import torch |
| from datasets import Dataset |
|
|
| from unsloth import FastModel |
| from unsloth.chat_templates import ( |
| get_chat_template, |
| standardize_data_formats, |
| train_on_responses_only, |
| ) |
| from trl import SFTConfig, SFTTrainer |
|
|
| model_name = "unsloth/gemma-3-4b-it" |
| data_path = "/home/mshahidul/readctrl/data/extracting_subclaim/bn/multiclinsum_test_en2bn_gemma(0_1000)_3396_extracted_subclaims_bn_0_end.json" |
| test_size = 0.2 |
| seed = 42 |
| run_mode = "finetune_and_eval" |
| save_fp16_merged = True |
|
|
| |
| MAX_SUBCLAIMS_FULLTEXT = 80 |
| MAX_SUBCLAIMS_SUMMARY = 40 |
|
|
|
|
| def get_model_size_from_name(name): |
| base = name.split("/")[-1] |
| for part in base.split("-"): |
| token = part.lower() |
| if token.endswith("b") or token.endswith("m"): |
| return part |
| return "unknown" |
|
|
|
|
| model_size = get_model_size_from_name(model_name) |
|
|
|
|
| def formatting_prompts_func(examples): |
| convos = examples["conversations"] |
| texts = [ |
| tokenizer.apply_chat_template( |
| convo, |
| tokenize=False, |
| add_generation_prompt=False, |
| ).removeprefix("<bos>") |
| for convo in convos |
| ] |
| return {"text": texts} |
|
|
|
|
| def build_subclaim_user_prompt(medical_text, is_summary=False, max_subclaims=None): |
| """ |
| Build a Bangla instruction prompt for subclaim extraction. |
| Uses the same wording as `extraction_prompt` in `extract_bn_subclaims_vllm.py`, |
| with an optional cap on the number of subclaims described in the instructions. |
| """ |
| base_prompt = f""" |
| You are an expert medical annotator. The following text is in Bangla (Bengali). |
| |
| Your task is to extract granular, factual subclaims from the provided medical text. |
| A subclaim is the smallest standalone factual unit that can be independently verified. |
| |
| Instructions: |
| 1. Read the Bangla medical text carefully. |
| 2. Extract factual statements explicitly stated in the text. |
| 3. Each subclaim must: |
| - Be in Bangla (same language as the input) |
| - Contain exactly ONE factual assertion |
| - Come directly from the text (no inference or interpretation) |
| - Preserve original wording as much as possible |
| - Include any negation, uncertainty, or qualifier |
| 4. Do NOT: |
| - Combine multiple facts into one subclaim |
| - Add new information |
| - Translate to another language |
| 5. Return ONLY a valid JSON array of strings. |
| 6. Use double quotes and valid JSON formatting only (no markdown, no commentary). |
| |
| Medical Text (Bangla): |
| {medical_text} |
| |
| Return format: |
| [ |
| "subclaim 1", |
| "subclaim 2" |
| ] |
| """.strip() |
|
|
| |
| |
| if max_subclaims is not None: |
| limit_note = ( |
| f"\n\nNote: Extract at most {max_subclaims} subclaims, prioritizing the most important factual statements." |
| ) |
| return base_prompt + limit_note |
| return base_prompt |
|
|
|
|
| def build_subclaim_examples(raw_records): |
| """ |
| Build chat-style training examples for Bangla subclaim extraction. |
| |
| Each record can contribute up to two examples: |
| - fulltext -> fulltext_subclaims |
| - summary -> summary_subclaims |
| """ |
| examples = [] |
| for record in raw_records: |
| fulltext = (record.get("fulltext") or "").strip() |
| fulltext_subclaims = record.get("fulltext_subclaims") or [] |
| summary = (record.get("summary") or "").strip() |
| summary_subclaims = record.get("summary_subclaims") or [] |
|
|
| if fulltext and fulltext_subclaims: |
| user_prompt = build_subclaim_user_prompt( |
| fulltext, |
| is_summary=False, |
| max_subclaims=MAX_SUBCLAIMS_FULLTEXT, |
| ) |
| assistant_content = json.dumps(fulltext_subclaims, ensure_ascii=False) |
| examples.append( |
| { |
| "conversations": [ |
| {"role": "user", "content": user_prompt}, |
| {"role": "assistant", "content": assistant_content}, |
| ], |
| } |
| ) |
|
|
| if summary and summary_subclaims: |
| user_prompt = build_subclaim_user_prompt( |
| summary, |
| is_summary=True, |
| max_subclaims=MAX_SUBCLAIMS_SUMMARY, |
| ) |
| assistant_content = json.dumps(summary_subclaims, ensure_ascii=False) |
| examples.append( |
| { |
| "conversations": [ |
| {"role": "user", "content": user_prompt}, |
| {"role": "assistant", "content": assistant_content}, |
| ], |
| } |
| ) |
|
|
| return examples |
|
|
|
|
| def extract_conversation_pair(conversations): |
| user_prompt = "" |
| gold_response = "" |
| for message in conversations: |
| role = message.get("role") or message.get("from") |
| content = message.get("content", "") |
| if role == "user" and not user_prompt: |
| user_prompt = content |
| elif role == "assistant" and not gold_response: |
| gold_response = content |
| return user_prompt, gold_response |
|
|
|
|
| def generate_prediction(user_prompt): |
| prompt = tokenizer.apply_chat_template( |
| [{"role": "user", "content": user_prompt}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device) |
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=1024, |
| do_sample=False, |
| temperature=0.0, |
| use_cache=True, |
| ) |
| generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] |
| return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
|
|
|
|
| |
| model, tokenizer = FastModel.from_pretrained( |
| model_name=model_name, |
| max_seq_length=4092, |
| load_in_4bit=True, |
| ) |
|
|
| |
| tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") |
| with open(data_path, "r", encoding="utf-8") as f: |
| raw_data = json.load(f) |
|
|
| raw_dataset = Dataset.from_list(raw_data) |
| split_dataset = raw_dataset.train_test_split( |
| test_size=test_size, seed=seed, shuffle=True |
| ) |
| train_raw = split_dataset["train"] |
| test_raw = split_dataset["test"] |
|
|
| train_examples = build_subclaim_examples(train_raw) |
| train_dataset = Dataset.from_list(train_examples) |
| train_dataset = train_dataset.map(formatting_prompts_func, batched=True) |
|
|
| |
| if run_mode == "finetune_and_eval": |
| |
| model = FastModel.get_peft_model( |
| model, |
| r=8, |
| target_modules=[ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| ], |
| lora_alpha=16, |
| lora_dropout=0, |
| bias="none", |
| random_state=seed, |
| ) |
|
|
| |
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=train_dataset, |
| dataset_text_field="text", |
| max_seq_length=2048, |
| args=SFTConfig( |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=4, |
| warmup_steps=5, |
| max_steps=60, |
| learning_rate=2e-4, |
| fp16=not torch.cuda.is_bf16_supported(), |
| bf16=torch.cuda.is_bf16_supported(), |
| logging_steps=1, |
| optim="adamw_8bit", |
| weight_decay=0.01, |
| lr_scheduler_type="linear", |
| seed=seed, |
| output_dir="outputs", |
| report_to="none", |
| ), |
| ) |
|
|
| |
| trainer = train_on_responses_only( |
| trainer, |
| instruction_part="<start_of_turn>user\n", |
| response_part="<start_of_turn>model\n", |
| ) |
|
|
| |
| save_dir = f"/home/mshahidul/readctrl_model/subclaim_support_extraction_bn/{model_name.split('/')[-1]}" |
| os.makedirs(save_dir, exist_ok=True) |
| trainer.train() |
|
|
| |
| if save_fp16_merged: |
| model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit") |
| tokenizer.save_pretrained(save_dir) |
|
|
| elif run_mode == "eval_base_only": |
| |
| save_dir = f"BASE_MODEL:{model_name}" |
| else: |
| raise ValueError(f"Unsupported run_mode: {run_mode}") |
|
|
| |
| FastModel.for_inference(model) |
| model.eval() |
|
|
| model_info_dir = ( |
| "/home/mshahidul/readctrl/code/subclaim_support_extraction/inference_data" |
| ) |
| ablation_dir = ( |
| "/home/mshahidul/readctrl/code/subclaim_support_extraction/ablation_studies" |
| ) |
| os.makedirs(model_info_dir, exist_ok=True) |
| os.makedirs(ablation_dir, exist_ok=True) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| model_tag = model_name.split("/")[-1].replace(".", "_") |
|
|
|
|
| def _parse_subclaim_list(text): |
| """Best-effort parse of a JSON list of subclaims from model output.""" |
| if not text: |
| return [] |
| text = text.strip() |
|
|
| |
| if "</think>" in text: |
| text = text.split("</think>")[-1].strip() |
|
|
| start_idx = text.find("[") |
| end_idx = text.rfind("]") + 1 |
| if start_idx != -1 and end_idx > start_idx: |
| text_slice = text[start_idx:end_idx] |
| else: |
| text_slice = text |
|
|
| try: |
| parsed = json.loads(text_slice) |
| if isinstance(parsed, list): |
| return [str(s).strip() for s in parsed if s] |
| except Exception: |
| return [] |
| return [] |
|
|
|
|
| def _subclaim_metrics(gold, pred): |
| """Compute simple set-based precision/recall/Jaccard for subclaim lists.""" |
| gold_set = {s.strip() for s in gold if s} |
| pred_set = {s.strip() for s in pred if s} |
|
|
| if not gold_set and not pred_set: |
| return 1.0, 1.0, 1.0 |
| if not pred_set: |
| return 0.0, 0.0, 0.0 |
|
|
| inter = gold_set & pred_set |
| union = gold_set | pred_set |
|
|
| precision = len(inter) / len(pred_set) if pred_set else 0.0 |
| recall = len(inter) / len(gold_set) if gold_set else 0.0 |
| jaccard = len(inter) / len(union) if union else 0.0 |
| return precision, recall, jaccard |
|
|
|
|
| def evaluate_subclaim_mode(test_split): |
| """ |
| Evaluate subclaim extraction on the held-out split. |
| |
| For each example, we prompt on fulltext and/or summary (if present) |
| and compare the predicted subclaim list with the gold subclaims. |
| """ |
| results = [] |
| total_pairs = 0 |
| sum_precision = 0.0 |
| sum_recall = 0.0 |
| sum_jaccard = 0.0 |
|
|
| for idx, sample in enumerate(test_split): |
| sample_id = sample.get("id") |
|
|
| |
| fulltext = (sample.get("fulltext") or "").strip() |
| fulltext_gold = sample.get("fulltext_subclaims") or [] |
| if fulltext and fulltext_gold: |
| user_prompt = build_subclaim_user_prompt( |
| fulltext, |
| is_summary=False, |
| max_subclaims=MAX_SUBCLAIMS_FULLTEXT, |
| ) |
| pred_text = generate_prediction(user_prompt) |
| pred_list = _parse_subclaim_list(pred_text) |
| precision, recall, jaccard = _subclaim_metrics(fulltext_gold, pred_list) |
|
|
| total_pairs += 1 |
| sum_precision += precision |
| sum_recall += recall |
| sum_jaccard += jaccard |
|
|
| results.append( |
| { |
| "sample_index": idx, |
| "id": sample_id, |
| "source_type": "fulltext", |
| "input_text": fulltext, |
| "gold_subclaims": fulltext_gold, |
| "predicted_subclaims": pred_list, |
| "precision": precision, |
| "recall": recall, |
| "jaccard": jaccard, |
| } |
| ) |
|
|
| |
| summary = (sample.get("summary") or "").strip() |
| summary_gold = sample.get("summary_subclaims") or [] |
| if summary and summary_gold: |
| user_prompt = build_subclaim_user_prompt( |
| summary, |
| is_summary=True, |
| max_subclaims=MAX_SUBCLAIMS_SUMMARY, |
| ) |
| pred_text = generate_prediction(user_prompt) |
| pred_list = _parse_subclaim_list(pred_text) |
| precision, recall, jaccard = _subclaim_metrics(summary_gold, pred_list) |
|
|
| total_pairs += 1 |
| sum_precision += precision |
| sum_recall += recall |
| sum_jaccard += jaccard |
|
|
| results.append( |
| { |
| "sample_index": idx, |
| "id": sample_id, |
| "source_type": "summary", |
| "input_text": summary, |
| "gold_subclaims": summary_gold, |
| "predicted_subclaims": pred_list, |
| "precision": precision, |
| "recall": recall, |
| "jaccard": jaccard, |
| } |
| ) |
|
|
| avg_precision = sum_precision / total_pairs if total_pairs else 0.0 |
| avg_recall = sum_recall / total_pairs if total_pairs else 0.0 |
| avg_jaccard = sum_jaccard / total_pairs if total_pairs else 0.0 |
|
|
| metrics = { |
| "mode": "bangla_subclaim_extraction", |
| "model_name": model_name, |
| "model_save_dir": save_dir, |
| "dataset_path": data_path, |
| "seed": seed, |
| "test_size": test_size, |
| "examples_evaluated": total_pairs, |
| "avg_precision": avg_precision, |
| "avg_recall": avg_recall, |
| "avg_jaccard": avg_jaccard, |
| "subclaim_score": avg_jaccard, |
| "timestamp": timestamp, |
| } |
| return results, metrics |
|
|
|
|
| results, accuracy_summary = evaluate_subclaim_mode(test_raw) |
|
|
| accuracy_summary["finetune_mode"] = "subclaim_extraction" |
| accuracy_summary["model_size"] = model_size |
| accuracy_summary["run_mode"] = run_mode |
| accuracy_summary["language"] = "bn" |
|
|
| predictions_path = os.path.join( |
| model_info_dir, |
| f"{model_tag}_test_inference_{timestamp}.json", |
| ) |
| accuracy_path = os.path.join( |
| ablation_dir, |
| f"{model_tag}_subclaim_{model_size}_{run_mode}_{timestamp}.json", |
| ) |
|
|
| with open(predictions_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
| with open(accuracy_path, "w", encoding="utf-8") as f: |
| json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Saved test inference to: {predictions_path}") |
| print(f"Saved test metrics to: {accuracy_path}") |
| print( |
| f"Avg Jaccard (subclaim_score): {accuracy_summary.get('subclaim_score', 0.0):.4f}" |
| ) |