shahidul034's picture
Add files using upload-large-folder tool
903b1a4 verified
import ast
import json
import os
from datetime import datetime
import torch
from datasets import Dataset
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
model_name = "unsloth/Llama-3.2-3B-Instruct"
data_path = "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json"
test_size = 0.1
seed = 3407
max_seq_length = 2048
load_in_4bit = True
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<|begin_of_text|>")
for convo in convos
]
return {"text": texts}
def parse_label_array(raw_text):
text = (raw_text or "").strip()
if not text:
return []
if "```" in text:
text = text.replace("```json", "").replace("```", "").strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
text = text[start : end + 1]
parsed = None
for parser in (json.loads, ast.literal_eval):
try:
parsed = parser(text)
break
except Exception:
continue
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, str):
normalized.append("not_supported")
continue
label = item.strip().lower().replace("-", "_").replace(" ", "_")
if label not in {"supported", "not_supported"}:
label = "not_supported"
normalized.append(label)
return normalized
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=load_in_4bit,
)
# 2. Add LoRA adapters
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=seed,
)
# 3. Data preparation
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
train_dataset = train_raw.map(formatting_prompts_func, batched=True)
# 4. Save directories for this run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
model_save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/{model_tag}"
run_info_dir = os.path.join(
"/home/mshahidul/readctrl/code/support_check/model_info",
f"{model_tag}_{timestamp}",
)
os.makedirs(model_save_dir, exist_ok=True)
os.makedirs(run_info_dir, exist_ok=True)
# 5. Training setup
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=30,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir=os.path.join(run_info_dir, "trainer_outputs"),
report_to="none",
),
)
# 6. Train
trainer.train()
# 7. Save merged model
model.save_pretrained_merged(model_save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(model_save_dir)
# 8. Test-set inference + accuracy
FastLanguageModel.for_inference(model)
model.eval()
results = []
exact_match_correct = 0
label_correct = 0
label_total = 0
parsed_prediction_count = 0
for idx, sample in enumerate(test_raw):
conversations = sample.get("conversations", [])
user_prompt, gold_text = extract_conversation_pair(conversations)
if not user_prompt:
continue
gold_labels = parse_label_array(gold_text)
pred_text = generate_prediction(user_prompt)
pred_labels = parse_label_array(pred_text)
if pred_labels:
parsed_prediction_count += 1
exact_match = bool(gold_labels) and pred_labels == gold_labels
if exact_match:
exact_match_correct += 1
sample_label_correct = 0
for pos, gold_label in enumerate(gold_labels):
if pos < len(pred_labels) and pred_labels[pos] == gold_label:
sample_label_correct += 1
label_correct += sample_label_correct
label_total += len(gold_labels)
results.append(
{
"sample_index": idx,
"gold_labels": gold_labels,
"predicted_labels": pred_labels,
"raw_prediction": pred_text,
"exact_match": exact_match,
"label_accuracy": (
sample_label_correct / len(gold_labels) if gold_labels else None
),
}
)
total_samples = len(results)
exact_match_accuracy = exact_match_correct / total_samples if total_samples else 0.0
label_accuracy = label_correct / label_total if label_total else 0.0
accuracy_summary = {
"model_name": model_name,
"model_save_dir": model_save_dir,
"run_info_dir": run_info_dir,
"dataset_path": data_path,
"seed": seed,
"test_size": test_size,
"test_samples_evaluated": total_samples,
"parsed_prediction_count": parsed_prediction_count,
"exact_match_accuracy": exact_match_accuracy,
"label_accuracy": label_accuracy,
"exact_match_correct": exact_match_correct,
"label_correct": label_correct,
"label_total": label_total,
"timestamp": timestamp,
}
predictions_path = os.path.join(run_info_dir, "test_inference.json")
accuracy_path = os.path.join(run_info_dir, "test_accuracy.json")
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved merged model to: {model_save_dir}")
print(f"Saved run info folder to: {run_info_dir}")
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test accuracy to: {accuracy_path}")
print(f"Exact match accuracy: {exact_match_accuracy:.4f}")
print(f"Label accuracy: {label_accuracy:.4f}")