import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "6" import ast import json import os from datetime import datetime import torch from datasets import Dataset from unsloth import FastModel from unsloth.chat_templates import ( get_chat_template, standardize_data_formats, train_on_responses_only, ) from trl import SFTConfig, SFTTrainer model_name = "unsloth/gemma-3-4b-it" data_path = "/home/mshahidul/readctrl/code/support_check/support_check_bn/finetune_dataset_subclaim_support_bn.json" test_size = 0.3 seed = 3407 finetune_mode = "subclaim_list" # "single_subclaim" or "subclaim_list" prompt_language = "en" # "bn" (Bangla) or "en" (English) run_mode = "finetune_and_eval" # "finetune_and_eval" or "eval_base_only" save_fp16_merged = False # whether to save merged fp16 model after finetuning def get_model_size_from_name(name): base = name.split("/")[-1] for part in base.split("-"): token = part.lower() if token.endswith("b") or token.endswith("m"): return part return "unknown" model_size = get_model_size_from_name(model_name) def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False, ).removeprefix("") for convo in convos ] return {"text": texts} def parse_label_array(raw_text): text = (raw_text or "").strip() if not text: return [] if "```" in text: text = text.replace("```json", "").replace("```", "").strip() start = text.find("[") end = text.rfind("]") if start != -1 and end != -1 and end > start: text = text[start : end + 1] parsed = None for parser in (json.loads, ast.literal_eval): try: parsed = parser(text) break except Exception: continue if not isinstance(parsed, list): return [] normalized = [] for item in parsed: if not isinstance(item, str): normalized.append("not_supported") continue label = item.strip().lower().replace("-", "_").replace(" ", "_") if label not in {"supported", "not_supported"}: label = "not_supported" normalized.append(label) return normalized def parse_single_label(raw_text): text = (raw_text or "").strip().lower() if "supported" in text and "not_supported" not in text: return "supported" if "not_supported" in text: return "not_supported" if "supported" in text: return "supported" return None def normalize_label(label): if label is None: return None label = str(label).strip().lower().replace("-", "_").replace(" ", "_") if label not in {"supported", "not_supported"}: return None return label def build_single_user_prompt(input_text, subclaim): if prompt_language == "en": return ( "You will be given a medical case description and one subclaim. " "Determine whether the subclaim is supported by the text.\n\n" f"Text:\n{input_text}\n\n" f"Subclaim:\n{subclaim}\n\n" "Reply with exactly one word: 'supported' or 'not_supported'." ) # Bangla (default) return ( "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একটি সাবক্লেইম দেওয়া হবে। " "সাবক্লেইমটি টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n" f"টেক্সট:\n{input_text}\n\n" f"সাবক্লেইম:\n{subclaim}\n\n" "শুধু একটি শব্দ দিয়ে উত্তর দিন: 'supported' অথবা 'not_supported'." ) def build_list_user_prompt(input_text, subclaims): numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims)) if prompt_language == "en": return ( "You will be given a medical case description and a list of subclaims. " "Determine for each subclaim whether it is supported by the text.\n\n" f"Text:\n{input_text}\n\n" f"List of subclaims:\n{numbered}\n\n" "Give the label for each subclaim in order. " "Reply with a JSON array only, e.g.:\n" '["supported", "not_supported", ...]\n' "Do not write anything else." ) # Bangla (default) return ( "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। " "প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n" f"টেক্সট:\n{input_text}\n\n" f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n" "প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। " "নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n" '["supported", "not_supported", ...]\n' "অন্য কিছু লিখবেন না।" ) def build_single_subclaim_examples(raw_records): examples = [] for record in raw_records: input_text = record.get("input_text", "") model_output = record.get("model_output") or {} items = model_output.get("items") or [] for item in items: subclaims = item.get("subclaims") or [] for sc in subclaims: subclaim_text = sc.get("subclaim", "") label = normalize_label(sc.get("label")) if not label: continue user_prompt = build_single_user_prompt(input_text, subclaim_text) examples.append( { "conversations": [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": label}, ], } ) return examples def build_list_subclaim_examples(raw_records): examples = [] for record in raw_records: input_text = record.get("input_text", "") model_output = record.get("model_output") or {} items = model_output.get("items") or [] all_subclaims = [] all_labels = [] for item in items: subclaims = item.get("subclaims") or [] for sc in subclaims: subclaim_text = sc.get("subclaim", "") label = normalize_label(sc.get("label")) if not label: continue all_subclaims.append(subclaim_text) all_labels.append(label) if not all_subclaims: continue user_prompt = build_list_user_prompt(input_text, all_subclaims) examples.append( { "conversations": [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": json.dumps(all_labels)}, ], } ) return examples def extract_conversation_pair(conversations): user_prompt = "" gold_response = "" for message in conversations: role = message.get("role") or message.get("from") content = message.get("content", "") if role == "user" and not user_prompt: user_prompt = content elif role == "assistant" and not gold_response: gold_response = content return user_prompt, gold_response def generate_prediction(user_prompt): prompt = tokenizer.apply_chat_template( [{"role": "user", "content": user_prompt}], tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=False, temperature=0.0, use_cache=True, ) generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() # 1. Load Model and Tokenizer model, tokenizer = FastModel.from_pretrained( model_name=model_name, max_seq_length=4092, load_in_4bit=True, ) # 2. Data Preparation tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) raw_dataset = Dataset.from_list(raw_data) split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True) train_raw = split_dataset["train"] test_raw = split_dataset["test"] if finetune_mode == "single_subclaim": train_examples = build_single_subclaim_examples(train_raw) elif finetune_mode == "subclaim_list": train_examples = build_list_subclaim_examples(train_raw) else: raise ValueError(f"Unsupported finetune_mode: {finetune_mode}") train_dataset = Dataset.from_list(train_examples) train_dataset = train_dataset.map(formatting_prompts_func, batched=True) # 3. Optional Finetuning if run_mode == "finetune_and_eval": # Add LoRA adapters for finetuning model = FastModel.get_peft_model( model, r=8, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", random_state=seed, ) # Training setup trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, dataset_text_field="text", max_seq_length=2048, args=SFTConfig( per_device_train_batch_size=2, gradient_accumulation_steps=4, warmup_steps=5, max_steps=60, learning_rate=2e-4, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=seed, output_dir="outputs", report_to="none", ), ) # Masking to train on assistant responses only trainer = train_on_responses_only( trainer, instruction_part="user\n", response_part="model\n", ) # Execute training save_dir = f"/home/mshahidul/readctrl_model/support_checking_bn/{model_name.split('/')[-1]}" os.makedirs(save_dir, exist_ok=True) trainer.train() # Optional: save in float16 merged format if save_fp16_merged: model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit") tokenizer.save_pretrained(save_dir) elif run_mode == "eval_base_only": # No finetuning; evaluate base model save_dir = f"BASE_MODEL:{model_name}" else: raise ValueError(f"Unsupported run_mode: {run_mode}") # 4. Test-set Inference + Accuracy FastModel.for_inference(model) model.eval() model_info_dir = "/home/mshahidul/readctrl/code/support_check/model_info" ablation_dir = "/home/mshahidul/readctrl/code/support_check/support_check_bn/ablation_studies" os.makedirs(model_info_dir, exist_ok=True) os.makedirs(ablation_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = model_name.split("/")[-1].replace(".", "_") def evaluate_single_subclaim_mode(test_split): results = [] total = 0 correct = 0 tp = fp = fn = tn = 0 for idx, sample in enumerate(test_split): input_text = sample.get("input_text", "") model_output = sample.get("model_output") or {} items = model_output.get("items") or [] for item in items: subclaims = item.get("subclaims") or [] for sc in subclaims: subclaim_text = sc.get("subclaim", "") gold_label = normalize_label(sc.get("label")) if not gold_label: continue user_prompt = build_single_user_prompt(input_text, subclaim_text) pred_text = generate_prediction(user_prompt) pred_label = parse_single_label(pred_text) or "not_supported" total += 1 is_correct = pred_label == gold_label if is_correct: correct += 1 if gold_label == "supported" and pred_label == "supported": tp += 1 elif gold_label == "supported" and pred_label == "not_supported": fn += 1 elif gold_label == "not_supported" and pred_label == "supported": fp += 1 elif gold_label == "not_supported" and pred_label == "not_supported": tn += 1 results.append( { "sample_index": idx, "input_text": input_text, "subclaim": subclaim_text, "gold_label": gold_label, "predicted_label": pred_label, "correct": is_correct, } ) accuracy = correct / total if total else 0.0 precision = tp / (tp + fp) if (tp + fp) else 0.0 recall = tp / (tp + fn) if (tp + fn) else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 metrics = { "mode": "single_subclaim", "model_name": model_name, "model_save_dir": save_dir, "dataset_path": data_path, "seed": seed, "test_size": test_size, "examples_evaluated": total, "accuracy": accuracy, "precision_supported": precision, "recall_supported": recall, "f1_supported": f1, "tp_supported": tp, "fp_supported": fp, "fn_supported": fn, "tn_supported": tn, "timestamp": timestamp, } return results, metrics def evaluate_subclaim_list_mode(test_split): results = [] total_samples = 0 exact_match_correct = 0 total_subclaims = 0 correct_subclaims = 0 tp = fp = fn = tn = 0 for idx, sample in enumerate(test_split): input_text = sample.get("input_text", "") model_output = sample.get("model_output") or {} items = model_output.get("items") or [] subclaims = [] gold_labels = [] for item in items: for sc in item.get("subclaims") or []: subclaim_text = sc.get("subclaim", "") label = normalize_label(sc.get("label")) if not label: continue subclaims.append(subclaim_text) gold_labels.append(label) if not subclaims: continue user_prompt = build_list_user_prompt(input_text, subclaims) pred_text = generate_prediction(user_prompt) pred_labels = parse_label_array(pred_text) if not pred_labels: pred_labels = ["not_supported"] * len(gold_labels) if len(pred_labels) < len(gold_labels): pred_labels = pred_labels + ["not_supported"] * (len(gold_labels) - len(pred_labels)) elif len(pred_labels) > len(gold_labels): pred_labels = pred_labels[: len(gold_labels)] sample_correct = 0 for gold_label, pred_label in zip(gold_labels, pred_labels): total_subclaims += 1 if pred_label == gold_label: correct_subclaims += 1 sample_correct += 1 if gold_label == "supported" and pred_label == "supported": tp += 1 elif gold_label == "supported" and pred_label == "not_supported": fn += 1 elif gold_label == "not_supported" and pred_label == "supported": fp += 1 elif gold_label == "not_supported" and pred_label == "not_supported": tn += 1 total_samples += 1 exact_match = sample_correct == len(gold_labels) if exact_match: exact_match_correct += 1 results.append( { "sample_index": idx, "input_text": input_text, "subclaims": subclaims, "gold_labels": gold_labels, "predicted_labels": pred_labels, "exact_match": exact_match, "per_sample_accuracy": sample_correct / len(gold_labels), } ) accuracy = correct_subclaims / total_subclaims if total_subclaims else 0.0 precision = tp / (tp + fp) if (tp + fp) else 0.0 recall = tp / (tp + fn) if (tp + fn) else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 exact_match_accuracy = ( exact_match_correct / total_samples if total_samples else 0.0 ) metrics = { "mode": "subclaim_list", "model_name": model_name, "model_save_dir": save_dir, "dataset_path": data_path, "seed": seed, "test_size": test_size, "test_samples_evaluated": total_samples, "total_subclaims": total_subclaims, "correct_subclaims": correct_subclaims, "subclaim_accuracy": accuracy, "exact_match_accuracy": exact_match_accuracy, "precision_supported": precision, "recall_supported": recall, "f1_supported": f1, "tp_supported": tp, "fp_supported": fp, "fn_supported": fn, "tn_supported": tn, "timestamp": timestamp, } return results, metrics if finetune_mode == "single_subclaim": results, accuracy_summary = evaluate_single_subclaim_mode(test_raw) else: results, accuracy_summary = evaluate_subclaim_list_mode(test_raw) accuracy_summary["finetune_mode"] = finetune_mode accuracy_summary["model_size"] = model_size accuracy_summary["run_mode"] = run_mode predictions_path = os.path.join( model_info_dir, f"{model_tag}_test_inference_{timestamp}.json", ) accuracy_path = os.path.join( ablation_dir, f"{model_tag}_{finetune_mode}_{model_size}_{run_mode}_{timestamp}.json", ) with open(predictions_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) with open(accuracy_path, "w", encoding="utf-8") as f: json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) print(f"Saved test inference to: {predictions_path}") print(f"Saved test accuracy to: {accuracy_path}") print(f"Accuracy: {accuracy_summary.get('accuracy', accuracy_summary.get('subclaim_accuracy', 0.0)):.4f}") print(f"F1 (supported class): {accuracy_summary.get('f1_supported', 0.0):.4f}")