orbgen-training / train.py
javasop's picture
Upload folder using huggingface_hub
9791706 verified
#!/usr/bin/env python3
"""
OrbGen Training Script
Fine-tunes a base model to generate valid Orbital schemas (.orb files).
Usage:
python train.py --config config.yaml
python train.py --config config.yaml --debug --max_steps 100
"""
import os
import yaml
import fire
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
DataCollatorForSeq2Seq,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
import wandb
def load_config(config_path: str) -> dict:
"""Load configuration from YAML file."""
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def format_example(example: dict, tokenizer) -> str:
"""Format a single training example as a chat conversation."""
system_prompt = """You are OrbGen, a specialized AI that generates valid Orbital schemas (.orb files) from natural language descriptions.
Rules:
1. Output ONLY valid JSON - no explanations, no markdown code blocks
2. Every schema must have: name, version, orbitals array
3. Each orbital must have: name, entity, traits, pages
4. Each entity must have: name, collection (or runtime/singleton), fields
5. Each trait must have: name, category (interaction/integration), linkedEntity, stateMachine
6. State machines must have: states (with one isInitial:true), events, transitions
7. Use S-expression arrays for effects: ["set", "field", "value"], ["emit", "EVENT", {}], ["render-ui", "slot", {...}]
8. Pages must have: name, path, entity, traits"""
return f"""<|im_start|>system
{system_prompt}
<|im_end|>
<|im_start|>user
{example['prompt']}
<|im_end|>
<|im_start|>assistant
{example['completion']}
<|im_end|>"""
def main(
config: str = "config.yaml",
debug: bool = False,
max_steps: int = -1,
resume_from_checkpoint: str = None,
):
"""Main training function."""
# Load configuration
cfg = load_config(config)
print("=" * 60)
print("OrbGen Training")
print("=" * 60)
print(f"Base model: {cfg['model']['base_model']}")
print(f"Output dir: {cfg['model']['output_dir']}")
print(f"Debug mode: {debug}")
print("=" * 60)
# Initialize wandb
if not debug:
wandb.init(
project=cfg['wandb']['project'],
entity=cfg['wandb'].get('entity'),
name=cfg['wandb']['run_name'],
config=cfg,
)
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
cfg['model']['base_model'],
trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load model with optional quantization
print("Loading model...")
model_kwargs = {
"trust_remote_code": True,
"device_map": "auto",
}
# Check if 4-bit quantization is enabled
quant_cfg = cfg.get('quantization', {})
if quant_cfg.get('enabled', False) and quant_cfg.get('load_in_4bit', False):
print("Using 4-bit quantization (QLoRA)...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=getattr(torch, quant_cfg.get('bnb_4bit_compute_dtype', 'bfloat16')),
bnb_4bit_quant_type=quant_cfg.get('bnb_4bit_quant_type', 'nf4'),
bnb_4bit_use_double_quant=quant_cfg.get('bnb_4bit_use_double_quant', True),
)
model_kwargs["quantization_config"] = bnb_config
else:
model_kwargs["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
cfg['model']['base_model'],
**model_kwargs,
)
# Prepare model for training
model.config.use_cache = False
# For quantized models, use prepare_model_for_kbit_training
if quant_cfg.get('enabled', False):
model = prepare_model_for_kbit_training(model)
else:
model.enable_input_require_grads()
# Configure LoRA
if cfg['lora']['enabled']:
print("Configuring LoRA...")
lora_config = LoraConfig(
r=cfg['lora']['r'],
lora_alpha=cfg['lora']['lora_alpha'],
lora_dropout=cfg['lora']['lora_dropout'],
target_modules=cfg['lora']['target_modules'],
bias=cfg['lora']['bias'],
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Load dataset
print("\nLoading dataset...")
# Support both HuggingFace dataset and local files
if 'train_file' in cfg['data']:
# Load from local JSONL files
data_files = {
'train': cfg['data']['train_file'],
'validation': cfg['data']['eval_file'],
}
dataset = load_dataset('json', data_files=data_files)
train_dataset = dataset['train']
eval_dataset = dataset['validation']
else:
# Load from HuggingFace Hub
dataset = load_dataset(cfg['data']['dataset'])
train_dataset = dataset[cfg['data']['train_split']]
eval_dataset = dataset[cfg['data']['eval_split']]
print(f"Train examples: {len(train_dataset)}")
print(f"Eval examples: {len(eval_dataset)}")
# Format dataset
def format_dataset(examples):
texts = []
for i in range(len(examples['prompt'])):
example = {
'prompt': examples['prompt'][i],
'completion': examples['completion'][i],
}
texts.append(format_example(example, tokenizer))
return {'text': texts}
train_dataset = train_dataset.map(
format_dataset,
batched=True,
remove_columns=train_dataset.column_names,
)
eval_dataset = eval_dataset.map(
format_dataset,
batched=True,
remove_columns=eval_dataset.column_names,
)
# Training arguments
training_args = SFTConfig(
output_dir=cfg['model']['output_dir'],
num_train_epochs=cfg['training']['num_epochs'] if not debug else 1,
per_device_train_batch_size=cfg['training']['per_device_train_batch_size'],
per_device_eval_batch_size=cfg['training']['per_device_eval_batch_size'],
gradient_accumulation_steps=cfg['training']['gradient_accumulation_steps'],
learning_rate=cfg['training']['learning_rate'],
warmup_ratio=cfg['training']['warmup_ratio'],
weight_decay=cfg['training']['weight_decay'],
max_grad_norm=cfg['training']['max_grad_norm'],
logging_steps=cfg['training']['logging_steps'],
eval_strategy="steps",
eval_steps=cfg['training']['eval_steps'],
save_steps=cfg['training']['save_steps'],
save_total_limit=cfg['training']['save_total_limit'],
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=cfg['model']['max_seq_length'],
dataset_text_field="text",
report_to="wandb" if not debug else "none",
max_steps=max_steps if max_steps > 0 else -1,
)
# Create trainer (TRL v0.27+ API)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
# Train
print("\nStarting training...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# Save final model
print("\nSaving model...")
trainer.save_model(f"{cfg['model']['output_dir']}/final")
tokenizer.save_pretrained(f"{cfg['model']['output_dir']}/final")
# Push to HuggingFace Hub if configured
hub_cfg = cfg.get('hub', {})
if hub_cfg.get('push_to_hub', False) and not debug:
print("\nPushing model to HuggingFace Hub...")
hub_model_id = hub_cfg.get('hub_model_id', 'orbital-ai/orbgen-1.5b')
trainer.push_to_hub(commit_message="Final model after SFT training")
print(f"Model pushed to: https://huggingface.co/{hub_model_id}")
# Finish wandb
if not debug:
wandb.finish()
print("\nTraining complete!")
print(f"Model saved to: {cfg['model']['output_dir']}/final")
if __name__ == "__main__":
fire.Fire(main)