readctrl / code /support_check /model_finetune /gemma3-finetune.py
shahidul034's picture
Add files using upload-large-folder tool
903b1a4 verified
import ast
import json
import os
from datetime import datetime
import torch
from datasets import Dataset
from trl import SFTConfig, SFTTrainer
from unsloth import FastModel
from unsloth.chat_templates import (
get_chat_template,
standardize_data_formats,
train_on_responses_only,
)
model_name = "unsloth/gemma-3-4b-it"
data_path = "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json"
test_size = 0.1
seed = 3407
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<bos>")
for convo in convos
]
return {"text": texts}
def parse_label_array(raw_text):
text = (raw_text or "").strip()
if not text:
return []
if "```" in text:
text = text.replace("```json", "").replace("```", "").strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
text = text[start : end + 1]
parsed = None
for parser in (json.loads, ast.literal_eval):
try:
parsed = parser(text)
break
except Exception:
continue
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, str):
normalized.append("not_supported")
continue
label = item.strip().lower().replace("-", "_").replace(" ", "_")
if label not in {"supported", "not_supported"}:
label = "not_supported"
normalized.append(label)
return normalized
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load Model and Tokenizer
model, tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=2048,
load_in_4bit=True,
)
# 2. Add LoRA Adapters
model = FastModel.get_peft_model(
model,
r=8,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
random_state=seed,
)
# 3. Data Preparation
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
train_dataset = standardize_data_formats(train_raw)
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
# 4. Training Setup
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=2048,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=30,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir="outputs",
report_to="none",
),
)
# Masking to train on assistant responses only
trainer = train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
# 5. Execute Training
save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/{model_name.split('/')[-1]}"
os.makedirs(save_dir, exist_ok=True)
trainer.train()
# 6. Save in float16 Format
model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(save_dir)
# 7. Test-set Inference + Accuracy
FastModel.for_inference(model)
model.eval()
model_info_dir = "/home/mshahidul/readctrl/code/support_check/model_info"
os.makedirs(model_info_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
results = []
exact_match_correct = 0
label_correct = 0
label_total = 0
parsed_prediction_count = 0
for idx, sample in enumerate(test_raw):
conversations = sample.get("conversations", [])
user_prompt, gold_text = extract_conversation_pair(conversations)
if not user_prompt:
continue
gold_labels = parse_label_array(gold_text)
pred_text = generate_prediction(user_prompt)
pred_labels = parse_label_array(pred_text)
if pred_labels:
parsed_prediction_count += 1
exact_match = bool(gold_labels) and pred_labels == gold_labels
if exact_match:
exact_match_correct += 1
sample_label_correct = 0
for pos, gold_label in enumerate(gold_labels):
if pos < len(pred_labels) and pred_labels[pos] == gold_label:
sample_label_correct += 1
label_correct += sample_label_correct
label_total += len(gold_labels)
results.append(
{
"sample_index": idx,
"gold_labels": gold_labels,
"predicted_labels": pred_labels,
"raw_prediction": pred_text,
"exact_match": exact_match,
"label_accuracy": (
sample_label_correct / len(gold_labels) if gold_labels else None
),
}
)
total_samples = len(results)
exact_match_accuracy = exact_match_correct / total_samples if total_samples else 0.0
label_accuracy = label_correct / label_total if label_total else 0.0
accuracy_summary = {
"model_name": model_name,
"model_save_dir": save_dir,
"dataset_path": data_path,
"seed": seed,
"test_size": test_size,
"test_samples_evaluated": total_samples,
"parsed_prediction_count": parsed_prediction_count,
"exact_match_accuracy": exact_match_accuracy,
"label_accuracy": label_accuracy,
"exact_match_correct": exact_match_correct,
"label_correct": label_correct,
"label_total": label_total,
"timestamp": timestamp,
}
predictions_path = os.path.join(
model_info_dir,
f"{model_tag}_test_inference_{timestamp}.json",
)
accuracy_path = os.path.join(
model_info_dir,
f"{model_tag}_test_accuracy_{timestamp}.json",
)
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test accuracy to: {accuracy_path}")
print(f"Exact match accuracy: {exact_match_accuracy:.4f}")
print(f"Label accuracy: {label_accuracy:.4f}")