image_generator / train_moe_conditional.py
Kyryll Kochkin
minor fixes
2492b59
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import ConditionalMNISTDataset
import os
from dataclasses import dataclass, field
from typing import Optional
from tqdm import tqdm
from torchvision import transforms
import torchvision.datasets as datasets
import lion_pytorch
# -------------------------------------------------------------------
# Config classes
# -------------------------------------------------------------------
@dataclass
class MoEPixelTransformerConfig:
vocab_size: int = 10 # for MNIST digits
image_size: int = 28
n_layers: int = 8
d_model: int = 256
n_heads: int = 8
dropout: float = 0.1
gating_dropout: float = 0.1
expert_count: int = 4
expert_capacity: int = 4
max_position_embeddings: int = 28 * 28
lr: float = 1e-3
batch_size: int = 64
epochs: int = 10
warmup_steps: int = 500
device: str = field(default="mps" if torch.backends.mps.is_available() else "cpu")
@classmethod
def from_pretrained(cls, path: str):
# placeholder logic for loading config
# adapt to your real scenario
config_path = os.path.join(path, "config.pt")
if not os.path.exists(config_path):
raise ValueError(f"No config found at {config_path}")
config_dict = torch.load(config_path)
return cls(**config_dict)
def save_pretrained(self, path: str):
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.pt")
torch.save(self.__dict__, config_path)
# -------------------------------------------------------------------
# Mixture-of-Experts Block
# -------------------------------------------------------------------
class Expert(nn.Module):
def __init__(self, d_model):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model)
)
def forward(self, x):
return self.net(x)
class GatingNetwork(nn.Module):
def __init__(self, d_model, expert_count, gating_dropout=0.1):
super().__init__()
self.expert_count = expert_count
self.linear = nn.Linear(d_model, expert_count)
self.dropout = nn.Dropout(gating_dropout)
def forward(self, hidden_states):
# hidden_states: (batch, seq_len, d_model)
# gating logits
logits = self.linear(hidden_states) # (batch, seq_len, expert_count)
logits = self.dropout(logits)
return logits
class MoEBlock(nn.Module):
def __init__(self, config: MoEPixelTransformerConfig):
super().__init__()
self.experts = nn.ModuleList([Expert(config.d_model) for _ in range(config.expert_count)])
self.gating_network = GatingNetwork(config.d_model, config.expert_count, gating_dropout=config.gating_dropout)
self.expert_count = config.expert_count
self.layernorm = nn.LayerNorm(config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, hidden_states):
# hidden_states shape: (B, Seq, d_model)
normed = self.layernorm(hidden_states)
gating_logits = self.gating_network(normed) # (B, Seq, expert_count)
# Softmax over experts
gates = torch.softmax(gating_logits, dim=-1) # (B, Seq, expert_count)
# Weighted sum of experts
# For each token, compute a weighted combination of expert outputs
expert_outputs = []
for i, expert in enumerate(self.experts):
# Extract gating for expert i
gate_i = gates[..., i].unsqueeze(-1) # (B, Seq, 1)
# Expert forward
exp_out = expert(normed) # (B, Seq, d_model)
# Weighted output
expert_outputs.append(exp_out * gate_i)
combined = torch.stack(expert_outputs, dim=-1).sum(dim=-1) # (B, Seq, d_model)
hidden_states = hidden_states + self.dropout(combined) # Residual
return hidden_states
# -------------------------------------------------------------------
# Self-Attention Block
# -------------------------------------------------------------------
class SelfAttention(nn.Module):
def __init__(self, config: MoEPixelTransformerConfig):
super().__init__()
self.d_model = config.d_model
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
assert config.d_model % config.n_heads == 0, "d_model must be divisible by n_heads"
self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
self.o_proj = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(config.dropout)
self.register_buffer("mask", torch.tril(torch.ones(config.max_position_embeddings, config.max_position_embeddings))
.view(1,1, config.max_position_embeddings, config.max_position_embeddings))
def forward(self, x):
B, Seq, D = x.shape
qkv = self.qkv(x) # (B, Seq, 3*d_model)
q, k, v = qkv.split(D, dim=-1)
# reshape for multi-head
q = q.view(B, Seq, self.n_heads, self.head_dim).transpose(1,2)
k = k.view(B, Seq, self.n_heads, self.head_dim).transpose(1,2)
v = v.view(B, Seq, self.n_heads, self.head_dim).transpose(1,2)
# scaled dot-product
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, Seq, Seq)
# causal mask
attn_scores = attn_scores.masked_fill(self.mask[:,:,:Seq,:Seq] == 0, float('-inf'))
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = attn_weights @ v # (B, n_heads, Seq, head_dim)
out = out.transpose(1,2).contiguous().view(B, Seq, D)
out = self.o_proj(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, config: MoEPixelTransformerConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.d_model)
self.attn = SelfAttention(config)
self.dropout1 = nn.Dropout(config.dropout)
self.ln2 = nn.LayerNorm(config.d_model)
self.mlp = nn.Sequential(
nn.Linear(config.d_model, 4*config.d_model),
nn.GELU(),
nn.Linear(4*config.d_model, config.d_model)
)
self.dropout2 = nn.Dropout(config.dropout)
def forward(self, x):
# Self-attention
a = self.ln1(x)
x = x + self.dropout1(self.attn(a))
# FFN
m = self.ln2(x)
x = x + self.dropout2(self.mlp(m))
return x
# -------------------------------------------------------------------
# Full MoEPixelTransformer model
# -------------------------------------------------------------------
class MoEPixelTransformer(nn.Module):
def __init__(self, config: MoEPixelTransformerConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embedding = nn.Parameter(torch.zeros(1, config.max_position_embeddings, config.d_model))
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layers)]
)
# Insert a MoE block in the middle or after all blocks
self.moe_block = MoEBlock(config)
self.ln_f = nn.LayerNorm(config.d_model)
self.output_head = nn.Linear(config.d_model, 10) # Predict next pixel distribution from 10 discrete bins
def forward(self, x):
# x: (B, Seq) indices from 0..9 for each pixel
B, Seq = x.shape
token_emb = self.embedding(x) # (B, Seq, d_model)
position_emb = self.pos_embedding[:, :Seq, :]
hidden_states = token_emb + position_emb
for block in self.transformer_blocks:
hidden_states = block(hidden_states)
# MoE block after transformer
hidden_states = self.moe_block(hidden_states)
hidden_states = self.ln_f(hidden_states)
logits = self.output_head(hidden_states) # (B, Seq, 10)
return logits
def generate_digit_stream(self, digit: int):
"""
Generator function: yields a pixel at a time.
Suppose you have a decode method or iterative sample method.
"""
self.eval()
# Start generation with empty or some special tokens
seq = [digit] # or any label conditioning
for pos in range(1, self.config.image_size * self.config.image_size + 1):
x_in = torch.tensor([seq], dtype=torch.long, device="cpu") # shape (1, len(seq))
with torch.no_grad():
logits = self.forward(x_in)
# Get the last position's distribution
last_logits = logits[0, -1, :] # shape (10,)
probs = torch.softmax(last_logits, dim=-1)
# Sample or argmax
next_token = torch.multinomial(probs, num_samples=1).item()
seq.append(next_token)
yield next_token
@classmethod
def from_pretrained(
cls,
path: str,
config: MoEPixelTransformerConfig = None,
device: str = "cpu",
):
"""Load a saved model onto the specified device (default CPU).
Checkpoints saved on macOS may reference the ``mps`` device. Loading
them on environments without MPS support would previously raise
``torch.UntypedStorage`` errors. We load weights on CPU first and then
move the model to the requested device to ensure compatibility."""
if config is None:
config = MoEPixelTransformerConfig.from_pretrained(path)
# Update config to reflect the runtime device
config.device = device
model_path = os.path.join(path, "model.pt")
model = cls(config)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.to(device)
return model
def save_pretrained(self, path: str):
os.makedirs(path, exist_ok=True)
model_path = os.path.join(path, "model.pt")
torch.save(self.state_dict(), model_path)
self.config.save_pretrained(path)
# -------------------------------------------------------------------
# Training Code
# -------------------------------------------------------------------
def train_moe_pixel_transformer(config: MoEPixelTransformerConfig):
# Simple MNIST-based dataset, discretize pixel values into 10 bins.
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
model = MoEPixelTransformer(config).to(config.device)
#print(f"Model params: ")
#optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.01)
optimizer = lion_pytorch.Lion(model.parameters(), lr=config.lr, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
# Simple linear warmup + decay scheduler
total_steps = config.epochs * len(train_loader)
warmup_steps = config.warmup_steps
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
return max(0.0, float(total_steps - step) / float(max(1, total_steps - warmup_steps)))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
model.train()
global_step = 0
try:
for epoch in range(config.epochs):
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
for i, (imgs, labels) in enumerate(pbar):
# Convert images to discrete bins 0..9
# Pixel values are in [0,1], multiply by 9
imgs = imgs.to(config.device)
imgs_discrete = torch.floor(imgs * 9).long().squeeze(1) # (B, 28, 28)
# Flatten images
B, H, W = imgs_discrete.shape
seq_length = H * W
imgs_discrete = imgs_discrete.view(B, seq_length)
# Forward
logits = model(imgs_discrete[:, :-1]) # predict next pixel
# Targets are shifted by 1
targets = imgs_discrete[:, 1:].contiguous()
# Flatten
logits = logits.view(-1, 10) # shape (B*(seq_length-1), 10)
targets = targets.view(-1) # shape (B*(seq_length-1))
loss = criterion(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
global_step += 1
except KeyboardInterrupt:
print("\nEmergency save triggered by keyboard interrupt...")
model.save_pretrained("my_moe_model")
print("Model saved to my_moe_model/")
return model
model.save_pretrained("my_moe_model")
return model
if __name__ == "__main__":
config = MoEPixelTransformerConfig(
epochs=1,
n_layers=8, #4
expert_count=64,
expert_capacity=4, ###
d_model=256,
batch_size=4, #16 #64
dropout=0.1,
gating_dropout=0.1,
lr=1e-3,
warmup_steps=500,
)
model = train_moe_pixel_transformer(config)
model.save_pretrained("my_moe_model")