#!/usr/bin/env python3 """ OrbGen Evaluation Script Evaluates a trained model on the test set with Orbital validation metrics. Usage: python evaluate.py --checkpoint ./orbgen-1.5b/final python evaluate.py --checkpoint ./orbgen-1.5b/final --use_validator """ import os import json import fire import torch import subprocess import tempfile from pathlib import Path from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel from tqdm import tqdm def validate_schema(schema_json: str) -> tuple[bool, list[str]]: """Validate schema using orbital CLI.""" # Check valid JSON first try: json.loads(schema_json) except json.JSONDecodeError as e: return False, [f"Invalid JSON: {e}"] # Write to temp file and validate with tempfile.NamedTemporaryFile(mode='w', suffix='.orb', delete=False) as f: f.write(schema_json) temp_path = f.name try: # Find orbital binary - check multiple locations orbital_cmd = 'orbital' for path in ['/usr/local/bin/orbital', os.path.expanduser('~/kflow.ai.builder/orbital-rust/target/release/orbital')]: if os.path.exists(path): orbital_cmd = path break result = subprocess.run( [orbital_cmd, 'validate', temp_path], capture_output=True, text=True, timeout=30, ) if result.returncode == 0 or 'Schema is valid' in result.stdout: return True, [] else: errors = [line for line in result.stderr.split('\n') if line.strip()] return False, errors[:5] except subprocess.TimeoutExpired: return False, ["Validation timeout"] except FileNotFoundError: return False, ["Orbital CLI not found - install it or use --use_validator=False"] except Exception as e: return False, [f"Validation error: {e}"] finally: Path(temp_path).unlink(missing_ok=True) def extract_completion(generated_text: str) -> str: """Extract the completion from generated text.""" # Try to find assistant response if '<|im_start|>assistant' in generated_text: parts = generated_text.split('<|im_start|>assistant') if len(parts) > 1: completion = parts[-1] if '<|im_end|>' in completion: completion = completion.split('<|im_end|>')[0] return completion.strip() # Try to find JSON object start = generated_text.find('{') if start != -1: # Find matching closing brace depth = 0 for i, char in enumerate(generated_text[start:]): if char == '{': depth += 1 elif char == '}': depth -= 1 if depth == 0: return generated_text[start:start + i + 1] return generated_text def main( checkpoint: str = "./orbgen-1.5b/final", dataset: str = "orbital-ai/orbital-schemas", split: str = "test", use_validator: bool = False, max_samples: int = -1, output_file: str = "evaluation_results.json", ): """Evaluate model on test set.""" print("=" * 60) print("OrbGen Evaluation") print("=" * 60) print(f"Checkpoint: {checkpoint}") print(f"Dataset: {dataset}") print(f"Split: {split}") print(f"Use Validator: {use_validator}") print("=" * 60) # Load tokenizer and model print("\nLoading model...") tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.eval() # Load dataset print("Loading dataset...") ds = load_dataset(dataset) test_data = ds[split] if max_samples > 0: test_data = test_data.select(range(min(max_samples, len(test_data)))) print(f"Evaluating on {len(test_data)} examples...") # Metrics metrics = { 'total': len(test_data), 'valid_json': 0, 'valid_schema': 0, 'generation_errors': 0, } results = [] system_prompt = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions. Rules: 1. Output ONLY valid JSON - no explanations, no markdown code blocks 2. Every schema must have: name, version, orbitals array 3. Each orbital must have: name, entity, traits, pages 4. Each entity must have: name, collection (or runtime/singleton), fields 5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine 6. State machines must have: states (with one isInitial:true), events, transitions 7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}] 8. Pages must have: name, path, entity, traits""" for i, example in enumerate(tqdm(test_data)): prompt = example['prompt'] expected = example['completion'] # Format input input_text = f"""<|im_start|>system {system_prompt} <|im_end|> <|im_start|>user {prompt} <|im_end|> <|im_start|>assistant """ try: # Generate inputs = tokenizer(input_text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=4096, temperature=0.7, top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=False) completion = extract_completion(generated) # Check valid JSON is_valid_json = False is_valid_schema = False errors = [] try: json.loads(completion) is_valid_json = True metrics['valid_json'] += 1 # Check valid schema if use_validator: is_valid_schema, errors = validate_schema(completion) if is_valid_schema: metrics['valid_schema'] += 1 else: # Basic structural check parsed = json.loads(completion) if 'name' in parsed and 'orbitals' in parsed: is_valid_schema = True metrics['valid_schema'] += 1 except json.JSONDecodeError as e: errors = [f"JSON error: {e}"] results.append({ 'prompt': prompt, 'expected': expected[:500] + '...' if len(expected) > 500 else expected, 'generated': completion[:500] + '...' if len(completion) > 500 else completion, 'valid_json': is_valid_json, 'valid_schema': is_valid_schema, 'errors': errors, }) except Exception as e: metrics['generation_errors'] += 1 results.append({ 'prompt': prompt, 'error': str(e), 'valid_json': False, 'valid_schema': False, }) # Calculate percentages metrics['valid_json_pct'] = metrics['valid_json'] / metrics['total'] * 100 metrics['valid_schema_pct'] = metrics['valid_schema'] / metrics['total'] * 100 # Print results print("\n" + "=" * 60) print("Results") print("=" * 60) print(f"Total examples: {metrics['total']}") print(f"Valid JSON: {metrics['valid_json']} ({metrics['valid_json_pct']:.1f}%)") print(f"Valid Schema: {metrics['valid_schema']} ({metrics['valid_schema_pct']:.1f}%)") print(f"Generation errors: {metrics['generation_errors']}") # Save results output = { 'metrics': metrics, 'results': results, } with open(output_file, 'w') as f: json.dump(output, f, indent=2) print(f"\nResults saved to: {output_file}") return metrics if __name__ == "__main__": fire.Fire(main)