import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "7" import json import os from datetime import datetime import torch from datasets import Dataset from unsloth import FastModel from unsloth.chat_templates import ( get_chat_template, standardize_data_formats, train_on_responses_only, ) from trl import SFTConfig, SFTTrainer model_name = "unsloth/gemma-3-4b-it" data_path = "/home/mshahidul/readctrl/data/extracting_subclaim/bn/multiclinsum_test_en2bn_gemma(0_1000)_3396_extracted_subclaims_bn_0_end.json" test_size = 0.2 # 1 - train_ratio (0.8) seed = 42 run_mode = "finetune_and_eval" # "finetune_and_eval" or "eval_base_only" save_fp16_merged = True # whether to save merged fp16 model after finetuning # Max subclaims to request in prompts MAX_SUBCLAIMS_FULLTEXT = 80 MAX_SUBCLAIMS_SUMMARY = 40 def get_model_size_from_name(name): base = name.split("/")[-1] for part in base.split("-"): token = part.lower() if token.endswith("b") or token.endswith("m"): return part return "unknown" model_size = get_model_size_from_name(model_name) def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False, ).removeprefix("") for convo in convos ] return {"text": texts} def build_subclaim_user_prompt(medical_text, is_summary=False, max_subclaims=None): """ Build a Bangla instruction prompt for subclaim extraction. Uses the same wording as `extraction_prompt` in `extract_bn_subclaims_vllm.py`, with an optional cap on the number of subclaims described in the instructions. """ base_prompt = f""" You are an expert medical annotator. The following text is in Bangla (Bengali). Your task is to extract granular, factual subclaims from the provided medical text. A subclaim is the smallest standalone factual unit that can be independently verified. Instructions: 1. Read the Bangla medical text carefully. 2. Extract factual statements explicitly stated in the text. 3. Each subclaim must: - Be in Bangla (same language as the input) - Contain exactly ONE factual assertion - Come directly from the text (no inference or interpretation) - Preserve original wording as much as possible - Include any negation, uncertainty, or qualifier 4. Do NOT: - Combine multiple facts into one subclaim - Add new information - Translate to another language 5. Return ONLY a valid JSON array of strings. 6. Use double quotes and valid JSON formatting only (no markdown, no commentary). Medical Text (Bangla): {medical_text} Return format: [ "subclaim 1", "subclaim 2" ] """.strip() # Optionally mention a maximum number of subclaims, but only in text, # so we keep the core wording identical to the vLLM prompt. if max_subclaims is not None: limit_note = ( f"\n\nNote: Extract at most {max_subclaims} subclaims, prioritizing the most important factual statements." ) return base_prompt + limit_note return base_prompt def build_subclaim_examples(raw_records): """ Build chat-style training examples for Bangla subclaim extraction. Each record can contribute up to two examples: - fulltext -> fulltext_subclaims - summary -> summary_subclaims """ examples = [] for record in raw_records: fulltext = (record.get("fulltext") or "").strip() fulltext_subclaims = record.get("fulltext_subclaims") or [] summary = (record.get("summary") or "").strip() summary_subclaims = record.get("summary_subclaims") or [] if fulltext and fulltext_subclaims: user_prompt = build_subclaim_user_prompt( fulltext, is_summary=False, max_subclaims=MAX_SUBCLAIMS_FULLTEXT, ) assistant_content = json.dumps(fulltext_subclaims, ensure_ascii=False) examples.append( { "conversations": [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_content}, ], } ) if summary and summary_subclaims: user_prompt = build_subclaim_user_prompt( summary, is_summary=True, max_subclaims=MAX_SUBCLAIMS_SUMMARY, ) assistant_content = json.dumps(summary_subclaims, ensure_ascii=False) examples.append( { "conversations": [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_content}, ], } ) return examples def extract_conversation_pair(conversations): user_prompt = "" gold_response = "" for message in conversations: role = message.get("role") or message.get("from") content = message.get("content", "") if role == "user" and not user_prompt: user_prompt = content elif role == "assistant" and not gold_response: gold_response = content return user_prompt, gold_response def generate_prediction(user_prompt): prompt = tokenizer.apply_chat_template( [{"role": "user", "content": user_prompt}], tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=1024, do_sample=False, temperature=0.0, use_cache=True, ) generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() # 1. Load Model and Tokenizer model, tokenizer = FastModel.from_pretrained( model_name=model_name, max_seq_length=4092, load_in_4bit=True, ) # 2. Data Preparation tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) raw_dataset = Dataset.from_list(raw_data) split_dataset = raw_dataset.train_test_split( test_size=test_size, seed=seed, shuffle=True ) train_raw = split_dataset["train"] test_raw = split_dataset["test"] train_examples = build_subclaim_examples(train_raw) train_dataset = Dataset.from_list(train_examples) train_dataset = train_dataset.map(formatting_prompts_func, batched=True) # 3. Optional Finetuning if run_mode == "finetune_and_eval": # Add LoRA adapters for finetuning model = FastModel.get_peft_model( model, r=8, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", random_state=seed, ) # Training setup trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, dataset_text_field="text", max_seq_length=2048, args=SFTConfig( per_device_train_batch_size=2, gradient_accumulation_steps=4, warmup_steps=5, max_steps=60, learning_rate=2e-4, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=seed, output_dir="outputs", report_to="none", ), ) # Masking to train on assistant responses only trainer = train_on_responses_only( trainer, instruction_part="user\n", response_part="model\n", ) # Execute training save_dir = f"/home/mshahidul/readctrl_model/subclaim_support_extraction_bn/{model_name.split('/')[-1]}" os.makedirs(save_dir, exist_ok=True) trainer.train() # Optional: save in float16 merged format if save_fp16_merged: model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit") tokenizer.save_pretrained(save_dir) elif run_mode == "eval_base_only": # No finetuning; evaluate base model save_dir = f"BASE_MODEL:{model_name}" else: raise ValueError(f"Unsupported run_mode: {run_mode}") # 4. Test-set Inference + Accuracy FastModel.for_inference(model) model.eval() model_info_dir = ( "/home/mshahidul/readctrl/code/subclaim_support_extraction/inference_data" ) ablation_dir = ( "/home/mshahidul/readctrl/code/subclaim_support_extraction/ablation_studies" ) os.makedirs(model_info_dir, exist_ok=True) os.makedirs(ablation_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = model_name.split("/")[-1].replace(".", "_") def _parse_subclaim_list(text): """Best-effort parse of a JSON list of subclaims from model output.""" if not text: return [] text = text.strip() # Strip any trailing reasoning markup if present if "" in text: text = text.split("")[-1].strip() start_idx = text.find("[") end_idx = text.rfind("]") + 1 if start_idx != -1 and end_idx > start_idx: text_slice = text[start_idx:end_idx] else: text_slice = text try: parsed = json.loads(text_slice) if isinstance(parsed, list): return [str(s).strip() for s in parsed if s] except Exception: return [] return [] def _subclaim_metrics(gold, pred): """Compute simple set-based precision/recall/Jaccard for subclaim lists.""" gold_set = {s.strip() for s in gold if s} pred_set = {s.strip() for s in pred if s} if not gold_set and not pred_set: return 1.0, 1.0, 1.0 if not pred_set: return 0.0, 0.0, 0.0 inter = gold_set & pred_set union = gold_set | pred_set precision = len(inter) / len(pred_set) if pred_set else 0.0 recall = len(inter) / len(gold_set) if gold_set else 0.0 jaccard = len(inter) / len(union) if union else 0.0 return precision, recall, jaccard def evaluate_subclaim_mode(test_split): """ Evaluate subclaim extraction on the held-out split. For each example, we prompt on fulltext and/or summary (if present) and compare the predicted subclaim list with the gold subclaims. """ results = [] total_pairs = 0 sum_precision = 0.0 sum_recall = 0.0 sum_jaccard = 0.0 for idx, sample in enumerate(test_split): sample_id = sample.get("id") # Fulltext side fulltext = (sample.get("fulltext") or "").strip() fulltext_gold = sample.get("fulltext_subclaims") or [] if fulltext and fulltext_gold: user_prompt = build_subclaim_user_prompt( fulltext, is_summary=False, max_subclaims=MAX_SUBCLAIMS_FULLTEXT, ) pred_text = generate_prediction(user_prompt) pred_list = _parse_subclaim_list(pred_text) precision, recall, jaccard = _subclaim_metrics(fulltext_gold, pred_list) total_pairs += 1 sum_precision += precision sum_recall += recall sum_jaccard += jaccard results.append( { "sample_index": idx, "id": sample_id, "source_type": "fulltext", "input_text": fulltext, "gold_subclaims": fulltext_gold, "predicted_subclaims": pred_list, "precision": precision, "recall": recall, "jaccard": jaccard, } ) # Summary side summary = (sample.get("summary") or "").strip() summary_gold = sample.get("summary_subclaims") or [] if summary and summary_gold: user_prompt = build_subclaim_user_prompt( summary, is_summary=True, max_subclaims=MAX_SUBCLAIMS_SUMMARY, ) pred_text = generate_prediction(user_prompt) pred_list = _parse_subclaim_list(pred_text) precision, recall, jaccard = _subclaim_metrics(summary_gold, pred_list) total_pairs += 1 sum_precision += precision sum_recall += recall sum_jaccard += jaccard results.append( { "sample_index": idx, "id": sample_id, "source_type": "summary", "input_text": summary, "gold_subclaims": summary_gold, "predicted_subclaims": pred_list, "precision": precision, "recall": recall, "jaccard": jaccard, } ) avg_precision = sum_precision / total_pairs if total_pairs else 0.0 avg_recall = sum_recall / total_pairs if total_pairs else 0.0 avg_jaccard = sum_jaccard / total_pairs if total_pairs else 0.0 metrics = { "mode": "bangla_subclaim_extraction", "model_name": model_name, "model_save_dir": save_dir, "dataset_path": data_path, "seed": seed, "test_size": test_size, "examples_evaluated": total_pairs, "avg_precision": avg_precision, "avg_recall": avg_recall, "avg_jaccard": avg_jaccard, "subclaim_score": avg_jaccard, "timestamp": timestamp, } return results, metrics results, accuracy_summary = evaluate_subclaim_mode(test_raw) accuracy_summary["finetune_mode"] = "subclaim_extraction" accuracy_summary["model_size"] = model_size accuracy_summary["run_mode"] = run_mode accuracy_summary["language"] = "bn" predictions_path = os.path.join( model_info_dir, f"{model_tag}_test_inference_{timestamp}.json", ) accuracy_path = os.path.join( ablation_dir, f"{model_tag}_subclaim_{model_size}_{run_mode}_{timestamp}.json", ) with open(predictions_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) with open(accuracy_path, "w", encoding="utf-8") as f: json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) print(f"Saved test inference to: {predictions_path}") print(f"Saved test metrics to: {accuracy_path}") print( f"Avg Jaccard (subclaim_score): {accuracy_summary.get('subclaim_score', 0.0):.4f}" )