# ABOUTME: Validate fine-tuned model against a held-out test dataset # ABOUTME: Reports accuracy and shows per-class breakdown import json from pathlib import Path import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer def load_test_dataset(paths: list[str]) -> list[dict]: """ Load test examples from JSONL file(s) or folder(s). Returns list of {"input": str, "expected": str} dicts. """ # Resolve paths (files and folders) resolved_files = [] for p in paths: path = Path(p) if path.is_dir(): resolved_files.extend(sorted(path.glob("*.jsonl"))) elif path.is_file(): resolved_files.append(path) else: raise FileNotFoundError(f"Path not found: {path}") if not resolved_files: raise ValueError("No test files found") examples = [] for file_path in resolved_files: print(f" Loading: {file_path}") with open(file_path, "r", encoding="utf-8") as f: for line in f: if line.strip(): data = json.loads(line) messages = data["messages"] # Extract user content and expected assistant response user_content = None expected = None for msg in messages: if msg["role"] == "user": user_content = msg["content"] elif msg["role"] == "assistant": expected = msg["content"].strip() if user_content and expected: examples.append( { "input": user_content, "expected": expected, } ) return examples def load_model( adapter_path: str, base_model_name: str = "Qwen/Qwen2.5-3B-Instruct", merge: bool = True, ): """ Load the fine-tuned model. Args: adapter_path: Path to the LoRA adapter base_model_name: Base model to load adapter onto merge: If True, merge adapter into base model (faster inference) """ print(f"Loading base model: {base_model_name}") # Determine device if torch.backends.mps.is_available(): device = "mps" torch_dtype = torch.float16 elif torch.cuda.is_available(): device = "cuda" torch_dtype = torch.bfloat16 else: device = "cpu" torch_dtype = torch.float32 print(f"Using device: {device}") tokenizer = AutoTokenizer.from_pretrained(base_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch_dtype, trust_remote_code=True, ) print(f"Loading adapter from: {adapter_path}") model = PeftModel.from_pretrained(base_model, adapter_path) if merge: print("Merging adapter into base model...") model = model.merge_and_unload() model = model.to(device) model.eval() return model, tokenizer def predict(model, tokenizer, user_input: str) -> str: """ Run inference and extract the predicted score. """ messages = [{"role": "user", "content": user_input}] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=10, do_sample=False, pad_token_id=tokenizer.pad_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the generated part (after the prompt) # Find the last digit in the response as the score generated = response[len(text) :] if len(response) > len(text) else response # Extract score - look for a digit 0-3 score = None for char in generated: if char in "0123": score = char break return score, generated.strip() def validate( adapter_path: str, test_paths: list[str], base_model_name: str = "Qwen/Qwen2.5-3B-Instruct", verbose: bool = False, ): """ Run validation and report results. """ print("=" * 60) print("Model Validation") print("=" * 60) # Load model model, tokenizer = load_model(adapter_path, base_model_name) # Load test data print(f"\nLoading test dataset:") test_examples = load_test_dataset(test_paths) print(f" Total test examples: {len(test_examples)}") # Run predictions print(f"\nRunning predictions...") results = { "correct": 0, "incorrect": 0, "unparseable": 0, "by_class": {str(i): {"correct": 0, "total": 0} for i in range(4)}, } errors = [] for i, example in enumerate(test_examples): expected = example["expected"] predicted, raw_output = predict(model, tokenizer, example["input"]) # Track by class if expected in results["by_class"]: results["by_class"][expected]["total"] += 1 if predicted is None: results["unparseable"] += 1 errors.append( { "input": example["input"][:100], "expected": expected, "predicted": predicted, "raw": raw_output, "error": "Could not parse score", } ) elif predicted == expected: results["correct"] += 1 if expected in results["by_class"]: results["by_class"][expected]["correct"] += 1 else: results["incorrect"] += 1 errors.append( { "input": example["input"][:100], "expected": expected, "predicted": predicted, "raw": raw_output, "error": "Wrong prediction", } ) # Progress if (i + 1) % 10 == 0: print(f" Processed {i + 1}/{len(test_examples)}...") # Calculate metrics total = results["correct"] + results["incorrect"] + results["unparseable"] accuracy = results["correct"] / total if total > 0 else 0 # Print results print("\n" + "=" * 60) print("Results") print("=" * 60) print(f"\nOverall Accuracy: {accuracy:.1%} ({results['correct']}/{total})") print(f" Correct: {results['correct']}") print(f" Incorrect: {results['incorrect']}") print(f" Unparseable: {results['unparseable']}") print(f"\nPer-Class Accuracy:") for cls in sorted(results["by_class"].keys()): data = results["by_class"][cls] if data["total"] > 0: cls_acc = data["correct"] / data["total"] print(f" Score {cls}: {cls_acc:.1%} ({data['correct']}/{data['total']})") else: print(f" Score {cls}: No examples") if errors and verbose: print(f"\nErrors ({len(errors)} total):") for err in errors[:10]: # Show first 10 print(f"\n Input: {err['input']}...") print(f" Expected: {err['expected']}, Predicted: {err['predicted']}") print(f" Raw output: {err['raw'][:50]}") print("\n" + "=" * 60) return { "accuracy": accuracy, "total": total, "results": results, "errors": errors, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Validate fine-tuned model") parser.add_argument( "--adapter", type=str, required=True, help="Path to the LoRA adapter directory", ) parser.add_argument( "--test", type=str, nargs="+", required=True, help="Path(s) to test dataset(s) - files or folders", ) parser.add_argument( "--base-model", type=str, default="Qwen/Qwen2.5-3B-Instruct", help="Base model name", ) parser.add_argument( "--verbose", action="store_true", help="Show detailed error output", ) args = parser.parse_args() validate( adapter_path=args.adapter, test_paths=args.test, base_model_name=args.base_model, verbose=args.verbose, )