ff6347's picture
Upload folder using huggingface_hub
9ec3d1d verified
# ABOUTME: Validate fine-tuned model against a held-out test dataset
# ABOUTME: Reports accuracy and shows per-class breakdown
import json
from pathlib import Path
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_test_dataset(paths: list[str]) -> list[dict]:
"""
Load test examples from JSONL file(s) or folder(s).
Returns list of {"input": str, "expected": str} dicts.
"""
# Resolve paths (files and folders)
resolved_files = []
for p in paths:
path = Path(p)
if path.is_dir():
resolved_files.extend(sorted(path.glob("*.jsonl")))
elif path.is_file():
resolved_files.append(path)
else:
raise FileNotFoundError(f"Path not found: {path}")
if not resolved_files:
raise ValueError("No test files found")
examples = []
for file_path in resolved_files:
print(f" Loading: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
messages = data["messages"]
# Extract user content and expected assistant response
user_content = None
expected = None
for msg in messages:
if msg["role"] == "user":
user_content = msg["content"]
elif msg["role"] == "assistant":
expected = msg["content"].strip()
if user_content and expected:
examples.append(
{
"input": user_content,
"expected": expected,
}
)
return examples
def load_model(
adapter_path: str,
base_model_name: str = "Qwen/Qwen2.5-3B-Instruct",
merge: bool = True,
):
"""
Load the fine-tuned model.
Args:
adapter_path: Path to the LoRA adapter
base_model_name: Base model to load adapter onto
merge: If True, merge adapter into base model (faster inference)
"""
print(f"Loading base model: {base_model_name}")
# Determine device
if torch.backends.mps.is_available():
device = "mps"
torch_dtype = torch.float16
elif torch.cuda.is_available():
device = "cuda"
torch_dtype = torch.bfloat16
else:
device = "cpu"
torch_dtype = torch.float32
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
print(f"Loading adapter from: {adapter_path}")
model = PeftModel.from_pretrained(base_model, adapter_path)
if merge:
print("Merging adapter into base model...")
model = model.merge_and_unload()
model = model.to(device)
model.eval()
return model, tokenizer
def predict(model, tokenizer, user_input: str) -> str:
"""
Run inference and extract the predicted score.
"""
messages = [{"role": "user", "content": user_input}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the generated part (after the prompt)
# Find the last digit in the response as the score
generated = response[len(text) :] if len(response) > len(text) else response
# Extract score - look for a digit 0-3
score = None
for char in generated:
if char in "0123":
score = char
break
return score, generated.strip()
def validate(
adapter_path: str,
test_paths: list[str],
base_model_name: str = "Qwen/Qwen2.5-3B-Instruct",
verbose: bool = False,
):
"""
Run validation and report results.
"""
print("=" * 60)
print("Model Validation")
print("=" * 60)
# Load model
model, tokenizer = load_model(adapter_path, base_model_name)
# Load test data
print(f"\nLoading test dataset:")
test_examples = load_test_dataset(test_paths)
print(f" Total test examples: {len(test_examples)}")
# Run predictions
print(f"\nRunning predictions...")
results = {
"correct": 0,
"incorrect": 0,
"unparseable": 0,
"by_class": {str(i): {"correct": 0, "total": 0} for i in range(4)},
}
errors = []
for i, example in enumerate(test_examples):
expected = example["expected"]
predicted, raw_output = predict(model, tokenizer, example["input"])
# Track by class
if expected in results["by_class"]:
results["by_class"][expected]["total"] += 1
if predicted is None:
results["unparseable"] += 1
errors.append(
{
"input": example["input"][:100],
"expected": expected,
"predicted": predicted,
"raw": raw_output,
"error": "Could not parse score",
}
)
elif predicted == expected:
results["correct"] += 1
if expected in results["by_class"]:
results["by_class"][expected]["correct"] += 1
else:
results["incorrect"] += 1
errors.append(
{
"input": example["input"][:100],
"expected": expected,
"predicted": predicted,
"raw": raw_output,
"error": "Wrong prediction",
}
)
# Progress
if (i + 1) % 10 == 0:
print(f" Processed {i + 1}/{len(test_examples)}...")
# Calculate metrics
total = results["correct"] + results["incorrect"] + results["unparseable"]
accuracy = results["correct"] / total if total > 0 else 0
# Print results
print("\n" + "=" * 60)
print("Results")
print("=" * 60)
print(f"\nOverall Accuracy: {accuracy:.1%} ({results['correct']}/{total})")
print(f" Correct: {results['correct']}")
print(f" Incorrect: {results['incorrect']}")
print(f" Unparseable: {results['unparseable']}")
print(f"\nPer-Class Accuracy:")
for cls in sorted(results["by_class"].keys()):
data = results["by_class"][cls]
if data["total"] > 0:
cls_acc = data["correct"] / data["total"]
print(f" Score {cls}: {cls_acc:.1%} ({data['correct']}/{data['total']})")
else:
print(f" Score {cls}: No examples")
if errors and verbose:
print(f"\nErrors ({len(errors)} total):")
for err in errors[:10]: # Show first 10
print(f"\n Input: {err['input']}...")
print(f" Expected: {err['expected']}, Predicted: {err['predicted']}")
print(f" Raw output: {err['raw'][:50]}")
print("\n" + "=" * 60)
return {
"accuracy": accuracy,
"total": total,
"results": results,
"errors": errors,
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Validate fine-tuned model")
parser.add_argument(
"--adapter",
type=str,
required=True,
help="Path to the LoRA adapter directory",
)
parser.add_argument(
"--test",
type=str,
nargs="+",
required=True,
help="Path(s) to test dataset(s) - files or folders",
)
parser.add_argument(
"--base-model",
type=str,
default="Qwen/Qwen2.5-3B-Instruct",
help="Base model name",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Show detailed error output",
)
args = parser.parse_args()
validate(
adapter_path=args.adapter,
test_paths=args.test,
base_model_name=args.base_model,
verbose=args.verbose,
)