mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import ast
import json
import os
from datetime import datetime
import torch
from datasets import Dataset
from unsloth import FastModel
from unsloth.chat_templates import (
get_chat_template,
standardize_data_formats,
train_on_responses_only,
)
from trl import SFTConfig, SFTTrainer
model_name = "unsloth/gemma-3-4b-it"
data_path = "/home/mshahidul/readctrl/code/support_check/support_check_bn/finetune_dataset_subclaim_support_bn.json"
test_size = 0.3
seed = 3407
finetune_mode = "subclaim_list" # "single_subclaim" or "subclaim_list"
prompt_language = "en" # "bn" (Bangla) or "en" (English)
run_mode = "finetune_and_eval" # "finetune_and_eval" or "eval_base_only"
save_fp16_merged = False # whether to save merged fp16 model after finetuning
def get_model_size_from_name(name):
base = name.split("/")[-1]
for part in base.split("-"):
token = part.lower()
if token.endswith("b") or token.endswith("m"):
return part
return "unknown"
model_size = get_model_size_from_name(model_name)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<bos>")
for convo in convos
]
return {"text": texts}
def parse_label_array(raw_text):
text = (raw_text or "").strip()
if not text:
return []
if "```" in text:
text = text.replace("```json", "").replace("```", "").strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
text = text[start : end + 1]
parsed = None
for parser in (json.loads, ast.literal_eval):
try:
parsed = parser(text)
break
except Exception:
continue
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, str):
normalized.append("not_supported")
continue
label = item.strip().lower().replace("-", "_").replace(" ", "_")
if label not in {"supported", "not_supported"}:
label = "not_supported"
normalized.append(label)
return normalized
def parse_single_label(raw_text):
text = (raw_text or "").strip().lower()
if "supported" in text and "not_supported" not in text:
return "supported"
if "not_supported" in text:
return "not_supported"
if "supported" in text:
return "supported"
return None
def normalize_label(label):
if label is None:
return None
label = str(label).strip().lower().replace("-", "_").replace(" ", "_")
if label not in {"supported", "not_supported"}:
return None
return label
def build_single_user_prompt(input_text, subclaim):
if prompt_language == "en":
return (
"You will be given a medical case description and one subclaim. "
"Determine whether the subclaim is supported by the text.\n\n"
f"Text:\n{input_text}\n\n"
f"Subclaim:\n{subclaim}\n\n"
"Reply with exactly one word: 'supported' or 'not_supported'."
)
# Bangla (default)
return (
"আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একটি সাবক্লেইম দেওয়া হবে। "
"সাবক্লেইমটি টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n"
f"টেক্সট:\n{input_text}\n\n"
f"সাবক্লেইম:\n{subclaim}\n\n"
"শুধু একটি শব্দ দিয়ে উত্তর দিন: 'supported' অথবা 'not_supported'."
)
def build_list_user_prompt(input_text, subclaims):
numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims))
if prompt_language == "en":
return (
"You will be given a medical case description and a list of subclaims. "
"Determine for each subclaim whether it is supported by the text.\n\n"
f"Text:\n{input_text}\n\n"
f"List of subclaims:\n{numbered}\n\n"
"Give the label for each subclaim in order. "
"Reply with a JSON array only, e.g.:\n"
'["supported", "not_supported", ...]\n'
"Do not write anything else."
)
# Bangla (default)
return (
"আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। "
"প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n"
f"টেক্সট:\n{input_text}\n\n"
f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n"
"প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। "
"নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n"
'["supported", "not_supported", ...]\n'
"অন্য কিছু লিখবেন না।"
)
def build_single_subclaim_examples(raw_records):
examples = []
for record in raw_records:
input_text = record.get("input_text", "")
model_output = record.get("model_output") or {}
items = model_output.get("items") or []
for item in items:
subclaims = item.get("subclaims") or []
for sc in subclaims:
subclaim_text = sc.get("subclaim", "")
label = normalize_label(sc.get("label"))
if not label:
continue
user_prompt = build_single_user_prompt(input_text, subclaim_text)
examples.append(
{
"conversations": [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": label},
],
}
)
return examples
def build_list_subclaim_examples(raw_records):
examples = []
for record in raw_records:
input_text = record.get("input_text", "")
model_output = record.get("model_output") or {}
items = model_output.get("items") or []
all_subclaims = []
all_labels = []
for item in items:
subclaims = item.get("subclaims") or []
for sc in subclaims:
subclaim_text = sc.get("subclaim", "")
label = normalize_label(sc.get("label"))
if not label:
continue
all_subclaims.append(subclaim_text)
all_labels.append(label)
if not all_subclaims:
continue
user_prompt = build_list_user_prompt(input_text, all_subclaims)
examples.append(
{
"conversations": [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": json.dumps(all_labels)},
],
}
)
return examples
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load Model and Tokenizer
model, tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=4092,
load_in_4bit=True,
)
# 2. Data Preparation
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
if finetune_mode == "single_subclaim":
train_examples = build_single_subclaim_examples(train_raw)
elif finetune_mode == "subclaim_list":
train_examples = build_list_subclaim_examples(train_raw)
else:
raise ValueError(f"Unsupported finetune_mode: {finetune_mode}")
train_dataset = Dataset.from_list(train_examples)
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
# 3. Optional Finetuning
if run_mode == "finetune_and_eval":
# Add LoRA adapters for finetuning
model = FastModel.get_peft_model(
model,
r=8,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
random_state=seed,
)
# Training setup
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=2048,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir="outputs",
report_to="none",
),
)
# Masking to train on assistant responses only
trainer = train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
# Execute training
save_dir = f"/home/mshahidul/readctrl_model/support_checking_bn/{model_name.split('/')[-1]}"
os.makedirs(save_dir, exist_ok=True)
trainer.train()
# Optional: save in float16 merged format
if save_fp16_merged:
model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(save_dir)
elif run_mode == "eval_base_only":
# No finetuning; evaluate base model
save_dir = f"BASE_MODEL:{model_name}"
else:
raise ValueError(f"Unsupported run_mode: {run_mode}")
# 4. Test-set Inference + Accuracy
FastModel.for_inference(model)
model.eval()
model_info_dir = "/home/mshahidul/readctrl/code/support_check/model_info"
ablation_dir = "/home/mshahidul/readctrl/code/support_check/support_check_bn/ablation_studies"
os.makedirs(model_info_dir, exist_ok=True)
os.makedirs(ablation_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
def evaluate_single_subclaim_mode(test_split):
results = []
total = 0
correct = 0
tp = fp = fn = tn = 0
for idx, sample in enumerate(test_split):
input_text = sample.get("input_text", "")
model_output = sample.get("model_output") or {}
items = model_output.get("items") or []
for item in items:
subclaims = item.get("subclaims") or []
for sc in subclaims:
subclaim_text = sc.get("subclaim", "")
gold_label = normalize_label(sc.get("label"))
if not gold_label:
continue
user_prompt = build_single_user_prompt(input_text, subclaim_text)
pred_text = generate_prediction(user_prompt)
pred_label = parse_single_label(pred_text) or "not_supported"
total += 1
is_correct = pred_label == gold_label
if is_correct:
correct += 1
if gold_label == "supported" and pred_label == "supported":
tp += 1
elif gold_label == "supported" and pred_label == "not_supported":
fn += 1
elif gold_label == "not_supported" and pred_label == "supported":
fp += 1
elif gold_label == "not_supported" and pred_label == "not_supported":
tn += 1
results.append(
{
"sample_index": idx,
"input_text": input_text,
"subclaim": subclaim_text,
"gold_label": gold_label,
"predicted_label": pred_label,
"correct": is_correct,
}
)
accuracy = correct / total if total else 0.0
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
metrics = {
"mode": "single_subclaim",
"model_name": model_name,
"model_save_dir": save_dir,
"dataset_path": data_path,
"seed": seed,
"test_size": test_size,
"examples_evaluated": total,
"accuracy": accuracy,
"precision_supported": precision,
"recall_supported": recall,
"f1_supported": f1,
"tp_supported": tp,
"fp_supported": fp,
"fn_supported": fn,
"tn_supported": tn,
"timestamp": timestamp,
}
return results, metrics
def evaluate_subclaim_list_mode(test_split):
results = []
total_samples = 0
exact_match_correct = 0
total_subclaims = 0
correct_subclaims = 0
tp = fp = fn = tn = 0
for idx, sample in enumerate(test_split):
input_text = sample.get("input_text", "")
model_output = sample.get("model_output") or {}
items = model_output.get("items") or []
subclaims = []
gold_labels = []
for item in items:
for sc in item.get("subclaims") or []:
subclaim_text = sc.get("subclaim", "")
label = normalize_label(sc.get("label"))
if not label:
continue
subclaims.append(subclaim_text)
gold_labels.append(label)
if not subclaims:
continue
user_prompt = build_list_user_prompt(input_text, subclaims)
pred_text = generate_prediction(user_prompt)
pred_labels = parse_label_array(pred_text)
if not pred_labels:
pred_labels = ["not_supported"] * len(gold_labels)
if len(pred_labels) < len(gold_labels):
pred_labels = pred_labels + ["not_supported"] * (len(gold_labels) - len(pred_labels))
elif len(pred_labels) > len(gold_labels):
pred_labels = pred_labels[: len(gold_labels)]
sample_correct = 0
for gold_label, pred_label in zip(gold_labels, pred_labels):
total_subclaims += 1
if pred_label == gold_label:
correct_subclaims += 1
sample_correct += 1
if gold_label == "supported" and pred_label == "supported":
tp += 1
elif gold_label == "supported" and pred_label == "not_supported":
fn += 1
elif gold_label == "not_supported" and pred_label == "supported":
fp += 1
elif gold_label == "not_supported" and pred_label == "not_supported":
tn += 1
total_samples += 1
exact_match = sample_correct == len(gold_labels)
if exact_match:
exact_match_correct += 1
results.append(
{
"sample_index": idx,
"input_text": input_text,
"subclaims": subclaims,
"gold_labels": gold_labels,
"predicted_labels": pred_labels,
"exact_match": exact_match,
"per_sample_accuracy": sample_correct / len(gold_labels),
}
)
accuracy = correct_subclaims / total_subclaims if total_subclaims else 0.0
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
exact_match_accuracy = (
exact_match_correct / total_samples if total_samples else 0.0
)
metrics = {
"mode": "subclaim_list",
"model_name": model_name,
"model_save_dir": save_dir,
"dataset_path": data_path,
"seed": seed,
"test_size": test_size,
"test_samples_evaluated": total_samples,
"total_subclaims": total_subclaims,
"correct_subclaims": correct_subclaims,
"subclaim_accuracy": accuracy,
"exact_match_accuracy": exact_match_accuracy,
"precision_supported": precision,
"recall_supported": recall,
"f1_supported": f1,
"tp_supported": tp,
"fp_supported": fp,
"fn_supported": fn,
"tn_supported": tn,
"timestamp": timestamp,
}
return results, metrics
if finetune_mode == "single_subclaim":
results, accuracy_summary = evaluate_single_subclaim_mode(test_raw)
else:
results, accuracy_summary = evaluate_subclaim_list_mode(test_raw)
accuracy_summary["finetune_mode"] = finetune_mode
accuracy_summary["model_size"] = model_size
accuracy_summary["run_mode"] = run_mode
predictions_path = os.path.join(
model_info_dir,
f"{model_tag}_test_inference_{timestamp}.json",
)
accuracy_path = os.path.join(
ablation_dir,
f"{model_tag}_{finetune_mode}_{model_size}_{run_mode}_{timestamp}.json",
)
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test accuracy to: {accuracy_path}")
print(f"Accuracy: {accuracy_summary.get('accuracy', accuracy_summary.get('subclaim_accuracy', 0.0)):.4f}")
print(f"F1 (supported class): {accuracy_summary.get('f1_supported', 0.0):.4f}")