orbgen-training / evaluate.py
javasop's picture
Upload folder using huggingface_hub
9791706 verified
#!/usr/bin/env python3
"""
OrbGen Evaluation Script
Evaluates a trained model on the test set with Orbital validation metrics.
Usage:
python evaluate.py --checkpoint ./orbgen-1.5b/final
python evaluate.py --checkpoint ./orbgen-1.5b/final --use_validator
"""
import os
import json
import fire
import torch
import subprocess
import tempfile
from pathlib import Path
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from tqdm import tqdm
def validate_schema(schema_json: str) -> tuple[bool, list[str]]:
"""Validate schema using orbital CLI."""
# Check valid JSON first
try:
json.loads(schema_json)
except json.JSONDecodeError as e:
return False, [f"Invalid JSON: {e}"]
# Write to temp file and validate
with tempfile.NamedTemporaryFile(mode='w', suffix='.orb', delete=False) as f:
f.write(schema_json)
temp_path = f.name
try:
# Find orbital binary - check multiple locations
orbital_cmd = 'orbital'
for path in ['/usr/local/bin/orbital', os.path.expanduser('~/kflow.ai.builder/orbital-rust/target/release/orbital')]:
if os.path.exists(path):
orbital_cmd = path
break
result = subprocess.run(
[orbital_cmd, 'validate', temp_path],
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0 or 'Schema is valid' in result.stdout:
return True, []
else:
errors = [line for line in result.stderr.split('\n') if line.strip()]
return False, errors[:5]
except subprocess.TimeoutExpired:
return False, ["Validation timeout"]
except FileNotFoundError:
return False, ["Orbital CLI not found - install it or use --use_validator=False"]
except Exception as e:
return False, [f"Validation error: {e}"]
finally:
Path(temp_path).unlink(missing_ok=True)
def extract_completion(generated_text: str) -> str:
"""Extract the completion from generated text."""
# Try to find assistant response
if '<|im_start|>assistant' in generated_text:
parts = generated_text.split('<|im_start|>assistant')
if len(parts) > 1:
completion = parts[-1]
if '<|im_end|>' in completion:
completion = completion.split('<|im_end|>')[0]
return completion.strip()
# Try to find JSON object
start = generated_text.find('{')
if start != -1:
# Find matching closing brace
depth = 0
for i, char in enumerate(generated_text[start:]):
if char == '{':
depth += 1
elif char == '}':
depth -= 1
if depth == 0:
return generated_text[start:start + i + 1]
return generated_text
def main(
checkpoint: str = "./orbgen-1.5b/final",
dataset: str = "orbital-ai/orbital-schemas",
split: str = "test",
use_validator: bool = False,
max_samples: int = -1,
output_file: str = "evaluation_results.json",
):
"""Evaluate model on test set."""
print("=" * 60)
print("OrbGen Evaluation")
print("=" * 60)
print(f"Checkpoint: {checkpoint}")
print(f"Dataset: {dataset}")
print(f"Split: {split}")
print(f"Use Validator: {use_validator}")
print("=" * 60)
# Load tokenizer and model
print("\nLoading model...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
# Load dataset
print("Loading dataset...")
ds = load_dataset(dataset)
test_data = ds[split]
if max_samples > 0:
test_data = test_data.select(range(min(max_samples, len(test_data))))
print(f"Evaluating on {len(test_data)} examples...")
# Metrics
metrics = {
'total': len(test_data),
'valid_json': 0,
'valid_schema': 0,
'generation_errors': 0,
}
results = []
system_prompt = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions.
Rules:
1. Output ONLY valid JSON - no explanations, no markdown code blocks
2. Every schema must have: name, version, orbitals array
3. Each orbital must have: name, entity, traits, pages
4. Each entity must have: name, collection (or runtime/singleton), fields
5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine
6. State machines must have: states (with one isInitial:true), events, transitions
7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}]
8. Pages must have: name, path, entity, traits"""
for i, example in enumerate(tqdm(test_data)):
prompt = example['prompt']
expected = example['completion']
# Format input
input_text = f"""<|im_start|>system
{system_prompt}
<|im_end|>
<|im_start|>user
{prompt}
<|im_end|>
<|im_start|>assistant
"""
try:
# Generate
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=4096,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
completion = extract_completion(generated)
# Check valid JSON
is_valid_json = False
is_valid_schema = False
errors = []
try:
json.loads(completion)
is_valid_json = True
metrics['valid_json'] += 1
# Check valid schema
if use_validator:
is_valid_schema, errors = validate_schema(completion)
if is_valid_schema:
metrics['valid_schema'] += 1
else:
# Basic structural check
parsed = json.loads(completion)
if 'name' in parsed and 'orbitals' in parsed:
is_valid_schema = True
metrics['valid_schema'] += 1
except json.JSONDecodeError as e:
errors = [f"JSON error: {e}"]
results.append({
'prompt': prompt,
'expected': expected[:500] + '...' if len(expected) > 500 else expected,
'generated': completion[:500] + '...' if len(completion) > 500 else completion,
'valid_json': is_valid_json,
'valid_schema': is_valid_schema,
'errors': errors,
})
except Exception as e:
metrics['generation_errors'] += 1
results.append({
'prompt': prompt,
'error': str(e),
'valid_json': False,
'valid_schema': False,
})
# Calculate percentages
metrics['valid_json_pct'] = metrics['valid_json'] / metrics['total'] * 100
metrics['valid_schema_pct'] = metrics['valid_schema'] / metrics['total'] * 100
# Print results
print("\n" + "=" * 60)
print("Results")
print("=" * 60)
print(f"Total examples: {metrics['total']}")
print(f"Valid JSON: {metrics['valid_json']} ({metrics['valid_json_pct']:.1f}%)")
print(f"Valid Schema: {metrics['valid_schema']} ({metrics['valid_schema_pct']:.1f}%)")
print(f"Generation errors: {metrics['generation_errors']}")
# Save results
output = {
'metrics': metrics,
'results': results,
}
with open(output_file, 'w') as f:
json.dump(output, f, indent=2)
print(f"\nResults saved to: {output_file}")
return metrics
if __name__ == "__main__":
fire.Fire(main)