Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| OrbGen Training Script | |
| Fine-tunes a base model to generate valid Orbital schemas (.orb files). | |
| Usage: | |
| python train.py --config config.yaml | |
| python train.py --config config.yaml --debug --max_steps 100 | |
| """ | |
| import os | |
| import yaml | |
| import fire | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| DataCollatorForSeq2Seq, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training | |
| from trl import SFTTrainer, SFTConfig | |
| import wandb | |
| def load_config(config_path: str) -> dict: | |
| """Load configuration from YAML file.""" | |
| with open(config_path, 'r') as f: | |
| return yaml.safe_load(f) | |
| def format_example(example: dict, tokenizer) -> str: | |
| """Format a single training example as a chat conversation.""" | |
| system_prompt = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions. | |
| Rules: | |
| 1. Output ONLY valid JSON - no explanations, no markdown code blocks | |
| 2. Every schema must have: name, version, orbitals array | |
| 3. Each orbital must have: name, entity, traits, pages | |
| 4. Each entity must have: name, collection (or runtime/singleton), fields | |
| 5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine | |
| 6. State machines must have: states (with one isInitial:true), events, transitions | |
| 7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}] | |
| 8. Pages must have: name, path, entity, traits""" | |
| return f"""<|im_start|>system | |
| {system_prompt} | |
| <|im_end|> | |
| <|im_start|>user | |
| {example['prompt']} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| {example['completion']} | |
| <|im_end|>""" | |
| def main( | |
| config: str = "config.yaml", | |
| debug: bool = False, | |
| max_steps: int = -1, | |
| resume_from_checkpoint: str = None, | |
| ): | |
| """Main training function.""" | |
| # Load configuration | |
| cfg = load_config(config) | |
| print("=" * 60) | |
| print("OrbGen Training") | |
| print("=" * 60) | |
| print(f"Base model: {cfg['model']['base_model']}") | |
| print(f"Output dir: {cfg['model']['output_dir']}") | |
| print(f"Debug mode: {debug}") | |
| print("=" * 60) | |
| # Initialize wandb | |
| if not debug: | |
| wandb.init( | |
| project=cfg['wandb']['project'], | |
| entity=cfg['wandb'].get('entity'), | |
| name=cfg['wandb']['run_name'], | |
| config=cfg, | |
| ) | |
| # Load tokenizer | |
| print("\nLoading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg['model']['base_model'], | |
| trust_remote_code=True, | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| # Load model with optional quantization | |
| print("Loading model...") | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": "auto", | |
| } | |
| # Check if 4-bit quantization is enabled | |
| quant_cfg = cfg.get('quantization', {}) | |
| if quant_cfg.get('enabled', False) and quant_cfg.get('load_in_4bit', False): | |
| print("Using 4-bit quantization (QLoRA)...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=getattr(torch, quant_cfg.get('bnb_4bit_compute_dtype', 'bfloat16')), | |
| bnb_4bit_quant_type=quant_cfg.get('bnb_4bit_quant_type', 'nf4'), | |
| bnb_4bit_use_double_quant=quant_cfg.get('bnb_4bit_use_double_quant', True), | |
| ) | |
| model_kwargs["quantization_config"] = bnb_config | |
| else: | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| cfg['model']['base_model'], | |
| **model_kwargs, | |
| ) | |
| # Prepare model for training | |
| model.config.use_cache = False | |
| # For quantized models, use prepare_model_for_kbit_training | |
| if quant_cfg.get('enabled', False): | |
| model = prepare_model_for_kbit_training(model) | |
| else: | |
| model.enable_input_require_grads() | |
| # Configure LoRA | |
| if cfg['lora']['enabled']: | |
| print("Configuring LoRA...") | |
| lora_config = LoraConfig( | |
| r=cfg['lora']['r'], | |
| lora_alpha=cfg['lora']['lora_alpha'], | |
| lora_dropout=cfg['lora']['lora_dropout'], | |
| target_modules=cfg['lora']['target_modules'], | |
| bias=cfg['lora']['bias'], | |
| task_type=TaskType.CAUSAL_LM, | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| # Load dataset | |
| print("\nLoading dataset...") | |
| # Support both HuggingFace dataset and local files | |
| if 'train_file' in cfg['data']: | |
| # Load from local JSONL files | |
| data_files = { | |
| 'train': cfg['data']['train_file'], | |
| 'validation': cfg['data']['eval_file'], | |
| } | |
| dataset = load_dataset('json', data_files=data_files) | |
| train_dataset = dataset['train'] | |
| eval_dataset = dataset['validation'] | |
| else: | |
| # Load from HuggingFace Hub | |
| dataset = load_dataset(cfg['data']['dataset']) | |
| train_dataset = dataset[cfg['data']['train_split']] | |
| eval_dataset = dataset[cfg['data']['eval_split']] | |
| print(f"Train examples: {len(train_dataset)}") | |
| print(f"Eval examples: {len(eval_dataset)}") | |
| # Format dataset | |
| def format_dataset(examples): | |
| texts = [] | |
| for i in range(len(examples['prompt'])): | |
| example = { | |
| 'prompt': examples['prompt'][i], | |
| 'completion': examples['completion'][i], | |
| } | |
| texts.append(format_example(example, tokenizer)) | |
| return {'text': texts} | |
| train_dataset = train_dataset.map( | |
| format_dataset, | |
| batched=True, | |
| remove_columns=train_dataset.column_names, | |
| ) | |
| eval_dataset = eval_dataset.map( | |
| format_dataset, | |
| batched=True, | |
| remove_columns=eval_dataset.column_names, | |
| ) | |
| # Training arguments | |
| training_args = SFTConfig( | |
| output_dir=cfg['model']['output_dir'], | |
| num_train_epochs=cfg['training']['num_epochs'] if not debug else 1, | |
| per_device_train_batch_size=cfg['training']['per_device_train_batch_size'], | |
| per_device_eval_batch_size=cfg['training']['per_device_eval_batch_size'], | |
| gradient_accumulation_steps=cfg['training']['gradient_accumulation_steps'], | |
| learning_rate=cfg['training']['learning_rate'], | |
| warmup_ratio=cfg['training']['warmup_ratio'], | |
| weight_decay=cfg['training']['weight_decay'], | |
| max_grad_norm=cfg['training']['max_grad_norm'], | |
| logging_steps=cfg['training']['logging_steps'], | |
| eval_strategy="steps", | |
| eval_steps=cfg['training']['eval_steps'], | |
| save_steps=cfg['training']['save_steps'], | |
| save_total_limit=cfg['training']['save_total_limit'], | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| max_length=cfg['model']['max_seq_length'], | |
| dataset_text_field="text", | |
| report_to="wandb" if not debug else "none", | |
| max_steps=max_steps if max_steps > 0 else -1, | |
| ) | |
| # Create trainer (TRL v0.27+ API) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| # Train | |
| print("\nStarting training...") | |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) | |
| # Save final model | |
| print("\nSaving model...") | |
| trainer.save_model(f"{cfg['model']['output_dir']}/final") | |
| tokenizer.save_pretrained(f"{cfg['model']['output_dir']}/final") | |
| # Push to HuggingFace Hub if configured | |
| hub_cfg = cfg.get('hub', {}) | |
| if hub_cfg.get('push_to_hub', False) and not debug: | |
| print("\nPushing model to HuggingFace Hub...") | |
| hub_model_id = hub_cfg.get('hub_model_id', 'orbital-ai/orbgen-1.5b') | |
| trainer.push_to_hub(commit_message="Final model after SFT training") | |
| print(f"Model pushed to: https://huggingface.co/{hub_model_id}") | |
| # Finish wandb | |
| if not debug: | |
| wandb.finish() | |
| print("\nTraining complete!") | |
| print(f"Model saved to: {cfg['model']['output_dir']}/final") | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |