|
|
""" |
|
|
validate_code.py - Executes the complete code completion validation suite. |
|
|
|
|
|
This script: |
|
|
1. Loads the trained model |
|
|
2. Executes all test cases |
|
|
3. Calculates evaluation metrics |
|
|
4. Generates a detailed report |
|
|
|
|
|
Usage: |
|
|
python validation/validate_code.py |
|
|
python validation/validate_code.py --verbose |
|
|
python validation/validate_code.py --category brackets |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import pickle |
|
|
import argparse |
|
|
import json |
|
|
from datetime import datetime |
|
|
from typing import List, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
|
|
|
|
from src.model import RippleGPT |
|
|
from src.config import RippleConfig |
|
|
from validation.code.test_cases import get_all_test_cases, get_tests_by_category, get_categories, TestCase |
|
|
from validation.code.metrics import ( |
|
|
TestResult, |
|
|
evaluate_test_case, |
|
|
generate_report, |
|
|
format_report, |
|
|
check_brackets_balanced |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') |
|
|
CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') |
|
|
RESULTS_DIR = os.path.join(os.path.dirname(__file__), 'results') |
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path: str = None) -> tuple: |
|
|
""" |
|
|
Loads the model and returns (model, encode_fn, decode_fn). |
|
|
""" |
|
|
|
|
|
if checkpoint_path is None: |
|
|
best_path = os.path.join(CKPT_DIR, 'ckpt_best.pt') |
|
|
final_path = os.path.join(CKPT_DIR, 'ckpt_final.pt') |
|
|
|
|
|
if os.path.exists(best_path): |
|
|
checkpoint_path = best_path |
|
|
elif os.path.exists(final_path): |
|
|
checkpoint_path = final_path |
|
|
else: |
|
|
raise FileNotFoundError( |
|
|
f"No checkpoint found in {CKPT_DIR}\n" |
|
|
"Run first: python validation/train_code.py" |
|
|
) |
|
|
|
|
|
print(f"📦 Loading model from: {checkpoint_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) |
|
|
config = checkpoint['config'] |
|
|
|
|
|
|
|
|
model = RippleGPT(config) |
|
|
|
|
|
|
|
|
state_dict = checkpoint['model'] |
|
|
unwanted_prefix = '_orig_mod.' |
|
|
for k in list(state_dict.keys()): |
|
|
if k.startswith(unwanted_prefix): |
|
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
meta_path = os.path.join(DATA_DIR, 'meta.pkl') |
|
|
with open(meta_path, 'rb') as f: |
|
|
meta = pickle.load(f) |
|
|
|
|
|
stoi = meta['stoi'] |
|
|
itos = meta['itos'] |
|
|
|
|
|
|
|
|
unknown_token = stoi.get('?', stoi.get(' ', 0)) |
|
|
encode = lambda s: [stoi.get(c, unknown_token) for c in s] |
|
|
decode = lambda l: ''.join([itos.get(i, '?') for i in l]) |
|
|
|
|
|
print(f" ✅ Model loaded ({model.get_num_params()/1e6:.2f}M parameters)") |
|
|
|
|
|
return model, encode, decode |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_completion( |
|
|
model: RippleGPT, |
|
|
prompt: str, |
|
|
encode, |
|
|
decode, |
|
|
max_tokens: int = 50, |
|
|
temperature: float = 0.7, |
|
|
top_k: int = 50 |
|
|
) -> str: |
|
|
""" |
|
|
Generates completion for a prompt. |
|
|
""" |
|
|
|
|
|
input_ids = encode(prompt) |
|
|
x = torch.tensor(input_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) |
|
|
|
|
|
|
|
|
output = model.generate(x, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) |
|
|
|
|
|
|
|
|
full_text = decode(output[0].tolist()) |
|
|
generated = full_text[len(prompt):] |
|
|
|
|
|
return generated |
|
|
|
|
|
|
|
|
def run_test_case( |
|
|
model: RippleGPT, |
|
|
test: TestCase, |
|
|
encode, |
|
|
decode, |
|
|
verbose: bool = False |
|
|
) -> TestResult: |
|
|
""" |
|
|
Executes a test case and returns the result. |
|
|
""" |
|
|
|
|
|
generated = generate_completion( |
|
|
model, test.prompt, encode, decode, |
|
|
max_tokens=test.max_tokens |
|
|
) |
|
|
|
|
|
|
|
|
passed, score, matched, failed, forbidden = evaluate_test_case( |
|
|
prompt=test.prompt, |
|
|
generated=generated, |
|
|
expected_patterns=test.expected_patterns, |
|
|
forbidden_patterns=test.forbidden_patterns |
|
|
) |
|
|
|
|
|
result = TestResult( |
|
|
test_name=test.name, |
|
|
category=test.category, |
|
|
passed=passed, |
|
|
prompt=test.prompt, |
|
|
generated=generated, |
|
|
expected_patterns=test.expected_patterns, |
|
|
matched_patterns=matched, |
|
|
failed_patterns=failed, |
|
|
forbidden_matches=forbidden, |
|
|
score=score |
|
|
) |
|
|
|
|
|
if verbose: |
|
|
status = "✅" if passed else "❌" |
|
|
print(f"\n{status} {test.name} ({test.category})") |
|
|
print(f" Prompt: {repr(test.prompt[:50])}...") |
|
|
print(f" Generated: {repr(generated[:50])}...") |
|
|
print(f" Score: {score:.2f}") |
|
|
if failed: |
|
|
print(f" Missing patterns: {failed}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def run_validation( |
|
|
model: RippleGPT, |
|
|
encode, |
|
|
decode, |
|
|
categories: Optional[List[str]] = None, |
|
|
verbose: bool = False |
|
|
) -> List[TestResult]: |
|
|
""" |
|
|
Executes all validation tests. |
|
|
""" |
|
|
|
|
|
if categories: |
|
|
tests = [] |
|
|
for cat in categories: |
|
|
tests.extend(get_tests_by_category(cat)) |
|
|
else: |
|
|
tests = get_all_test_cases() |
|
|
|
|
|
print(f"\n🧪 Running {len(tests)} tests...") |
|
|
|
|
|
results = [] |
|
|
for i, test in enumerate(tests): |
|
|
if not verbose: |
|
|
print(f"\r Progress: {i+1}/{len(tests)}", end="", flush=True) |
|
|
|
|
|
result = run_test_case(model, test, encode, decode, verbose=verbose) |
|
|
results.append(result) |
|
|
|
|
|
if not verbose: |
|
|
print() |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def save_results(report, results: List[TestResult]): |
|
|
"""Saves results to a JSON file.""" |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
results_data = { |
|
|
'timestamp': timestamp, |
|
|
'model': report.model_name, |
|
|
'summary': { |
|
|
'total_tests': report.total_tests, |
|
|
'passed': report.total_passed, |
|
|
'accuracy': report.overall_accuracy, |
|
|
'bracket_accuracy': report.bracket_accuracy, |
|
|
'indentation_accuracy': report.indentation_accuracy, |
|
|
'structure_accuracy': report.structure_accuracy |
|
|
}, |
|
|
'categories': { |
|
|
name: { |
|
|
'total': cat.total_tests, |
|
|
'passed': cat.passed_tests, |
|
|
'accuracy': cat.accuracy |
|
|
} |
|
|
for name, cat in report.category_results.items() |
|
|
}, |
|
|
'tests': [ |
|
|
{ |
|
|
'name': r.test_name, |
|
|
'category': r.category, |
|
|
'passed': r.passed, |
|
|
'score': r.score, |
|
|
'prompt': r.prompt, |
|
|
'generated': r.generated, |
|
|
'matched': r.matched_patterns, |
|
|
'failed': r.failed_patterns |
|
|
} |
|
|
for r in results |
|
|
] |
|
|
} |
|
|
|
|
|
results_path = os.path.join(RESULTS_DIR, f'validation_{timestamp}.json') |
|
|
with open(results_path, 'w') as f: |
|
|
json.dump(results_data, f, indent=2) |
|
|
|
|
|
print(f"\n💾 Results saved to: {results_path}") |
|
|
|
|
|
return results_path |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='RippleGPT Code Completion Validation') |
|
|
parser.add_argument('--checkpoint', type=str, help='Path to specific checkpoint') |
|
|
parser.add_argument('--category', type=str, choices=get_categories(), help='Run only one category') |
|
|
parser.add_argument('--verbose', '-v', action='store_true', help='Show details for each test') |
|
|
parser.add_argument('--no-save', action='store_true', help='Do not save results to file') |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 60) |
|
|
print("🧪 CODE COMPLETION VALIDATION - RippleGPT") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
try: |
|
|
model, encode, decode = load_model(args.checkpoint) |
|
|
except FileNotFoundError as e: |
|
|
print(f"\n❌ {e}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
categories = [args.category] if args.category else None |
|
|
|
|
|
|
|
|
results = run_validation(model, encode, decode, categories=categories, verbose=args.verbose) |
|
|
|
|
|
|
|
|
report = generate_report("RippleGPT", results) |
|
|
|
|
|
|
|
|
print("\n" + format_report(report)) |
|
|
|
|
|
|
|
|
if not args.no_save: |
|
|
save_results(report, results) |
|
|
|
|
|
|
|
|
if report.overall_accuracy >= 0.7: |
|
|
print("\n🎉 Validation passed successfully!") |
|
|
return 0 |
|
|
elif report.overall_accuracy >= 0.5: |
|
|
print("\n⚠️ Validation passed partially. More training recommended.") |
|
|
return 0 |
|
|
else: |
|
|
print("\n❌ Validation failed. Model needs more training.") |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
exit(main()) |
|
|
|