mshahidul
Initial commit of readCtrl code without large models
030876e
import ast
import json
import os
from datetime import datetime
import torch
from datasets import Dataset
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
model_name = "unsloth/Llama-3.2-3B-Instruct"
data_path = "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json"
test_size = 0.1
seed = 3407
max_seq_length = 2048
load_in_4bit = True
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<|begin_of_text|>")
for convo in convos
]
return {"text": texts}
def parse_label_array(raw_text):
text = (raw_text or "").strip()
if not text:
return []
if "```" in text:
text = text.replace("```json", "").replace("```", "").strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
text = text[start : end + 1]
parsed = None
for parser in (json.loads, ast.literal_eval):
try:
parsed = parser(text)
break
except Exception:
continue
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, str):
normalized.append("not_supported")
continue
label = item.strip().lower().replace("-", "_").replace(" ", "_")
if label not in {"supported", "not_supported"}:
label = "not_supported"
normalized.append(label)
return normalized
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=load_in_4bit,
)
# 2. Add LoRA adapters
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=seed,
)
# 3. Data preparation
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
train_dataset = train_raw.map(formatting_prompts_func, batched=True)
# 4. Save directories for this run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
model_save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/{model_tag}"
run_info_dir = os.path.join(
"/home/mshahidul/readctrl/code/support_check/model_info",
f"{model_tag}_{timestamp}",
)
os.makedirs(model_save_dir, exist_ok=True)
os.makedirs(run_info_dir, exist_ok=True)
# 5. Training setup
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=30,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir=os.path.join(run_info_dir, "trainer_outputs"),
report_to="none",
),
)
# 6. Train
trainer.train()
# 7. Save merged model
model.save_pretrained_merged(model_save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(model_save_dir)
# 8. Test-set inference + accuracy
FastLanguageModel.for_inference(model)
model.eval()
results = []
exact_match_correct = 0
label_correct = 0
label_total = 0
parsed_prediction_count = 0
for idx, sample in enumerate(test_raw):
conversations = sample.get("conversations", [])
user_prompt, gold_text = extract_conversation_pair(conversations)
if not user_prompt:
continue
gold_labels = parse_label_array(gold_text)
pred_text = generate_prediction(user_prompt)
pred_labels = parse_label_array(pred_text)
if pred_labels:
parsed_prediction_count += 1
exact_match = bool(gold_labels) and pred_labels == gold_labels
if exact_match:
exact_match_correct += 1
sample_label_correct = 0
for pos, gold_label in enumerate(gold_labels):
if pos < len(pred_labels) and pred_labels[pos] == gold_label:
sample_label_correct += 1
label_correct += sample_label_correct
label_total += len(gold_labels)
results.append(
{
"sample_index": idx,
"gold_labels": gold_labels,
"predicted_labels": pred_labels,
"raw_prediction": pred_text,
"exact_match": exact_match,
"label_accuracy": (
sample_label_correct / len(gold_labels) if gold_labels else None
),
}
)
total_samples = len(results)
exact_match_accuracy = exact_match_correct / total_samples if total_samples else 0.0
label_accuracy = label_correct / label_total if label_total else 0.0
accuracy_summary = {
"model_name": model_name,
"model_save_dir": model_save_dir,
"run_info_dir": run_info_dir,
"dataset_path": data_path,
"seed": seed,
"test_size": test_size,
"test_samples_evaluated": total_samples,
"parsed_prediction_count": parsed_prediction_count,
"exact_match_accuracy": exact_match_accuracy,
"label_accuracy": label_accuracy,
"exact_match_correct": exact_match_correct,
"label_correct": label_correct,
"label_total": label_total,
"timestamp": timestamp,
}
predictions_path = os.path.join(run_info_dir, "test_inference.json")
accuracy_path = os.path.join(run_info_dir, "test_accuracy.json")
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved merged model to: {model_save_dir}")
print(f"Saved run info folder to: {run_info_dir}")
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test accuracy to: {accuracy_path}")
print(f"Exact match accuracy: {exact_match_accuracy:.4f}")
print(f"Label accuracy: {label_accuracy:.4f}")