File size: 4,544 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import os
import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from unsloth import FastLanguageModel
import torch
dataset_path = "/home/mshahidul/readctrl/data/finetuning_data/train_subclaim_support_v2.json"
lora_save_path = "/home/mshahidul/readctrl_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-lora"
full_model_save_path = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16"
lora=False
# === Load base model ===
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Nemotron-3-Nano-30B-A3B",
max_seq_length = 2048, # Choose any for long context!
load_in_4bit = False, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
trust_remote_code = True,
unsloth_force_compile = True,
attn_implementation="eager",
# token = "hf_...", # use one if using gated models
)
# === Prepare LoRA model ===
model = FastLanguageModel.get_peft_model(
model,
r = 32,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_alpha = 32,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
# === Load non-reasoning dataset (Full dataset) ===
from datasets import load_dataset
from unsloth.chat_templates import standardize_sharegpt
print("Loading dataset...")
with open(f"{dataset_path}") as f:
data = json.load(f)
from datasets import Dataset
dataset = Dataset.from_list(data)
def training_prompt(medical_text, subclaim):
system_prompt = (
"You are a clinical evidence auditor. Your evaluation must be based "
"STRICTLY and ONLY on the provided medical text. Do not use outside "
"medical knowledge or assume facts not explicitly stated. If the text "
"does not provide enough information to confirm the claim, you must "
"mark it as 'not_supported'."
)
user_content = f"""EVALUATION TASK:
1. Read the Medical Text.
2. Verify the Subclaim.
3. If the evidence is missing, ambiguous, or unconfirmed in the text, label it 'not_supported'.
### Medical Text:
{medical_text}
### Subclaim:
{subclaim}
Output exactly one word ('supported' or 'not_supported'):"""
return f"{system_prompt}\n\n{user_content}"
def generate_conversation(examples):
# import ipdb; ipdb.set_trace()
medical_texts = examples["medical_text"]
subclaims = examples["subclaim"]
labels=examples['label']
conversations = []
for medical_text, subclaim, label in zip(medical_texts, subclaims, labels):
conversations.append([
{"role" : "user", "content" : training_prompt(medical_text, subclaim)},
{"role" : "assistant", "content" : label},
])
return { "conversations": conversations, }
dataset = dataset.map(generate_conversation, batched = True)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
dataset = dataset.map(formatting_prompts_func, batched = True)
# === Training setup ===
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
eval_dataset = None, # Can set up evaluation!
args = SFTConfig(
dataset_text_field = "text",
per_device_train_batch_size = 4,
gradient_accumulation_steps = 2, # Use GA to mimic batch size!
warmup_steps = 5,
num_train_epochs = 1, # Set this for 1 full training run.
# max_steps = 60,
learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.001,
lr_scheduler_type = "linear",
seed = 3407,
report_to = "none", # Use TrackIO/WandB etc
),
)
# === Train model ===
trainer_stats = trainer.train()
if lora==True:
model.save_pretrained(lora_save_path)
tokenizer.save_pretrained(lora_save_path)
else:
model.save_pretrained_merged(
full_model_save_path,
tokenizer,
save_method="merged_16bit",
)
|