""" validate_code.py - Executes the complete code completion validation suite. This script: 1. Loads the trained model 2. Executes all test cases 3. Calculates evaluation metrics 4. Generates a detailed report Usage: python validation/validate_code.py python validation/validate_code.py --verbose python validation/validate_code.py --category brackets """ import os import sys import pickle import argparse import json from datetime import datetime from typing import List, Optional import torch # Add root directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from src.model import RippleGPT from src.config import RippleConfig from validation.code.test_cases import get_all_test_cases, get_tests_by_category, get_categories, TestCase from validation.code.metrics import ( TestResult, evaluate_test_case, generate_report, format_report, check_brackets_balanced ) # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') RESULTS_DIR = os.path.join(os.path.dirname(__file__), 'results') DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' def load_model(checkpoint_path: str = None) -> tuple: """ Loads the model and returns (model, encode_fn, decode_fn). """ # Find checkpoint if checkpoint_path is None: best_path = os.path.join(CKPT_DIR, 'ckpt_best.pt') final_path = os.path.join(CKPT_DIR, 'ckpt_final.pt') if os.path.exists(best_path): checkpoint_path = best_path elif os.path.exists(final_path): checkpoint_path = final_path else: raise FileNotFoundError( f"No checkpoint found in {CKPT_DIR}\n" "Run first: python validation/train_code.py" ) print(f"๐Ÿ“ฆ Loading model from: {checkpoint_path}") # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) config = checkpoint['config'] # Initialize model model = RippleGPT(config) # Clean compiled models prefix state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k in list(state_dict.keys()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() # Load vocabulary meta_path = os.path.join(DATA_DIR, 'meta.pkl') with open(meta_path, 'rb') as f: meta = pickle.load(f) stoi = meta['stoi'] itos = meta['itos'] # Encode/decode functions (with fallback for unknown characters) unknown_token = stoi.get('?', stoi.get(' ', 0)) encode = lambda s: [stoi.get(c, unknown_token) for c in s] decode = lambda l: ''.join([itos.get(i, '?') for i in l]) print(f" โœ… Model loaded ({model.get_num_params()/1e6:.2f}M parameters)") return model, encode, decode @torch.no_grad() def generate_completion( model: RippleGPT, prompt: str, encode, decode, max_tokens: int = 50, temperature: float = 0.7, top_k: int = 50 ) -> str: """ Generates completion for a prompt. """ # Encode prompt input_ids = encode(prompt) x = torch.tensor(input_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) # Generate output = model.generate(x, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) # Decode only the generated part full_text = decode(output[0].tolist()) generated = full_text[len(prompt):] return generated def run_test_case( model: RippleGPT, test: TestCase, encode, decode, verbose: bool = False ) -> TestResult: """ Executes a test case and returns the result. """ # Generate completion generated = generate_completion( model, test.prompt, encode, decode, max_tokens=test.max_tokens ) # Evaluate result passed, score, matched, failed, forbidden = evaluate_test_case( prompt=test.prompt, generated=generated, expected_patterns=test.expected_patterns, forbidden_patterns=test.forbidden_patterns ) result = TestResult( test_name=test.name, category=test.category, passed=passed, prompt=test.prompt, generated=generated, expected_patterns=test.expected_patterns, matched_patterns=matched, failed_patterns=failed, forbidden_matches=forbidden, score=score ) if verbose: status = "โœ…" if passed else "โŒ" print(f"\n{status} {test.name} ({test.category})") print(f" Prompt: {repr(test.prompt[:50])}...") print(f" Generated: {repr(generated[:50])}...") print(f" Score: {score:.2f}") if failed: print(f" Missing patterns: {failed}") return result def run_validation( model: RippleGPT, encode, decode, categories: Optional[List[str]] = None, verbose: bool = False ) -> List[TestResult]: """ Executes all validation tests. """ # Select tests if categories: tests = [] for cat in categories: tests.extend(get_tests_by_category(cat)) else: tests = get_all_test_cases() print(f"\n๐Ÿงช Running {len(tests)} tests...") results = [] for i, test in enumerate(tests): if not verbose: print(f"\r Progress: {i+1}/{len(tests)}", end="", flush=True) result = run_test_case(model, test, encode, decode, verbose=verbose) results.append(result) if not verbose: print() # New line after progress return results def save_results(report, results: List[TestResult]): """Saves results to a JSON file.""" os.makedirs(RESULTS_DIR, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Save detailed results results_data = { 'timestamp': timestamp, 'model': report.model_name, 'summary': { 'total_tests': report.total_tests, 'passed': report.total_passed, 'accuracy': report.overall_accuracy, 'bracket_accuracy': report.bracket_accuracy, 'indentation_accuracy': report.indentation_accuracy, 'structure_accuracy': report.structure_accuracy }, 'categories': { name: { 'total': cat.total_tests, 'passed': cat.passed_tests, 'accuracy': cat.accuracy } for name, cat in report.category_results.items() }, 'tests': [ { 'name': r.test_name, 'category': r.category, 'passed': r.passed, 'score': r.score, 'prompt': r.prompt, 'generated': r.generated, 'matched': r.matched_patterns, 'failed': r.failed_patterns } for r in results ] } results_path = os.path.join(RESULTS_DIR, f'validation_{timestamp}.json') with open(results_path, 'w') as f: json.dump(results_data, f, indent=2) print(f"\n๐Ÿ’พ Results saved to: {results_path}") return results_path def main(): parser = argparse.ArgumentParser(description='RippleGPT Code Completion Validation') parser.add_argument('--checkpoint', type=str, help='Path to specific checkpoint') parser.add_argument('--category', type=str, choices=get_categories(), help='Run only one category') parser.add_argument('--verbose', '-v', action='store_true', help='Show details for each test') parser.add_argument('--no-save', action='store_true', help='Do not save results to file') args = parser.parse_args() print("=" * 60) print("๐Ÿงช CODE COMPLETION VALIDATION - RippleGPT") print("=" * 60) # Load model try: model, encode, decode = load_model(args.checkpoint) except FileNotFoundError as e: print(f"\nโŒ {e}") return 1 # Define categories categories = [args.category] if args.category else None # Run validation results = run_validation(model, encode, decode, categories=categories, verbose=args.verbose) # Generate report report = generate_report("RippleGPT", results) # Print report print("\n" + format_report(report)) # Save results if not args.no_save: save_results(report, results) # Return exit code based on result if report.overall_accuracy >= 0.7: print("\n๐ŸŽ‰ Validation passed successfully!") return 0 elif report.overall_accuracy >= 0.5: print("\nโš ๏ธ Validation passed partially. More training recommended.") return 0 else: print("\nโŒ Validation failed. Model needs more training.") return 1 if __name__ == '__main__': exit(main())