import ast import json import os from datetime import datetime import torch from datasets import Dataset from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel model_name = "unsloth/Llama-3.2-3B-Instruct" data_path = "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json" test_size = 0.1 seed = 3407 max_seq_length = 2048 load_in_4bit = True def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False, ).removeprefix("<|begin_of_text|>") for convo in convos ] return {"text": texts} def parse_label_array(raw_text): text = (raw_text or "").strip() if not text: return [] if "```" in text: text = text.replace("```json", "").replace("```", "").strip() start = text.find("[") end = text.rfind("]") if start != -1 and end != -1 and end > start: text = text[start : end + 1] parsed = None for parser in (json.loads, ast.literal_eval): try: parsed = parser(text) break except Exception: continue if not isinstance(parsed, list): return [] normalized = [] for item in parsed: if not isinstance(item, str): normalized.append("not_supported") continue label = item.strip().lower().replace("-", "_").replace(" ", "_") if label not in {"supported", "not_supported"}: label = "not_supported" normalized.append(label) return normalized def extract_conversation_pair(conversations): user_prompt = "" gold_response = "" for message in conversations: role = message.get("role") or message.get("from") content = message.get("content", "") if role == "user" and not user_prompt: user_prompt = content elif role == "assistant" and not gold_response: gold_response = content return user_prompt, gold_response def generate_prediction(user_prompt): prompt = tokenizer.apply_chat_template( [{"role": "user", "content": user_prompt}], tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=128, do_sample=False, temperature=0.0, use_cache=True, ) generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() # 1. Load model and tokenizer model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=None, load_in_4bit=load_in_4bit, ) # 2. Add LoRA adapters model = FastLanguageModel.get_peft_model( model, r=16, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=seed, ) # 3. Data preparation with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) raw_dataset = Dataset.from_list(raw_data) split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True) train_raw = split_dataset["train"] test_raw = split_dataset["test"] train_dataset = train_raw.map(formatting_prompts_func, batched=True) # 4. Save directories for this run timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = model_name.split("/")[-1].replace(".", "_") model_save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/{model_tag}" run_info_dir = os.path.join( "/home/mshahidul/readctrl/code/support_check/model_info", f"{model_tag}_{timestamp}", ) os.makedirs(model_save_dir, exist_ok=True) os.makedirs(run_info_dir, exist_ok=True) # 5. Training setup trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, dataset_text_field="text", max_seq_length=max_seq_length, args=SFTConfig( per_device_train_batch_size=2, gradient_accumulation_steps=4, warmup_steps=5, max_steps=30, learning_rate=2e-4, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=seed, output_dir=os.path.join(run_info_dir, "trainer_outputs"), report_to="none", ), ) # 6. Train trainer.train() # 7. Save merged model model.save_pretrained_merged(model_save_dir, tokenizer, save_method="merged_16bit") tokenizer.save_pretrained(model_save_dir) # 8. Test-set inference + accuracy FastLanguageModel.for_inference(model) model.eval() results = [] exact_match_correct = 0 label_correct = 0 label_total = 0 parsed_prediction_count = 0 for idx, sample in enumerate(test_raw): conversations = sample.get("conversations", []) user_prompt, gold_text = extract_conversation_pair(conversations) if not user_prompt: continue gold_labels = parse_label_array(gold_text) pred_text = generate_prediction(user_prompt) pred_labels = parse_label_array(pred_text) if pred_labels: parsed_prediction_count += 1 exact_match = bool(gold_labels) and pred_labels == gold_labels if exact_match: exact_match_correct += 1 sample_label_correct = 0 for pos, gold_label in enumerate(gold_labels): if pos < len(pred_labels) and pred_labels[pos] == gold_label: sample_label_correct += 1 label_correct += sample_label_correct label_total += len(gold_labels) results.append( { "sample_index": idx, "gold_labels": gold_labels, "predicted_labels": pred_labels, "raw_prediction": pred_text, "exact_match": exact_match, "label_accuracy": ( sample_label_correct / len(gold_labels) if gold_labels else None ), } ) total_samples = len(results) exact_match_accuracy = exact_match_correct / total_samples if total_samples else 0.0 label_accuracy = label_correct / label_total if label_total else 0.0 accuracy_summary = { "model_name": model_name, "model_save_dir": model_save_dir, "run_info_dir": run_info_dir, "dataset_path": data_path, "seed": seed, "test_size": test_size, "test_samples_evaluated": total_samples, "parsed_prediction_count": parsed_prediction_count, "exact_match_accuracy": exact_match_accuracy, "label_accuracy": label_accuracy, "exact_match_correct": exact_match_correct, "label_correct": label_correct, "label_total": label_total, "timestamp": timestamp, } predictions_path = os.path.join(run_info_dir, "test_inference.json") accuracy_path = os.path.join(run_info_dir, "test_accuracy.json") with open(predictions_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) with open(accuracy_path, "w", encoding="utf-8") as f: json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) print(f"Saved merged model to: {model_save_dir}") print(f"Saved run info folder to: {run_info_dir}") print(f"Saved test inference to: {predictions_path}") print(f"Saved test accuracy to: {accuracy_path}") print(f"Exact match accuracy: {exact_match_accuracy:.4f}") print(f"Label accuracy: {label_accuracy:.4f}")