orbgen-training / generate.py
javasop's picture
Upload folder using huggingface_hub
9791706 verified
#!/usr/bin/env python3
"""
OrbGen Inference Script
Generate Orbital schemas from natural language prompts.
Usage:
python generate.py --prompt "Create a task management app"
python generate.py --prompt "..." --checkpoint ./orbgen-1.5b/final
python generate.py --interactive
"""
import os
import json
import fire
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PROMPT = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions.
Rules:
1. Output ONLY valid JSON - no explanations, no markdown code blocks
2. Every schema must have: name, version, orbitals array
3. Each orbital must have: name, entity, traits, pages
4. Each entity must have: name, collection (or runtime/singleton), fields
5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine
6. State machines must have: states (with one isInitial:true), events, transitions
7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}]
8. Pages must have: name, path, entity, traits"""
class OrbGen:
"""OrbGen schema generator."""
def __init__(
self,
checkpoint: str = "orbital-ai/orbgen-1.5b",
device: str = "auto",
):
"""Initialize the generator."""
print(f"Loading model from {checkpoint}...")
self.tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
trust_remote_code=True,
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True,
)
self.model.eval()
print("Model loaded!")
def generate(
self,
prompt: str,
temperature: float = 0.7,
top_p: float = 0.95,
max_new_tokens: int = 4096,
) -> str:
"""Generate a schema from a prompt."""
input_text = f"""<|im_start|>system
{SYSTEM_PROMPT}
<|im_end|>
<|im_start|>user
{prompt}
<|im_end|>
<|im_start|>assistant
"""
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract completion
if '<|im_start|>assistant' in generated:
parts = generated.split('<|im_start|>assistant')
if len(parts) > 1:
completion = parts[-1]
if '<|im_end|>' in completion:
completion = completion.split('<|im_end|>')[0]
return completion.strip()
# Try to find JSON
start = generated.find('{')
if start != -1:
depth = 0
for i, char in enumerate(generated[start:]):
if char == '{':
depth += 1
elif char == '}':
depth -= 1
if depth == 0:
return generated[start:start + i + 1]
return generated
def main(
prompt: str = None,
checkpoint: str = "orbital-ai/orbgen-1.5b",
output: str = None,
temperature: float = 0.7,
top_p: float = 0.95,
interactive: bool = False,
validate: bool = False,
):
"""Generate Orbital schemas."""
generator = OrbGen(checkpoint=checkpoint)
if interactive:
print("\n" + "=" * 60)
print("OrbGen Interactive Mode")
print("=" * 60)
print("Enter prompts to generate schemas. Type 'quit' to exit.\n")
while True:
try:
prompt = input("Prompt> ").strip()
if prompt.lower() in ['quit', 'exit', 'q']:
break
if not prompt:
continue
print("\nGenerating...")
result = generator.generate(prompt, temperature=temperature, top_p=top_p)
try:
parsed = json.loads(result)
print(json.dumps(parsed, indent=2))
except json.JSONDecodeError:
print(result)
print()
except KeyboardInterrupt:
print("\nExiting...")
break
elif prompt:
print(f"\nPrompt: {prompt}\n")
print("Generating...")
result = generator.generate(prompt, temperature=temperature, top_p=top_p)
try:
parsed = json.loads(result)
formatted = json.dumps(parsed, indent=2)
if output:
with open(output, 'w') as f:
f.write(formatted)
print(f"Schema saved to: {output}")
else:
print(formatted)
# Validate if requested
if validate:
import subprocess
import tempfile
from pathlib import Path
with tempfile.NamedTemporaryFile(mode='w', suffix='.orb', delete=False) as f:
f.write(formatted)
temp_path = f.name
try:
result = subprocess.run(
['orbital', 'validate', temp_path],
capture_output=True,
text=True,
cwd=os.path.expanduser('~/kflow.ai.builder/builder')
)
print("\nValidation:")
print(result.stdout or result.stderr)
finally:
Path(temp_path).unlink(missing_ok=True)
except json.JSONDecodeError as e:
print(f"Warning: Generated invalid JSON: {e}")
print(result)
else:
print("Usage: python generate.py --prompt 'Your prompt here'")
print(" python generate.py --interactive")
if __name__ == "__main__":
fire.Fire(main)