readCtrl_lambda / code /text_classifier /bn /finetune /llama31_8b_32_3b.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json
import ast
from unsloth import FastLanguageModel
import torch
from trl import SFTConfig, SFTTrainer
from datasets import Dataset
from unsloth.chat_templates import get_chat_template, standardize_sharegpt
# 1. Configuration
max_seq_length = 2048
dtype = None # Auto-detection
load_in_4bit = True
data_path = "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json"
# model_name = "unsloth/Llama-3.1-8B"
model_name = "unsloth/Llama-3.2-3B-Instruct"
# 2. Load Model & Tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
# 3. Add LoRA Adapters
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
)
# 4. Data Prep (Conversation Format)
tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<bos>")
for convo in convos
]
return { "text" : texts, }
def parse_label_array(raw_text):
text = (raw_text or "").strip()
if not text:
return []
if "```" in text:
text = text.replace("```json", "").replace("```", "").strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
text = text[start : end + 1]
parsed = None
for parser in (json.loads, ast.literal_eval):
try:
parsed = parser(text)
break
except Exception:
continue
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, str):
normalized.append("not_supported")
continue
label = item.strip().lower().replace("-", "_").replace(" ", "_")
if label not in {"supported", "not_supported"}:
label = "not_supported"
normalized.append(label)
return normalized
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
dataset = Dataset.from_list(raw_data)
dataset = standardize_sharegpt(dataset)
dataset = dataset.train_test_split(test_size=0.1, seed=3407, shuffle=True)
train_dataset = dataset["train"].map(formatting_prompts_func, batched=True)
test_dataset = dataset["test"]
# 5. Training
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = train_dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
packing = False,
args = SFTConfig(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 60, # Increase for full training
learning_rate = 2e-4,
fp16 = not torch.cuda.is_bf16_supported(),
bf16 = torch.cuda.is_bf16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
),
)
trainer.train()
# 6. Test-set Inference + Accuracy
FastLanguageModel.for_inference(model)
model.eval()
print("\n--- Testing Model on Test Set Samples ---")
for i in range(3):
sample = test_dataset[i]
user_prompt, _ = extract_conversation_pair(sample["conversations"])
print(f"\nTest Sample {i+1} Prompt: {user_prompt}")
decoded_output = generate_prediction(user_prompt)
print(f"Model Response: {decoded_output}")
exact_match_correct = 0
label_correct = 0
label_total = 0
evaluated_samples = 0
parsed_prediction_count = 0
for sample in test_dataset:
conversations = sample.get("conversations", [])
user_prompt, gold_text = extract_conversation_pair(conversations)
if not user_prompt:
continue
gold_labels = parse_label_array(gold_text)
pred_text = generate_prediction(user_prompt)
pred_labels = parse_label_array(pred_text)
evaluated_samples += 1
if pred_labels:
parsed_prediction_count += 1
if gold_labels and pred_labels == gold_labels:
exact_match_correct += 1
for pos, gold_label in enumerate(gold_labels):
if pos < len(pred_labels) and pred_labels[pos] == gold_label:
label_correct += 1
label_total += len(gold_labels)
exact_match_accuracy = exact_match_correct / evaluated_samples if evaluated_samples else 0.0
label_accuracy = label_correct / label_total if label_total else 0.0
print("\n--- Test Accuracy ---")
print(f"Samples evaluated: {evaluated_samples}")
print(f"Parsed predictions: {parsed_prediction_count}")
print(f"Exact match accuracy: {exact_match_accuracy:.4f}")
print(f"Label accuracy: {label_accuracy:.4f}")
save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/it_{model_name.split('/')[-1]}"
# 7. Save in FP16 Format (Merged)
# This creates a folder with the full model weights, not just adapters.
model.save_pretrained_merged(save_dir, tokenizer, save_method = "merged_16bit")
print(f"\nModel successfully saved in FP16 format to {save_dir}")