RippleGPT-Nano / validation /code /validate_code.py
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
validate_code.py - Executes the complete code completion validation suite.
This script:
1. Loads the trained model
2. Executes all test cases
3. Calculates evaluation metrics
4. Generates a detailed report
Usage:
python validation/validate_code.py
python validation/validate_code.py --verbose
python validation/validate_code.py --category brackets
"""
import os
import sys
import pickle
import argparse
import json
from datetime import datetime
from typing import List, Optional
import torch
# Add root directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.model import RippleGPT
from src.config import RippleConfig
from validation.code.test_cases import get_all_test_cases, get_tests_by_category, get_categories, TestCase
from validation.code.metrics import (
TestResult,
evaluate_test_case,
generate_report,
format_report,
check_brackets_balanced
)
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
RESULTS_DIR = os.path.join(os.path.dirname(__file__), 'results')
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
def load_model(checkpoint_path: str = None) -> tuple:
"""
Loads the model and returns (model, encode_fn, decode_fn).
"""
# Find checkpoint
if checkpoint_path is None:
best_path = os.path.join(CKPT_DIR, 'ckpt_best.pt')
final_path = os.path.join(CKPT_DIR, 'ckpt_final.pt')
if os.path.exists(best_path):
checkpoint_path = best_path
elif os.path.exists(final_path):
checkpoint_path = final_path
else:
raise FileNotFoundError(
f"No checkpoint found in {CKPT_DIR}\n"
"Run first: python validation/train_code.py"
)
print(f"📦 Loading model from: {checkpoint_path}")
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
# Initialize model
model = RippleGPT(config)
# Clean compiled models prefix
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Load vocabulary
meta_path = os.path.join(DATA_DIR, 'meta.pkl')
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi = meta['stoi']
itos = meta['itos']
# Encode/decode functions (with fallback for unknown characters)
unknown_token = stoi.get('?', stoi.get(' ', 0))
encode = lambda s: [stoi.get(c, unknown_token) for c in s]
decode = lambda l: ''.join([itos.get(i, '?') for i in l])
print(f" ✅ Model loaded ({model.get_num_params()/1e6:.2f}M parameters)")
return model, encode, decode
@torch.no_grad()
def generate_completion(
model: RippleGPT,
prompt: str,
encode,
decode,
max_tokens: int = 50,
temperature: float = 0.7,
top_k: int = 50
) -> str:
"""
Generates completion for a prompt.
"""
# Encode prompt
input_ids = encode(prompt)
x = torch.tensor(input_ids, dtype=torch.long, device=DEVICE).unsqueeze(0)
# Generate
output = model.generate(x, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k)
# Decode only the generated part
full_text = decode(output[0].tolist())
generated = full_text[len(prompt):]
return generated
def run_test_case(
model: RippleGPT,
test: TestCase,
encode,
decode,
verbose: bool = False
) -> TestResult:
"""
Executes a test case and returns the result.
"""
# Generate completion
generated = generate_completion(
model, test.prompt, encode, decode,
max_tokens=test.max_tokens
)
# Evaluate result
passed, score, matched, failed, forbidden = evaluate_test_case(
prompt=test.prompt,
generated=generated,
expected_patterns=test.expected_patterns,
forbidden_patterns=test.forbidden_patterns
)
result = TestResult(
test_name=test.name,
category=test.category,
passed=passed,
prompt=test.prompt,
generated=generated,
expected_patterns=test.expected_patterns,
matched_patterns=matched,
failed_patterns=failed,
forbidden_matches=forbidden,
score=score
)
if verbose:
status = "✅" if passed else "❌"
print(f"\n{status} {test.name} ({test.category})")
print(f" Prompt: {repr(test.prompt[:50])}...")
print(f" Generated: {repr(generated[:50])}...")
print(f" Score: {score:.2f}")
if failed:
print(f" Missing patterns: {failed}")
return result
def run_validation(
model: RippleGPT,
encode,
decode,
categories: Optional[List[str]] = None,
verbose: bool = False
) -> List[TestResult]:
"""
Executes all validation tests.
"""
# Select tests
if categories:
tests = []
for cat in categories:
tests.extend(get_tests_by_category(cat))
else:
tests = get_all_test_cases()
print(f"\n🧪 Running {len(tests)} tests...")
results = []
for i, test in enumerate(tests):
if not verbose:
print(f"\r Progress: {i+1}/{len(tests)}", end="", flush=True)
result = run_test_case(model, test, encode, decode, verbose=verbose)
results.append(result)
if not verbose:
print() # New line after progress
return results
def save_results(report, results: List[TestResult]):
"""Saves results to a JSON file."""
os.makedirs(RESULTS_DIR, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save detailed results
results_data = {
'timestamp': timestamp,
'model': report.model_name,
'summary': {
'total_tests': report.total_tests,
'passed': report.total_passed,
'accuracy': report.overall_accuracy,
'bracket_accuracy': report.bracket_accuracy,
'indentation_accuracy': report.indentation_accuracy,
'structure_accuracy': report.structure_accuracy
},
'categories': {
name: {
'total': cat.total_tests,
'passed': cat.passed_tests,
'accuracy': cat.accuracy
}
for name, cat in report.category_results.items()
},
'tests': [
{
'name': r.test_name,
'category': r.category,
'passed': r.passed,
'score': r.score,
'prompt': r.prompt,
'generated': r.generated,
'matched': r.matched_patterns,
'failed': r.failed_patterns
}
for r in results
]
}
results_path = os.path.join(RESULTS_DIR, f'validation_{timestamp}.json')
with open(results_path, 'w') as f:
json.dump(results_data, f, indent=2)
print(f"\n💾 Results saved to: {results_path}")
return results_path
def main():
parser = argparse.ArgumentParser(description='RippleGPT Code Completion Validation')
parser.add_argument('--checkpoint', type=str, help='Path to specific checkpoint')
parser.add_argument('--category', type=str, choices=get_categories(), help='Run only one category')
parser.add_argument('--verbose', '-v', action='store_true', help='Show details for each test')
parser.add_argument('--no-save', action='store_true', help='Do not save results to file')
args = parser.parse_args()
print("=" * 60)
print("🧪 CODE COMPLETION VALIDATION - RippleGPT")
print("=" * 60)
# Load model
try:
model, encode, decode = load_model(args.checkpoint)
except FileNotFoundError as e:
print(f"\n❌ {e}")
return 1
# Define categories
categories = [args.category] if args.category else None
# Run validation
results = run_validation(model, encode, decode, categories=categories, verbose=args.verbose)
# Generate report
report = generate_report("RippleGPT", results)
# Print report
print("\n" + format_report(report))
# Save results
if not args.no_save:
save_results(report, results)
# Return exit code based on result
if report.overall_accuracy >= 0.7:
print("\n🎉 Validation passed successfully!")
return 0
elif report.overall_accuracy >= 0.5:
print("\n⚠️ Validation passed partially. More training recommended.")
return 0
else:
print("\n❌ Validation failed. Model needs more training.")
return 1
if __name__ == '__main__':
exit(main())