readCtrl_lambda / code /text_classifier /bn /finetune /gemma3-finetune.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import json
import os
from datetime import datetime
import torch
from datasets import Dataset
from unsloth import FastModel
from unsloth.chat_templates import (
get_chat_template,
standardize_data_formats,
train_on_responses_only,
)
from trl import SFTConfig, SFTTrainer
model_name = "unsloth/gemma-3-4b-it"
data_path = "/home/mshahidul/readctrl/code/text_classifier/bn/testing_bn_full.json"
test_size = 0.2 # 1 - train_ratio (0.8)
seed = 42
prompt_language = "en" # "bn" (Bangla) or "en" (English)
# run_mode options:
# - "finetune_and_eval": run LoRA finetuning then evaluate
# - "eval_base_only": evaluate the untouched base model
# - "eval_finetuned_only": load an already-saved finetuned model and only run inference (no finetuning)
run_mode = "eval_finetuned_only"
# If you want to run "eval_finetuned_only", point this to the merged fp16 model directory
# created by a previous "finetune_and_eval" run (where save_pretrained_merged was used).
finetuned_model_dir = "/home/mshahidul/readctrl_model/text_classifier_bn/gemma-3-4b-it" # e.g. "/home/mshahidul/readctrl_model/text_classifier_bn/gemma-3-4b-it"
save_fp16_merged = True # whether to save merged fp16 model after finetuning
def get_model_size_from_name(name):
base = name.split("/")[-1]
for part in base.split("-"):
token = part.lower()
if token.endswith("b") or token.endswith("m"):
return part
return "unknown"
model_size = get_model_size_from_name(model_name)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<bos>")
for convo in convos
]
return {"text": texts}
def build_classification_user_prompt(fulltext, gen_text):
# Input: fulltext (reference) + gen_text (main text to classify), Output: label
if prompt_language == "en":
return (
"You will be given a medical case description as reference (full text) and a generated text to classify. "
"Determine the patient's health literacy level based only on the generated text.\n\n"
f"Reference (full text):\n{fulltext}\n\n"
f"Generated text (to classify):\n{gen_text}\n\n"
"Reply with exactly one label from this set:\n"
"low_health_literacy, intermediate_health_literacy, proficient_health_literacy"
)
# Bangla (default) — matches reward_new_v6_bn_v2.py
return (
"আপনাকে রেফারেন্স হিসেবে মেডিকেল কেসের পূর্ণ বর্ণনা (reference full text) এবং মূলভাবে শ্রেণিবিন্যাস করার জন্য তৈরি করা টেক্সট (generated text) দেওয়া হবে। "
"শুধুমাত্র তৈরি করা টেক্সট (generated text)-এর উপর ভিত্তি করে রোগীর স্বাস্থ্যজ্ঞান (health literacy) কোন স্তরের তা নির্ধারণ করুন।\n\n"
f"Reference (full text):\n{fulltext}\n\n"
f"Generated text (যেটি শ্রেণিবিন্যাস করতে হবে):\n{gen_text}\n\n"
"শুধু নিচের সেট থেকে একটি লেবেল দিয়ে উত্তর দিন:\n"
"low_health_literacy, intermediate_health_literacy, proficient_health_literacy"
)
def build_classification_examples(raw_records):
examples = []
for record in raw_records:
fulltext = record.get("fulltext", "")
gen_text = record.get("gen_text", "")
label = (record.get("label") or "").strip()
if not label:
continue
user_prompt = build_classification_user_prompt(fulltext, gen_text)
examples.append(
{
"conversations": [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": label},
],
}
)
return examples
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
# import ipdb; ipdb.set_trace()
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load Model and Tokenizer
if run_mode == "eval_finetuned_only":
if not finetuned_model_dir:
raise ValueError(
"run_mode is 'eval_finetuned_only' but 'finetuned_model_dir' is empty. "
"Please set 'finetuned_model_dir' to the directory of your saved merged model."
)
model, tokenizer = FastModel.from_pretrained(
model_name=finetuned_model_dir,
max_seq_length=8192,
load_in_4bit=False,
)
else:
model, tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=8192,
load_in_4bit=False,
)
# 2. Data Preparation
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
train_examples = build_classification_examples(train_raw)
train_dataset = Dataset.from_list(train_examples)
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
# 3. Optional Finetuning
if run_mode == "finetune_and_eval":
# Add LoRA adapters for finetuning
model = FastModel.get_peft_model(
model,
r=8,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
random_state=seed,
)
# Training setup
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=2048,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir="outputs",
report_to="none",
),
)
# Masking to train on assistant responses only
trainer = train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
# Execute training
save_dir = f"/home/mshahidul/readctrl_model/text_classifier_bn/{model_name.split('/')[-1]}"
os.makedirs(save_dir, exist_ok=True)
trainer.train()
# Optional: save in float16 merged format
if save_fp16_merged:
model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(save_dir)
elif run_mode == "eval_base_only":
# No finetuning; evaluate base (unmodified) model
save_dir = f"BASE_MODEL:{model_name}"
elif run_mode == "eval_finetuned_only":
# No finetuning; evaluate an already-saved finetuned model
save_dir = finetuned_model_dir
else:
raise ValueError(f"Unsupported run_mode: {run_mode}")
# 4. Test-set Inference + Accuracy
FastModel.for_inference(model)
model.eval()
model_info_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/model_info"
ablation_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/ablation_studies"
os.makedirs(model_info_dir, exist_ok=True)
os.makedirs(ablation_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
def evaluate_classification_mode(test_split):
results = []
total = 0
correct = 0
for idx, sample in enumerate(test_split):
fulltext = sample.get("fulltext", "")
gen_text = sample.get("gen_text", "")
gold_label = (sample.get("label") or "").strip()
if not gold_label:
continue
user_prompt = build_classification_user_prompt(fulltext, gen_text)
pred_text = generate_prediction(user_prompt)
pred_label = (pred_text or "").strip()
# import ipdb; ipdb.set_trace()
total += 1
is_correct = pred_label == gold_label
if is_correct:
correct += 1
results.append(
{
"sample_index": idx,
"fulltext": fulltext,
"gen_text": gen_text,
"gold_label": gold_label,
"predicted_label": pred_label,
"correct": is_correct,
}
)
accuracy = correct / total if total else 0.0
metrics = {
"mode": "fulltext_gen_text_classification",
"model_name": model_name,
"model_save_dir": save_dir,
"dataset_path": data_path,
"prompt_language": prompt_language,
"seed": seed,
"test_size": test_size,
"examples_evaluated": total,
"accuracy": accuracy,
"timestamp": timestamp,
}
return results, metrics
results, accuracy_summary = evaluate_classification_mode(test_raw)
accuracy_summary["finetune_mode"] = "classification"
accuracy_summary["model_size"] = model_size
accuracy_summary["run_mode"] = run_mode
accuracy_summary["prompt_language"] = prompt_language
predictions_path = os.path.join(
model_info_dir,
f"{model_tag}_test_inference_{timestamp}.json",
)
accuracy_path = os.path.join(
ablation_dir,
f"{model_tag}_classification_{model_size}_{run_mode}_{timestamp}.json",
)
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test accuracy to: {accuracy_path}")
print(f"Accuracy: {accuracy_summary.get('accuracy', accuracy_summary.get('subclaim_accuracy', 0.0)):.4f}")