import os import logging # Avoid TypeError in transformers deprecation warning (message contains '%', extra args break %-formatting) for _logger_name in ("transformers", "transformers.modeling_attn_mask_utils", "transformers.utils.logging"): logging.getLogger(_logger_name).setLevel(logging.ERROR) # If a handler still hits the buggy warning, don't crash the script logging.raiseExceptions = False os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "3" import json from datetime import datetime import torch from datasets import Dataset from unsloth import FastLanguageModel from trl import SFTConfig, SFTTrainer model_name = "unsloth/Llama-3.2-3B-Instruct" data_path = "/home/mshahidul/readctrl/code/text_classifier/bn/testing_bn_full.json" test_size = 0.2 # 1 - train_ratio (0.8), same as Gemma script seed = 42 prompt_language = "bn" # "bn" (Bangla) or "en" (English) run_mode = "finetune_and_eval" # "finetune_and_eval" or "eval_base_only" save_fp16_merged = False # whether to save merged fp16 model after finetuning max_seq_length = 4096 load_in_4bit = False def get_model_size_from_name(name): base = name.split("/")[-1] for part in base.split("-"): token = part.lower() if token.endswith("b") or token.endswith("m"): return part return "unknown" model_size = get_model_size_from_name(model_name) def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False, ).removeprefix("<|begin_of_text|>") for convo in convos ] return {"text": texts} def build_classification_user_prompt(fulltext, gen_text): # Input: fulltext + gen_text, Output: label if prompt_language == "en": return ( "You will be given a medical case description (full text) and a generated summary. " "Classify the patient's health literacy level.\n\n" f"Full text:\n{fulltext}\n\n" f"Generated text:\n{gen_text}\n\n" "Reply with exactly one label from this set:\n" "low_health_literacy, intermediate_health_literacy, high_health_literacy" ) # Bangla (default) return ( "আপনাকে একটি মেডিকেল কেসের পূর্ণ বর্ণনা (full text) এবং তৈরি করা সারাংশ (generated text) দেওয়া হবে। " "রোগীর স্বাস্থ্যজ্ঞান (health literacy) কোন স্তরের তা নির্ধারণ করুন।\n\n" f"Full text:\n{fulltext}\n\n" f"Generated text:\n{gen_text}\n\n" "শুধু নিচের সেট থেকে একটি লেবেল দিয়ে উত্তর দিন:\n" "low_health_literacy, intermediate_health_literacy, high_health_literacy" ) def build_classification_examples(raw_records): examples = [] for record in raw_records: fulltext = record.get("fulltext", "") gen_text = record.get("gen_text", "") label = (record.get("label") or "").strip() if not label: continue user_prompt = build_classification_user_prompt(fulltext, gen_text) examples.append( { "conversations": [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": label}, ], } ) return examples def generate_prediction(user_prompt): prompt = tokenizer.apply_chat_template( [{"role": "user", "content": user_prompt}], tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=False, temperature=0.0, use_cache=True, ) generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() # 1. Load model and tokenizer model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=None, load_in_4bit=load_in_4bit, ) # 2. Add LoRA adapters (kept same as original Llama script) model = FastLanguageModel.get_peft_model( model, r=16, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=seed, ) # 3. Data preparation (same dataset split and prompt style as Gemma script) with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) raw_dataset = Dataset.from_list(raw_data) split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True) train_raw = split_dataset["train"] test_raw = split_dataset["test"] train_examples = build_classification_examples(train_raw) train_dataset = Dataset.from_list(train_examples) train_dataset = train_dataset.map(formatting_prompts_func, batched=True) # 4. Optional finetuning if run_mode == "finetune_and_eval": trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, dataset_text_field="text", max_seq_length=max_seq_length, args=SFTConfig( per_device_train_batch_size=2, gradient_accumulation_steps=4, warmup_steps=5, max_steps=60, learning_rate=2e-4, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=seed, output_dir="outputs", report_to="none", ), ) trainer.train() save_dir = f"/home/mshahidul/readctrl_model/text_classifier_bn/{model_name.split('/')[-1]}" os.makedirs(save_dir, exist_ok=True) if save_fp16_merged: model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit") tokenizer.save_pretrained(save_dir) elif run_mode == "eval_base_only": # No finetuning; evaluate base model save_dir = f"BASE_MODEL:{model_name}" else: raise ValueError(f"Unsupported run_mode: {run_mode}") # 5. Test-set inference + accuracy (same pattern and folders as Gemma script) FastLanguageModel.for_inference(model) model.eval() model_info_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/model_info" ablation_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/ablation_studies" os.makedirs(model_info_dir, exist_ok=True) os.makedirs(ablation_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = model_name.split("/")[-1].replace(".", "_") def evaluate_classification_mode(test_split): results = [] total = 0 correct = 0 for idx, sample in enumerate(test_split): fulltext = sample.get("fulltext", "") gen_text = sample.get("gen_text", "") gold_label = (sample.get("label") or "").strip() if not gold_label: continue user_prompt = build_classification_user_prompt(fulltext, gen_text) pred_text = generate_prediction(user_prompt) pred_label = (pred_text or "").strip() total += 1 is_correct = pred_label == gold_label if is_correct: correct += 1 results.append( { "sample_index": idx, "fulltext": fulltext, "gen_text": gen_text, "gold_label": gold_label, "predicted_label": pred_label, "correct": is_correct, } ) accuracy = correct / total if total else 0.0 metrics = { "mode": "fulltext_gen_text_classification", "model_name": model_name, "model_save_dir": save_dir, "dataset_path": data_path, "prompt_language": prompt_language, "seed": seed, "test_size": test_size, "examples_evaluated": total, "accuracy": accuracy, "timestamp": timestamp, } return results, metrics results, accuracy_summary = evaluate_classification_mode(test_raw) accuracy_summary["finetune_mode"] = "classification" accuracy_summary["model_size"] = model_size accuracy_summary["run_mode"] = run_mode accuracy_summary["prompt_language"] = prompt_language predictions_path = os.path.join( model_info_dir, f"{model_tag}_test_inference_{timestamp}.json", ) accuracy_path = os.path.join( ablation_dir, f"{model_tag}_classification_{model_size}_{run_mode}_{timestamp}.json", ) with open(predictions_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) with open(accuracy_path, "w", encoding="utf-8") as f: json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) print(f"Saved test inference to: {predictions_path}") print(f"Saved test accuracy to: {accuracy_path}") print(f"Accuracy: {accuracy_summary.get('accuracy', 0.0):.4f}")