#!/usr/bin/env python3 """ OrbGen Training Script Fine-tunes a base model to generate valid Orbital schemas (.orb files). Usage: python train.py --config config.yaml python train.py --config config.yaml --debug --max_steps 100 """ import os import yaml import fire import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, DataCollatorForSeq2Seq, BitsAndBytesConfig, ) from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training from trl import SFTTrainer, SFTConfig import wandb def load_config(config_path: str) -> dict: """Load configuration from YAML file.""" with open(config_path, 'r') as f: return yaml.safe_load(f) def format_example(example: dict, tokenizer) -> str: """Format a single training example as a chat conversation.""" system_prompt = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions. Rules: 1. Output ONLY valid JSON - no explanations, no markdown code blocks 2. Every schema must have: name, version, orbitals array 3. Each orbital must have: name, entity, traits, pages 4. Each entity must have: name, collection (or runtime/singleton), fields 5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine 6. State machines must have: states (with one isInitial:true), events, transitions 7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}] 8. Pages must have: name, path, entity, traits""" return f"""<|im_start|>system {system_prompt} <|im_end|> <|im_start|>user {example['prompt']} <|im_end|> <|im_start|>assistant {example['completion']} <|im_end|>""" def main( config: str = "config.yaml", debug: bool = False, max_steps: int = -1, resume_from_checkpoint: str = None, ): """Main training function.""" # Load configuration cfg = load_config(config) print("=" * 60) print("OrbGen Training") print("=" * 60) print(f"Base model: {cfg['model']['base_model']}") print(f"Output dir: {cfg['model']['output_dir']}") print(f"Debug mode: {debug}") print("=" * 60) # Initialize wandb if not debug: wandb.init( project=cfg['wandb']['project'], entity=cfg['wandb'].get('entity'), name=cfg['wandb']['run_name'], config=cfg, ) # Load tokenizer print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( cfg['model']['base_model'], trust_remote_code=True, ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Load model with optional quantization print("Loading model...") model_kwargs = { "trust_remote_code": True, "device_map": "auto", } # Check if 4-bit quantization is enabled quant_cfg = cfg.get('quantization', {}) if quant_cfg.get('enabled', False) and quant_cfg.get('load_in_4bit', False): print("Using 4-bit quantization (QLoRA)...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=getattr(torch, quant_cfg.get('bnb_4bit_compute_dtype', 'bfloat16')), bnb_4bit_quant_type=quant_cfg.get('bnb_4bit_quant_type', 'nf4'), bnb_4bit_use_double_quant=quant_cfg.get('bnb_4bit_use_double_quant', True), ) model_kwargs["quantization_config"] = bnb_config else: model_kwargs["torch_dtype"] = torch.bfloat16 model = AutoModelForCausalLM.from_pretrained( cfg['model']['base_model'], **model_kwargs, ) # Prepare model for training model.config.use_cache = False # For quantized models, use prepare_model_for_kbit_training if quant_cfg.get('enabled', False): model = prepare_model_for_kbit_training(model) else: model.enable_input_require_grads() # Configure LoRA if cfg['lora']['enabled']: print("Configuring LoRA...") lora_config = LoraConfig( r=cfg['lora']['r'], lora_alpha=cfg['lora']['lora_alpha'], lora_dropout=cfg['lora']['lora_dropout'], target_modules=cfg['lora']['target_modules'], bias=cfg['lora']['bias'], task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Load dataset print("\nLoading dataset...") # Support both HuggingFace dataset and local files if 'train_file' in cfg['data']: # Load from local JSONL files data_files = { 'train': cfg['data']['train_file'], 'validation': cfg['data']['eval_file'], } dataset = load_dataset('json', data_files=data_files) train_dataset = dataset['train'] eval_dataset = dataset['validation'] else: # Load from HuggingFace Hub dataset = load_dataset(cfg['data']['dataset']) train_dataset = dataset[cfg['data']['train_split']] eval_dataset = dataset[cfg['data']['eval_split']] print(f"Train examples: {len(train_dataset)}") print(f"Eval examples: {len(eval_dataset)}") # Format dataset def format_dataset(examples): texts = [] for i in range(len(examples['prompt'])): example = { 'prompt': examples['prompt'][i], 'completion': examples['completion'][i], } texts.append(format_example(example, tokenizer)) return {'text': texts} train_dataset = train_dataset.map( format_dataset, batched=True, remove_columns=train_dataset.column_names, ) eval_dataset = eval_dataset.map( format_dataset, batched=True, remove_columns=eval_dataset.column_names, ) # Training arguments training_args = SFTConfig( output_dir=cfg['model']['output_dir'], num_train_epochs=cfg['training']['num_epochs'] if not debug else 1, per_device_train_batch_size=cfg['training']['per_device_train_batch_size'], per_device_eval_batch_size=cfg['training']['per_device_eval_batch_size'], gradient_accumulation_steps=cfg['training']['gradient_accumulation_steps'], learning_rate=cfg['training']['learning_rate'], warmup_ratio=cfg['training']['warmup_ratio'], weight_decay=cfg['training']['weight_decay'], max_grad_norm=cfg['training']['max_grad_norm'], logging_steps=cfg['training']['logging_steps'], eval_strategy="steps", eval_steps=cfg['training']['eval_steps'], save_steps=cfg['training']['save_steps'], save_total_limit=cfg['training']['save_total_limit'], load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, max_length=cfg['model']['max_seq_length'], dataset_text_field="text", report_to="wandb" if not debug else "none", max_steps=max_steps if max_steps > 0 else -1, ) # Create trainer (TRL v0.27+ API) trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=tokenizer, ) # Train print("\nStarting training...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) # Save final model print("\nSaving model...") trainer.save_model(f"{cfg['model']['output_dir']}/final") tokenizer.save_pretrained(f"{cfg['model']['output_dir']}/final") # Push to HuggingFace Hub if configured hub_cfg = cfg.get('hub', {}) if hub_cfg.get('push_to_hub', False) and not debug: print("\nPushing model to HuggingFace Hub...") hub_model_id = hub_cfg.get('hub_model_id', 'orbital-ai/orbgen-1.5b') trainer.push_to_hub(commit_message="Final model after SFT training") print(f"Model pushed to: https://huggingface.co/{hub_model_id}") # Finish wandb if not debug: wandb.finish() print("\nTraining complete!") print(f"Model saved to: {cfg['model']['output_dir']}/final") if __name__ == "__main__": fire.Fire(main)