Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| OrbGen Evaluation Script | |
| Evaluates a trained model on the test set with Orbital validation metrics. | |
| Usage: | |
| python evaluate.py --checkpoint ./orbgen-1.5b/final | |
| python evaluate.py --checkpoint ./orbgen-1.5b/final --use_validator | |
| """ | |
| import os | |
| import json | |
| import fire | |
| import torch | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from tqdm import tqdm | |
| def validate_schema(schema_json: str) -> tuple[bool, list[str]]: | |
| """Validate schema using orbital CLI.""" | |
| # Check valid JSON first | |
| try: | |
| json.loads(schema_json) | |
| except json.JSONDecodeError as e: | |
| return False, [f"Invalid JSON: {e}"] | |
| # Write to temp file and validate | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.orb', delete=False) as f: | |
| f.write(schema_json) | |
| temp_path = f.name | |
| try: | |
| # Find orbital binary - check multiple locations | |
| orbital_cmd = 'orbital' | |
| for path in ['/usr/local/bin/orbital', os.path.expanduser('~/kflow.ai.builder/orbital-rust/target/release/orbital')]: | |
| if os.path.exists(path): | |
| orbital_cmd = path | |
| break | |
| result = subprocess.run( | |
| [orbital_cmd, 'validate', temp_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=30, | |
| ) | |
| if result.returncode == 0 or 'Schema is valid' in result.stdout: | |
| return True, [] | |
| else: | |
| errors = [line for line in result.stderr.split('\n') if line.strip()] | |
| return False, errors[:5] | |
| except subprocess.TimeoutExpired: | |
| return False, ["Validation timeout"] | |
| except FileNotFoundError: | |
| return False, ["Orbital CLI not found - install it or use --use_validator=False"] | |
| except Exception as e: | |
| return False, [f"Validation error: {e}"] | |
| finally: | |
| Path(temp_path).unlink(missing_ok=True) | |
| def extract_completion(generated_text: str) -> str: | |
| """Extract the completion from generated text.""" | |
| # Try to find assistant response | |
| if '<|im_start|>assistant' in generated_text: | |
| parts = generated_text.split('<|im_start|>assistant') | |
| if len(parts) > 1: | |
| completion = parts[-1] | |
| if '<|im_end|>' in completion: | |
| completion = completion.split('<|im_end|>')[0] | |
| return completion.strip() | |
| # Try to find JSON object | |
| start = generated_text.find('{') | |
| if start != -1: | |
| # Find matching closing brace | |
| depth = 0 | |
| for i, char in enumerate(generated_text[start:]): | |
| if char == '{': | |
| depth += 1 | |
| elif char == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| return generated_text[start:start + i + 1] | |
| return generated_text | |
| def main( | |
| checkpoint: str = "./orbgen-1.5b/final", | |
| dataset: str = "orbital-ai/orbital-schemas", | |
| split: str = "test", | |
| use_validator: bool = False, | |
| max_samples: int = -1, | |
| output_file: str = "evaluation_results.json", | |
| ): | |
| """Evaluate model on test set.""" | |
| print("=" * 60) | |
| print("OrbGen Evaluation") | |
| print("=" * 60) | |
| print(f"Checkpoint: {checkpoint}") | |
| print(f"Dataset: {dataset}") | |
| print(f"Split: {split}") | |
| print(f"Use Validator: {use_validator}") | |
| print("=" * 60) | |
| # Load tokenizer and model | |
| print("\nLoading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| checkpoint, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # Load dataset | |
| print("Loading dataset...") | |
| ds = load_dataset(dataset) | |
| test_data = ds[split] | |
| if max_samples > 0: | |
| test_data = test_data.select(range(min(max_samples, len(test_data)))) | |
| print(f"Evaluating on {len(test_data)} examples...") | |
| # Metrics | |
| metrics = { | |
| 'total': len(test_data), | |
| 'valid_json': 0, | |
| 'valid_schema': 0, | |
| 'generation_errors': 0, | |
| } | |
| results = [] | |
| system_prompt = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions. | |
| Rules: | |
| 1. Output ONLY valid JSON - no explanations, no markdown code blocks | |
| 2. Every schema must have: name, version, orbitals array | |
| 3. Each orbital must have: name, entity, traits, pages | |
| 4. Each entity must have: name, collection (or runtime/singleton), fields | |
| 5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine | |
| 6. State machines must have: states (with one isInitial:true), events, transitions | |
| 7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}] | |
| 8. Pages must have: name, path, entity, traits""" | |
| for i, example in enumerate(tqdm(test_data)): | |
| prompt = example['prompt'] | |
| expected = example['completion'] | |
| # Format input | |
| input_text = f"""<|im_start|>system | |
| {system_prompt} | |
| <|im_end|> | |
| <|im_start|>user | |
| {prompt} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| try: | |
| # Generate | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| temperature=0.7, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| completion = extract_completion(generated) | |
| # Check valid JSON | |
| is_valid_json = False | |
| is_valid_schema = False | |
| errors = [] | |
| try: | |
| json.loads(completion) | |
| is_valid_json = True | |
| metrics['valid_json'] += 1 | |
| # Check valid schema | |
| if use_validator: | |
| is_valid_schema, errors = validate_schema(completion) | |
| if is_valid_schema: | |
| metrics['valid_schema'] += 1 | |
| else: | |
| # Basic structural check | |
| parsed = json.loads(completion) | |
| if 'name' in parsed and 'orbitals' in parsed: | |
| is_valid_schema = True | |
| metrics['valid_schema'] += 1 | |
| except json.JSONDecodeError as e: | |
| errors = [f"JSON error: {e}"] | |
| results.append({ | |
| 'prompt': prompt, | |
| 'expected': expected[:500] + '...' if len(expected) > 500 else expected, | |
| 'generated': completion[:500] + '...' if len(completion) > 500 else completion, | |
| 'valid_json': is_valid_json, | |
| 'valid_schema': is_valid_schema, | |
| 'errors': errors, | |
| }) | |
| except Exception as e: | |
| metrics['generation_errors'] += 1 | |
| results.append({ | |
| 'prompt': prompt, | |
| 'error': str(e), | |
| 'valid_json': False, | |
| 'valid_schema': False, | |
| }) | |
| # Calculate percentages | |
| metrics['valid_json_pct'] = metrics['valid_json'] / metrics['total'] * 100 | |
| metrics['valid_schema_pct'] = metrics['valid_schema'] / metrics['total'] * 100 | |
| # Print results | |
| print("\n" + "=" * 60) | |
| print("Results") | |
| print("=" * 60) | |
| print(f"Total examples: {metrics['total']}") | |
| print(f"Valid JSON: {metrics['valid_json']} ({metrics['valid_json_pct']:.1f}%)") | |
| print(f"Valid Schema: {metrics['valid_schema']} ({metrics['valid_schema_pct']:.1f}%)") | |
| print(f"Generation errors: {metrics['generation_errors']}") | |
| # Save results | |
| output = { | |
| 'metrics': metrics, | |
| 'results': results, | |
| } | |
| with open(output_file, 'w') as f: | |
| json.dump(output, f, indent=2) | |
| print(f"\nResults saved to: {output_file}") | |
| return metrics | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |