| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
| import ast |
| import json |
| import os |
| from datetime import datetime |
|
|
| import torch |
| from datasets import Dataset |
|
|
| from unsloth import FastModel |
| from unsloth.chat_templates import ( |
| get_chat_template, |
| standardize_data_formats, |
| train_on_responses_only, |
| ) |
| from trl import SFTConfig, SFTTrainer |
|
|
| model_name = "unsloth/gemma-3-4b-it" |
| data_path = "/home/mshahidul/readctrl/code/support_check/support_check_bn/finetune_dataset_subclaim_support_bn.json" |
| test_size = 0.3 |
| seed = 3407 |
| finetune_mode = "subclaim_list" |
| prompt_language = "en" |
| run_mode = "finetune_and_eval" |
| save_fp16_merged = False |
|
|
|
|
| def get_model_size_from_name(name): |
| base = name.split("/")[-1] |
| for part in base.split("-"): |
| token = part.lower() |
| if token.endswith("b") or token.endswith("m"): |
| return part |
| return "unknown" |
|
|
|
|
| model_size = get_model_size_from_name(model_name) |
|
|
|
|
| def formatting_prompts_func(examples): |
| convos = examples["conversations"] |
| texts = [ |
| tokenizer.apply_chat_template( |
| convo, |
| tokenize=False, |
| add_generation_prompt=False, |
| ).removeprefix("<bos>") |
| for convo in convos |
| ] |
| return {"text": texts} |
|
|
|
|
| def parse_label_array(raw_text): |
| text = (raw_text or "").strip() |
| if not text: |
| return [] |
|
|
| if "```" in text: |
| text = text.replace("```json", "").replace("```", "").strip() |
|
|
| start = text.find("[") |
| end = text.rfind("]") |
| if start != -1 and end != -1 and end > start: |
| text = text[start : end + 1] |
|
|
| parsed = None |
| for parser in (json.loads, ast.literal_eval): |
| try: |
| parsed = parser(text) |
| break |
| except Exception: |
| continue |
|
|
| if not isinstance(parsed, list): |
| return [] |
|
|
| normalized = [] |
| for item in parsed: |
| if not isinstance(item, str): |
| normalized.append("not_supported") |
| continue |
| label = item.strip().lower().replace("-", "_").replace(" ", "_") |
| if label not in {"supported", "not_supported"}: |
| label = "not_supported" |
| normalized.append(label) |
| return normalized |
|
|
|
|
| def parse_single_label(raw_text): |
| text = (raw_text or "").strip().lower() |
| if "supported" in text and "not_supported" not in text: |
| return "supported" |
| if "not_supported" in text: |
| return "not_supported" |
| if "supported" in text: |
| return "supported" |
| return None |
|
|
|
|
| def normalize_label(label): |
| if label is None: |
| return None |
| label = str(label).strip().lower().replace("-", "_").replace(" ", "_") |
| if label not in {"supported", "not_supported"}: |
| return None |
| return label |
|
|
|
|
| def build_single_user_prompt(input_text, subclaim): |
| if prompt_language == "en": |
| return ( |
| "You will be given a medical case description and one subclaim. " |
| "Determine whether the subclaim is supported by the text.\n\n" |
| f"Text:\n{input_text}\n\n" |
| f"Subclaim:\n{subclaim}\n\n" |
| "Reply with exactly one word: 'supported' or 'not_supported'." |
| ) |
| |
| return ( |
| "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একটি সাবক্লেইম দেওয়া হবে। " |
| "সাবক্লেইমটি টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n" |
| f"টেক্সট:\n{input_text}\n\n" |
| f"সাবক্লেইম:\n{subclaim}\n\n" |
| "শুধু একটি শব্দ দিয়ে উত্তর দিন: 'supported' অথবা 'not_supported'." |
| ) |
|
|
|
|
| def build_list_user_prompt(input_text, subclaims): |
| numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims)) |
| if prompt_language == "en": |
| return ( |
| "You will be given a medical case description and a list of subclaims. " |
| "Determine for each subclaim whether it is supported by the text.\n\n" |
| f"Text:\n{input_text}\n\n" |
| f"List of subclaims:\n{numbered}\n\n" |
| "Give the label for each subclaim in order. " |
| "Reply with a JSON array only, e.g.:\n" |
| '["supported", "not_supported", ...]\n' |
| "Do not write anything else." |
| ) |
| |
| return ( |
| "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। " |
| "প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n" |
| f"টেক্সট:\n{input_text}\n\n" |
| f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n" |
| "প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। " |
| "নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n" |
| '["supported", "not_supported", ...]\n' |
| "অন্য কিছু লিখবেন না।" |
| ) |
|
|
|
|
| def build_single_subclaim_examples(raw_records): |
| examples = [] |
| for record in raw_records: |
| input_text = record.get("input_text", "") |
| model_output = record.get("model_output") or {} |
| items = model_output.get("items") or [] |
| for item in items: |
| subclaims = item.get("subclaims") or [] |
| for sc in subclaims: |
| subclaim_text = sc.get("subclaim", "") |
| label = normalize_label(sc.get("label")) |
| if not label: |
| continue |
| user_prompt = build_single_user_prompt(input_text, subclaim_text) |
| examples.append( |
| { |
| "conversations": [ |
| {"role": "user", "content": user_prompt}, |
| {"role": "assistant", "content": label}, |
| ], |
| } |
| ) |
| return examples |
|
|
|
|
| def build_list_subclaim_examples(raw_records): |
| examples = [] |
| for record in raw_records: |
| input_text = record.get("input_text", "") |
| model_output = record.get("model_output") or {} |
| items = model_output.get("items") or [] |
| all_subclaims = [] |
| all_labels = [] |
| for item in items: |
| subclaims = item.get("subclaims") or [] |
| for sc in subclaims: |
| subclaim_text = sc.get("subclaim", "") |
| label = normalize_label(sc.get("label")) |
| if not label: |
| continue |
| all_subclaims.append(subclaim_text) |
| all_labels.append(label) |
| if not all_subclaims: |
| continue |
| user_prompt = build_list_user_prompt(input_text, all_subclaims) |
| examples.append( |
| { |
| "conversations": [ |
| {"role": "user", "content": user_prompt}, |
| {"role": "assistant", "content": json.dumps(all_labels)}, |
| ], |
| } |
| ) |
| return examples |
|
|
|
|
| def extract_conversation_pair(conversations): |
| user_prompt = "" |
| gold_response = "" |
| for message in conversations: |
| role = message.get("role") or message.get("from") |
| content = message.get("content", "") |
| if role == "user" and not user_prompt: |
| user_prompt = content |
| elif role == "assistant" and not gold_response: |
| gold_response = content |
| return user_prompt, gold_response |
|
|
|
|
| def generate_prediction(user_prompt): |
| prompt = tokenizer.apply_chat_template( |
| [{"role": "user", "content": user_prompt}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device) |
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=256, |
| do_sample=False, |
| temperature=0.0, |
| use_cache=True, |
| ) |
| generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] |
| return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
|
|
|
|
| |
| model, tokenizer = FastModel.from_pretrained( |
| model_name=model_name, |
| max_seq_length=4092, |
| load_in_4bit=True, |
| ) |
|
|
| |
| tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") |
| with open(data_path, "r", encoding="utf-8") as f: |
| raw_data = json.load(f) |
|
|
| raw_dataset = Dataset.from_list(raw_data) |
| split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True) |
| train_raw = split_dataset["train"] |
| test_raw = split_dataset["test"] |
|
|
| if finetune_mode == "single_subclaim": |
| train_examples = build_single_subclaim_examples(train_raw) |
| elif finetune_mode == "subclaim_list": |
| train_examples = build_list_subclaim_examples(train_raw) |
| else: |
| raise ValueError(f"Unsupported finetune_mode: {finetune_mode}") |
|
|
| train_dataset = Dataset.from_list(train_examples) |
| train_dataset = train_dataset.map(formatting_prompts_func, batched=True) |
|
|
| |
| if run_mode == "finetune_and_eval": |
| |
| model = FastModel.get_peft_model( |
| model, |
| r=8, |
| target_modules=[ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| ], |
| lora_alpha=16, |
| lora_dropout=0, |
| bias="none", |
| random_state=seed, |
| ) |
|
|
| |
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=train_dataset, |
| dataset_text_field="text", |
| max_seq_length=2048, |
| args=SFTConfig( |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=4, |
| warmup_steps=5, |
| max_steps=60, |
| learning_rate=2e-4, |
| fp16=not torch.cuda.is_bf16_supported(), |
| bf16=torch.cuda.is_bf16_supported(), |
| logging_steps=1, |
| optim="adamw_8bit", |
| weight_decay=0.01, |
| lr_scheduler_type="linear", |
| seed=seed, |
| output_dir="outputs", |
| report_to="none", |
| ), |
| ) |
|
|
| |
| trainer = train_on_responses_only( |
| trainer, |
| instruction_part="<start_of_turn>user\n", |
| response_part="<start_of_turn>model\n", |
| ) |
|
|
| |
| save_dir = f"/home/mshahidul/readctrl_model/support_checking_bn/{model_name.split('/')[-1]}" |
| os.makedirs(save_dir, exist_ok=True) |
| trainer.train() |
|
|
| |
| if save_fp16_merged: |
| model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit") |
| tokenizer.save_pretrained(save_dir) |
|
|
| elif run_mode == "eval_base_only": |
| |
| save_dir = f"BASE_MODEL:{model_name}" |
| else: |
| raise ValueError(f"Unsupported run_mode: {run_mode}") |
|
|
| |
| FastModel.for_inference(model) |
| model.eval() |
|
|
| model_info_dir = "/home/mshahidul/readctrl/code/support_check/model_info" |
| ablation_dir = "/home/mshahidul/readctrl/code/support_check/support_check_bn/ablation_studies" |
| os.makedirs(model_info_dir, exist_ok=True) |
| os.makedirs(ablation_dir, exist_ok=True) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| model_tag = model_name.split("/")[-1].replace(".", "_") |
|
|
| def evaluate_single_subclaim_mode(test_split): |
| results = [] |
| total = 0 |
| correct = 0 |
| tp = fp = fn = tn = 0 |
|
|
| for idx, sample in enumerate(test_split): |
| input_text = sample.get("input_text", "") |
| model_output = sample.get("model_output") or {} |
| items = model_output.get("items") or [] |
|
|
| for item in items: |
| subclaims = item.get("subclaims") or [] |
| for sc in subclaims: |
| subclaim_text = sc.get("subclaim", "") |
| gold_label = normalize_label(sc.get("label")) |
| if not gold_label: |
| continue |
|
|
| user_prompt = build_single_user_prompt(input_text, subclaim_text) |
| pred_text = generate_prediction(user_prompt) |
| pred_label = parse_single_label(pred_text) or "not_supported" |
|
|
| total += 1 |
| is_correct = pred_label == gold_label |
| if is_correct: |
| correct += 1 |
|
|
| if gold_label == "supported" and pred_label == "supported": |
| tp += 1 |
| elif gold_label == "supported" and pred_label == "not_supported": |
| fn += 1 |
| elif gold_label == "not_supported" and pred_label == "supported": |
| fp += 1 |
| elif gold_label == "not_supported" and pred_label == "not_supported": |
| tn += 1 |
|
|
| results.append( |
| { |
| "sample_index": idx, |
| "input_text": input_text, |
| "subclaim": subclaim_text, |
| "gold_label": gold_label, |
| "predicted_label": pred_label, |
| "correct": is_correct, |
| } |
| ) |
|
|
| accuracy = correct / total if total else 0.0 |
| precision = tp / (tp + fp) if (tp + fp) else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) else 0.0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 |
|
|
| metrics = { |
| "mode": "single_subclaim", |
| "model_name": model_name, |
| "model_save_dir": save_dir, |
| "dataset_path": data_path, |
| "seed": seed, |
| "test_size": test_size, |
| "examples_evaluated": total, |
| "accuracy": accuracy, |
| "precision_supported": precision, |
| "recall_supported": recall, |
| "f1_supported": f1, |
| "tp_supported": tp, |
| "fp_supported": fp, |
| "fn_supported": fn, |
| "tn_supported": tn, |
| "timestamp": timestamp, |
| } |
| return results, metrics |
|
|
|
|
| def evaluate_subclaim_list_mode(test_split): |
| results = [] |
| total_samples = 0 |
| exact_match_correct = 0 |
| total_subclaims = 0 |
| correct_subclaims = 0 |
| tp = fp = fn = tn = 0 |
|
|
| for idx, sample in enumerate(test_split): |
| input_text = sample.get("input_text", "") |
| model_output = sample.get("model_output") or {} |
| items = model_output.get("items") or [] |
|
|
| subclaims = [] |
| gold_labels = [] |
| for item in items: |
| for sc in item.get("subclaims") or []: |
| subclaim_text = sc.get("subclaim", "") |
| label = normalize_label(sc.get("label")) |
| if not label: |
| continue |
| subclaims.append(subclaim_text) |
| gold_labels.append(label) |
|
|
| if not subclaims: |
| continue |
|
|
| user_prompt = build_list_user_prompt(input_text, subclaims) |
| pred_text = generate_prediction(user_prompt) |
| pred_labels = parse_label_array(pred_text) |
|
|
| if not pred_labels: |
| pred_labels = ["not_supported"] * len(gold_labels) |
|
|
| if len(pred_labels) < len(gold_labels): |
| pred_labels = pred_labels + ["not_supported"] * (len(gold_labels) - len(pred_labels)) |
| elif len(pred_labels) > len(gold_labels): |
| pred_labels = pred_labels[: len(gold_labels)] |
|
|
| sample_correct = 0 |
| for gold_label, pred_label in zip(gold_labels, pred_labels): |
| total_subclaims += 1 |
| if pred_label == gold_label: |
| correct_subclaims += 1 |
| sample_correct += 1 |
|
|
| if gold_label == "supported" and pred_label == "supported": |
| tp += 1 |
| elif gold_label == "supported" and pred_label == "not_supported": |
| fn += 1 |
| elif gold_label == "not_supported" and pred_label == "supported": |
| fp += 1 |
| elif gold_label == "not_supported" and pred_label == "not_supported": |
| tn += 1 |
|
|
| total_samples += 1 |
| exact_match = sample_correct == len(gold_labels) |
| if exact_match: |
| exact_match_correct += 1 |
|
|
| results.append( |
| { |
| "sample_index": idx, |
| "input_text": input_text, |
| "subclaims": subclaims, |
| "gold_labels": gold_labels, |
| "predicted_labels": pred_labels, |
| "exact_match": exact_match, |
| "per_sample_accuracy": sample_correct / len(gold_labels), |
| } |
| ) |
|
|
| accuracy = correct_subclaims / total_subclaims if total_subclaims else 0.0 |
| precision = tp / (tp + fp) if (tp + fp) else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) else 0.0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 |
| exact_match_accuracy = ( |
| exact_match_correct / total_samples if total_samples else 0.0 |
| ) |
|
|
| metrics = { |
| "mode": "subclaim_list", |
| "model_name": model_name, |
| "model_save_dir": save_dir, |
| "dataset_path": data_path, |
| "seed": seed, |
| "test_size": test_size, |
| "test_samples_evaluated": total_samples, |
| "total_subclaims": total_subclaims, |
| "correct_subclaims": correct_subclaims, |
| "subclaim_accuracy": accuracy, |
| "exact_match_accuracy": exact_match_accuracy, |
| "precision_supported": precision, |
| "recall_supported": recall, |
| "f1_supported": f1, |
| "tp_supported": tp, |
| "fp_supported": fp, |
| "fn_supported": fn, |
| "tn_supported": tn, |
| "timestamp": timestamp, |
| } |
| return results, metrics |
|
|
|
|
| if finetune_mode == "single_subclaim": |
| results, accuracy_summary = evaluate_single_subclaim_mode(test_raw) |
| else: |
| results, accuracy_summary = evaluate_subclaim_list_mode(test_raw) |
|
|
| accuracy_summary["finetune_mode"] = finetune_mode |
| accuracy_summary["model_size"] = model_size |
| accuracy_summary["run_mode"] = run_mode |
|
|
| predictions_path = os.path.join( |
| model_info_dir, |
| f"{model_tag}_test_inference_{timestamp}.json", |
| ) |
| accuracy_path = os.path.join( |
| ablation_dir, |
| f"{model_tag}_{finetune_mode}_{model_size}_{run_mode}_{timestamp}.json", |
| ) |
|
|
| with open(predictions_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
| with open(accuracy_path, "w", encoding="utf-8") as f: |
| json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Saved test inference to: {predictions_path}") |
| print(f"Saved test accuracy to: {accuracy_path}") |
| print(f"Accuracy: {accuracy_summary.get('accuracy', accuracy_summary.get('subclaim_accuracy', 0.0)):.4f}") |
| print(f"F1 (supported class): {accuracy_summary.get('f1_supported', 0.0):.4f}") |