import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "3" import json import os from datetime import datetime import torch from datasets import Dataset from unsloth import FastModel from unsloth.chat_templates import ( get_chat_template, standardize_data_formats, train_on_responses_only, ) from trl import SFTConfig, SFTTrainer model_name = "unsloth/gemma-3-4b-it" data_path = "/home/mshahidul/readctrl/code/text_classifier/bn/testing_bn_full.json" test_size = 0.2 # 1 - train_ratio (0.8) seed = 42 prompt_language = "en" # "bn" (Bangla) or "en" (English) # run_mode options: # - "finetune_and_eval": run LoRA finetuning then evaluate # - "eval_base_only": evaluate the untouched base model # - "eval_finetuned_only": load an already-saved finetuned model and only run inference (no finetuning) run_mode = "eval_finetuned_only" # If you want to run "eval_finetuned_only", point this to the merged fp16 model directory # created by a previous "finetune_and_eval" run (where save_pretrained_merged was used). finetuned_model_dir = "/home/mshahidul/readctrl_model/text_classifier_bn/gemma-3-4b-it" # e.g. "/home/mshahidul/readctrl_model/text_classifier_bn/gemma-3-4b-it" save_fp16_merged = True # whether to save merged fp16 model after finetuning def get_model_size_from_name(name): base = name.split("/")[-1] for part in base.split("-"): token = part.lower() if token.endswith("b") or token.endswith("m"): return part return "unknown" model_size = get_model_size_from_name(model_name) def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False, ).removeprefix("") for convo in convos ] return {"text": texts} def build_classification_user_prompt(fulltext, gen_text): # Input: fulltext (reference) + gen_text (main text to classify), Output: label if prompt_language == "en": return ( "You will be given a medical case description as reference (full text) and a generated text to classify. " "Determine the patient's health literacy level based only on the generated text.\n\n" f"Reference (full text):\n{fulltext}\n\n" f"Generated text (to classify):\n{gen_text}\n\n" "Reply with exactly one label from this set:\n" "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" ) # Bangla (default) — matches reward_new_v6_bn_v2.py return ( "আপনাকে রেফারেন্স হিসেবে মেডিকেল কেসের পূর্ণ বর্ণনা (reference full text) এবং মূলভাবে শ্রেণিবিন্যাস করার জন্য তৈরি করা টেক্সট (generated text) দেওয়া হবে। " "শুধুমাত্র তৈরি করা টেক্সট (generated text)-এর উপর ভিত্তি করে রোগীর স্বাস্থ্যজ্ঞান (health literacy) কোন স্তরের তা নির্ধারণ করুন।\n\n" f"Reference (full text):\n{fulltext}\n\n" f"Generated text (যেটি শ্রেণিবিন্যাস করতে হবে):\n{gen_text}\n\n" "শুধু নিচের সেট থেকে একটি লেবেল দিয়ে উত্তর দিন:\n" "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" ) def build_classification_examples(raw_records): examples = [] for record in raw_records: fulltext = record.get("fulltext", "") gen_text = record.get("gen_text", "") label = (record.get("label") or "").strip() if not label: continue user_prompt = build_classification_user_prompt(fulltext, gen_text) examples.append( { "conversations": [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": label}, ], } ) return examples def extract_conversation_pair(conversations): user_prompt = "" gold_response = "" for message in conversations: role = message.get("role") or message.get("from") content = message.get("content", "") if role == "user" and not user_prompt: user_prompt = content elif role == "assistant" and not gold_response: gold_response = content return user_prompt, gold_response def generate_prediction(user_prompt): prompt = tokenizer.apply_chat_template( [{"role": "user", "content": user_prompt}], tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=False, temperature=0.0, use_cache=True, ) generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] # import ipdb; ipdb.set_trace() return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() # 1. Load Model and Tokenizer if run_mode == "eval_finetuned_only": if not finetuned_model_dir: raise ValueError( "run_mode is 'eval_finetuned_only' but 'finetuned_model_dir' is empty. " "Please set 'finetuned_model_dir' to the directory of your saved merged model." ) model, tokenizer = FastModel.from_pretrained( model_name=finetuned_model_dir, max_seq_length=8192, load_in_4bit=False, ) else: model, tokenizer = FastModel.from_pretrained( model_name=model_name, max_seq_length=8192, load_in_4bit=False, ) # 2. Data Preparation tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) raw_dataset = Dataset.from_list(raw_data) split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True) train_raw = split_dataset["train"] test_raw = split_dataset["test"] train_examples = build_classification_examples(train_raw) train_dataset = Dataset.from_list(train_examples) train_dataset = train_dataset.map(formatting_prompts_func, batched=True) # 3. Optional Finetuning if run_mode == "finetune_and_eval": # Add LoRA adapters for finetuning model = FastModel.get_peft_model( model, r=8, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", random_state=seed, ) # Training setup trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, dataset_text_field="text", max_seq_length=2048, args=SFTConfig( per_device_train_batch_size=2, gradient_accumulation_steps=4, warmup_steps=5, max_steps=60, learning_rate=2e-4, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=seed, output_dir="outputs", report_to="none", ), ) # Masking to train on assistant responses only trainer = train_on_responses_only( trainer, instruction_part="user\n", response_part="model\n", ) # Execute training save_dir = f"/home/mshahidul/readctrl_model/text_classifier_bn/{model_name.split('/')[-1]}" os.makedirs(save_dir, exist_ok=True) trainer.train() # Optional: save in float16 merged format if save_fp16_merged: model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit") tokenizer.save_pretrained(save_dir) elif run_mode == "eval_base_only": # No finetuning; evaluate base (unmodified) model save_dir = f"BASE_MODEL:{model_name}" elif run_mode == "eval_finetuned_only": # No finetuning; evaluate an already-saved finetuned model save_dir = finetuned_model_dir else: raise ValueError(f"Unsupported run_mode: {run_mode}") # 4. Test-set Inference + Accuracy FastModel.for_inference(model) model.eval() model_info_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/model_info" ablation_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/ablation_studies" os.makedirs(model_info_dir, exist_ok=True) os.makedirs(ablation_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = model_name.split("/")[-1].replace(".", "_") def evaluate_classification_mode(test_split): results = [] total = 0 correct = 0 for idx, sample in enumerate(test_split): fulltext = sample.get("fulltext", "") gen_text = sample.get("gen_text", "") gold_label = (sample.get("label") or "").strip() if not gold_label: continue user_prompt = build_classification_user_prompt(fulltext, gen_text) pred_text = generate_prediction(user_prompt) pred_label = (pred_text or "").strip() # import ipdb; ipdb.set_trace() total += 1 is_correct = pred_label == gold_label if is_correct: correct += 1 results.append( { "sample_index": idx, "fulltext": fulltext, "gen_text": gen_text, "gold_label": gold_label, "predicted_label": pred_label, "correct": is_correct, } ) accuracy = correct / total if total else 0.0 metrics = { "mode": "fulltext_gen_text_classification", "model_name": model_name, "model_save_dir": save_dir, "dataset_path": data_path, "prompt_language": prompt_language, "seed": seed, "test_size": test_size, "examples_evaluated": total, "accuracy": accuracy, "timestamp": timestamp, } return results, metrics results, accuracy_summary = evaluate_classification_mode(test_raw) accuracy_summary["finetune_mode"] = "classification" accuracy_summary["model_size"] = model_size accuracy_summary["run_mode"] = run_mode accuracy_summary["prompt_language"] = prompt_language predictions_path = os.path.join( model_info_dir, f"{model_tag}_test_inference_{timestamp}.json", ) accuracy_path = os.path.join( ablation_dir, f"{model_tag}_classification_{model_size}_{run_mode}_{timestamp}.json", ) with open(predictions_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) with open(accuracy_path, "w", encoding="utf-8") as f: json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) print(f"Saved test inference to: {predictions_path}") print(f"Saved test accuracy to: {accuracy_path}") print(f"Accuracy: {accuracy_summary.get('accuracy', accuracy_summary.get('subclaim_accuracy', 0.0)):.4f}")