likhonsheikh's picture
Upload folder using huggingface_hub
b9b1e87 verified
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import json
from typing import Dict, List, Optional, Any, Tuple
from pathlib import Path
import wandb
from accelerate import Accelerator
from transformers import get_cosine_schedule_with_warmup
import logging
from ..configs.config import Config, TrainingConfig
from ..architecture.model import CompactAIModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextDataset(Dataset):
"""Dataset for text training data."""
def __init__(self, data: List[Dict[str, Any]], tokenizer=None, max_length: int = 1024):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# Handle different data formats
if isinstance(item, dict) and "text" in item:
text = item["text"]
elif isinstance(item, str):
text = item
else:
raise ValueError(f"Unsupported data format: {type(item)}")
# Tokenize if tokenizer is provided
if self.tokenizer:
tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True, padding="max_length")
return {
"input_ids": torch.tensor(tokens, dtype=torch.long),
"attention_mask": torch.tensor([1] * len(tokens), dtype=torch.long),
}
else:
# Return raw text for processing later
return {"text": text}
def create_sample_data(num_samples: int = 1000) -> List[Dict[str, str]]:
"""Create sample training data for demonstration."""
import random
templates = [
"Question: {question}\nAnswer: {answer}",
"Solve: {problem}\nSolution: {solution}",
"Explain: {topic}\nExplanation: {explanation}",
"Translate: {text}\nTranslation: {translation}",
]
questions = [
"What is 2 + 2?", "What is the capital of France?", "How does photosynthesis work?",
"What is machine learning?", "Explain quantum computing", "What is the speed of light?"
]
answers = [
"4", "Paris", "Plants convert sunlight into energy using chlorophyll",
"A type of artificial intelligence", "Computing using quantum mechanics",
"Approximately 299,792,458 meters per second"
]
data = []
for i in range(num_samples):
template = random.choice(templates)
if "{question}" in template:
question = random.choice(questions)
answer = random.choice(answers)
text = template.format(question=question, answer=answer)
elif "{problem}" in template:
text = template.format(problem="2x + 5 = 15", solution="x = 5")
elif "{topic}" in template:
text = template.format(topic="gravity", explanation="The force that attracts objects with mass")
else:
text = template.format(text="Hello", translation="Hola")
data.append({"text": text})
return data
class Trainer:
"""Training class for the compact AI model."""
def __init__(
self,
model: CompactAIModel,
training_config: TrainingConfig,
accelerator: Optional[Accelerator] = None,
use_wandb: bool = False,
output_dir: str = "checkpoints"
):
self.model = model
self.config = training_config
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
# Initialize accelerator
if accelerator is None:
accelerator = Accelerator(
mixed_precision="fp16" if training_config.mixed_precision else "no",
gradient_accumulation_steps=training_config.gradient_accumulation_steps,
)
self.accelerator = accelerator
# Prepare model
self.model = self.accelerator.prepare(self.model)
# Optimizer
self.optimizer = AdamW(
self.model.parameters(),
lr=training_config.learning_rate,
weight_decay=training_config.weight_decay,
)
self.optimizer = self.accelerator.prepare(self.optimizer)
# Learning rate scheduler
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps=training_config.warmup_steps,
num_training_steps=training_config.num_epochs * 1000, # Rough estimate
)
# Loss function
self.criterion = nn.CrossEntropyLoss()
# Initialize wandb
self.use_wandb = use_wandb
if use_wandb:
wandb.init(project="compact-ai-model", config=training_config.__dict__)
# Training state
self.global_step = 0
self.best_loss = float('inf')
def save_checkpoint(self, epoch: int, loss: float):
"""Save model checkpoint."""
checkpoint_path = self.output_dir / f"checkpoint_epoch_{epoch}"
checkpoint_path.mkdir(exist_ok=True)
# Save model
unwrapped_model = self.accelerator.unwrap_model(self.model)
torch.save(unwrapped_model.state_dict(), checkpoint_path / "pytorch_model.bin")
# Save optimizer state
torch.save(self.optimizer.state_dict(), checkpoint_path / "optimizer.bin")
# Save training state
training_state = {
"epoch": epoch,
"global_step": self.global_step,
"best_loss": self.best_loss,
"current_loss": loss,
}
with open(checkpoint_path / "training_state.json", "w") as f:
json.dump(training_state, f)
logger.info(f"Saved checkpoint to {checkpoint_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint."""
checkpoint_path = Path(checkpoint_path)
# Load model state
model_state = torch.load(checkpoint_path / "pytorch_model.bin", map_location="cpu")
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.load_state_dict(model_state)
# Load optimizer state
optimizer_state = torch.load(checkpoint_path / "optimizer.bin", map_location="cpu")
self.optimizer.load_state_dict(optimizer_state)
# Load training state
with open(checkpoint_path / "training_state.json", "r") as f:
training_state = json.load(f)
self.global_step = training_state["global_step"]
self.best_loss = training_state["best_loss"]
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def train_epoch(self, train_loader: DataLoader) -> float:
"""Train for one epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
for batch_idx, batch in enumerate(train_loader):
with self.accelerator.accumulate(self.model):
# Forward pass
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
outputs = self.model(input_ids, attention_mask, use_thinking=True)
logits = outputs["logits"]
# Shift for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
# Compute loss
loss = self.criterion(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# Backward pass
self.accelerator.backward(loss)
# Gradient clipping
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Optimizer step
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
self.global_step += 1
# Logging
if batch_idx % self.config.log_interval == 0:
current_lr = self.lr_scheduler.get_last_lr()[0]
logger.info(
f"Step {self.global_step}: Loss = {loss.item():.4f}, LR = {current_lr:.6f}"
)
if self.use_wandb:
wandb.log({
"train/loss": loss.item(),
"train/learning_rate": current_lr,
"train/global_step": self.global_step,
})
return total_loss / num_batches
def evaluate(self, eval_loader: DataLoader) -> float:
"""Evaluate the model."""
self.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in eval_loader:
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
outputs = self.model(input_ids, attention_mask, use_thinking=False) # Eval without thinking for speed
logits = outputs["logits"]
# Shift for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = self.criterion(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
if self.use_wandb:
wandb.log({"eval/loss": avg_loss})
return avg_loss
def train(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None):
"""Main training loop."""
logger.info("Starting training...")
for epoch in range(self.config.num_epochs):
logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
# Train
train_loss = self.train_epoch(train_loader)
# Evaluate
if eval_loader is not None:
eval_loss = self.evaluate(eval_loader)
logger.info(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Eval Loss = {eval_loss:.4f}")
# Save best model
if eval_loss < self.best_loss:
self.best_loss = eval_loss
self.save_checkpoint(epoch, eval_loss)
# Save regular checkpoints
if (epoch + 1) % 5 == 0:
self.save_checkpoint(epoch, train_loss)
logger.info("Training completed!")
def main():
"""Main training function."""
import argparse
parser = argparse.ArgumentParser(description="Train Compact AI Model")
parser.add_argument("--data_path", type=str, default="training_data.json", help="Path to training data")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length")
parser.add_argument("--output_dir", type=str, default="checkpoints", help="Output directory")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases logging")
parser.add_argument("--model_size", type=str, default="small", choices=["tiny", "small", "medium"], help="Model size")
parser.add_argument("--resume_from", type=str, help="Resume training from checkpoint")
args = parser.parse_args()
# Create model
from ..architecture.model import create_compact_model
model = create_compact_model(args.model_size)
# Create training config
training_config = TrainingConfig(
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
)
# Initialize trainer
trainer = Trainer(
model=model,
training_config=training_config,
use_wandb=args.use_wandb,
output_dir=args.output_dir,
)
# Load data
if os.path.exists(args.data_path):
with open(args.data_path, "r") as f:
data = json.load(f)
else:
logger.info("Creating sample training data...")
data = create_sample_data(10000)
with open(args.data_path, "w") as f:
json.dump(data, f)
# Create dataset and dataloader
dataset = TextDataset(data, max_length=args.max_length)
train_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
)
# Resume training if specified
if args.resume_from:
trainer.load_checkpoint(args.resume_from)
# Start training
trainer.train(train_loader)
if __name__ == "__main__":
main()