Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| OrbGen Inference Script | |
| Generate Orbital schemas from natural language prompts. | |
| Usage: | |
| python generate.py --prompt "Create a task management app" | |
| python generate.py --prompt "..." --checkpoint ./orbgen-1.5b/final | |
| python generate.py --interactive | |
| """ | |
| import os | |
| import json | |
| import fire | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| SYSTEM_PROMPT = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions. | |
| Rules: | |
| 1. Output ONLY valid JSON - no explanations, no markdown code blocks | |
| 2. Every schema must have: name, version, orbitals array | |
| 3. Each orbital must have: name, entity, traits, pages | |
| 4. Each entity must have: name, collection (or runtime/singleton), fields | |
| 5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine | |
| 6. State machines must have: states (with one isInitial:true), events, transitions | |
| 7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}] | |
| 8. Pages must have: name, path, entity, traits""" | |
| class OrbGen: | |
| """OrbGen schema generator.""" | |
| def __init__( | |
| self, | |
| checkpoint: str = "orbital-ai/orbgen-1.5b", | |
| device: str = "auto", | |
| ): | |
| """Initialize the generator.""" | |
| print(f"Loading model from {checkpoint}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| checkpoint, | |
| trust_remote_code=True, | |
| ) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| checkpoint, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| trust_remote_code=True, | |
| ) | |
| self.model.eval() | |
| print("Model loaded!") | |
| def generate( | |
| self, | |
| prompt: str, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| max_new_tokens: int = 4096, | |
| ) -> str: | |
| """Generate a schema from a prompt.""" | |
| input_text = f"""<|im_start|>system | |
| {SYSTEM_PROMPT} | |
| <|im_end|> | |
| <|im_start|>user | |
| {prompt} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| generated = self.tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| # Extract completion | |
| if '<|im_start|>assistant' in generated: | |
| parts = generated.split('<|im_start|>assistant') | |
| if len(parts) > 1: | |
| completion = parts[-1] | |
| if '<|im_end|>' in completion: | |
| completion = completion.split('<|im_end|>')[0] | |
| return completion.strip() | |
| # Try to find JSON | |
| start = generated.find('{') | |
| if start != -1: | |
| depth = 0 | |
| for i, char in enumerate(generated[start:]): | |
| if char == '{': | |
| depth += 1 | |
| elif char == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| return generated[start:start + i + 1] | |
| return generated | |
| def main( | |
| prompt: str = None, | |
| checkpoint: str = "orbital-ai/orbgen-1.5b", | |
| output: str = None, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| interactive: bool = False, | |
| validate: bool = False, | |
| ): | |
| """Generate Orbital schemas.""" | |
| generator = OrbGen(checkpoint=checkpoint) | |
| if interactive: | |
| print("\n" + "=" * 60) | |
| print("OrbGen Interactive Mode") | |
| print("=" * 60) | |
| print("Enter prompts to generate schemas. Type 'quit' to exit.\n") | |
| while True: | |
| try: | |
| prompt = input("Prompt> ").strip() | |
| if prompt.lower() in ['quit', 'exit', 'q']: | |
| break | |
| if not prompt: | |
| continue | |
| print("\nGenerating...") | |
| result = generator.generate(prompt, temperature=temperature, top_p=top_p) | |
| try: | |
| parsed = json.loads(result) | |
| print(json.dumps(parsed, indent=2)) | |
| except json.JSONDecodeError: | |
| print(result) | |
| print() | |
| except KeyboardInterrupt: | |
| print("\nExiting...") | |
| break | |
| elif prompt: | |
| print(f"\nPrompt: {prompt}\n") | |
| print("Generating...") | |
| result = generator.generate(prompt, temperature=temperature, top_p=top_p) | |
| try: | |
| parsed = json.loads(result) | |
| formatted = json.dumps(parsed, indent=2) | |
| if output: | |
| with open(output, 'w') as f: | |
| f.write(formatted) | |
| print(f"Schema saved to: {output}") | |
| else: | |
| print(formatted) | |
| # Validate if requested | |
| if validate: | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.orb', delete=False) as f: | |
| f.write(formatted) | |
| temp_path = f.name | |
| try: | |
| result = subprocess.run( | |
| ['orbital', 'validate', temp_path], | |
| capture_output=True, | |
| text=True, | |
| cwd=os.path.expanduser('~/kflow.ai.builder/builder') | |
| ) | |
| print("\nValidation:") | |
| print(result.stdout or result.stderr) | |
| finally: | |
| Path(temp_path).unlink(missing_ok=True) | |
| except json.JSONDecodeError as e: | |
| print(f"Warning: Generated invalid JSON: {e}") | |
| print(result) | |
| else: | |
| print("Usage: python generate.py --prompt 'Your prompt here'") | |
| print(" python generate.py --interactive") | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |