File size: 4,544 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from unsloth import FastLanguageModel
import torch
dataset_path = "/home/mshahidul/readctrl/data/finetuning_data/train_subclaim_support_v2.json"
lora_save_path = "/home/mshahidul/readctrl_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-lora"
full_model_save_path = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16"
lora=False
# === Load base model ===
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Nemotron-3-Nano-30B-A3B",
    max_seq_length = 2048, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    trust_remote_code = True,
    unsloth_force_compile = True,
    attn_implementation="eager",
    # token = "hf_...", # use one if using gated models
)

# === Prepare LoRA model ===
model = FastLanguageModel.get_peft_model(
    model,
    r = 32,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_alpha = 32,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

# === Load non-reasoning dataset (Full dataset) ===
from datasets import load_dataset
from unsloth.chat_templates import standardize_sharegpt

print("Loading dataset...")
with open(f"{dataset_path}") as f:
    data = json.load(f)
from datasets import Dataset
dataset = Dataset.from_list(data)
def training_prompt(medical_text, subclaim):
    system_prompt = (
            "You are a clinical evidence auditor. Your evaluation must be based "
            "STRICTLY and ONLY on the provided medical text. Do not use outside "
            "medical knowledge or assume facts not explicitly stated. If the text "
            "does not provide enough information to confirm the claim, you must "
            "mark it as 'not_supported'."
        )
        
    user_content = f"""EVALUATION TASK:
    1. Read the Medical Text.
    2. Verify the Subclaim.
    3. If the evidence is missing, ambiguous, or unconfirmed in the text, label it 'not_supported'.

    ### Medical Text:
    {medical_text}

    ### Subclaim:
    {subclaim}

    Output exactly one word ('supported' or 'not_supported'):"""
    return f"{system_prompt}\n\n{user_content}"

def generate_conversation(examples):
    # import ipdb; ipdb.set_trace()
    medical_texts  = examples["medical_text"]
    subclaims = examples["subclaim"]
    labels=examples['label']
    conversations = []
    for medical_text, subclaim, label in zip(medical_texts, subclaims, labels):
        conversations.append([
            {"role" : "user",      "content" : training_prompt(medical_text, subclaim)},
            {"role" : "assistant", "content" : label},
        ])
    return { "conversations": conversations, }

dataset = dataset.map(generate_conversation, batched = True)

def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
   return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)


# === Training setup ===
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 2, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        # max_steps = 60,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use TrackIO/WandB etc
    ),
)

# === Train model ===
trainer_stats = trainer.train()


if lora==True:
    model.save_pretrained(lora_save_path)
    tokenizer.save_pretrained(lora_save_path)
else:
    model.save_pretrained_merged(
        full_model_save_path,
        tokenizer,
        save_method="merged_16bit",
    )