Jyo-K's picture
Upload 4 files
3eeedde verified
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import math
import os
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
# Tokenizer setup
enc = tiktoken.get_encoding("gpt2")
vocab_size = enc.n_vocab + 1 # +1 for mask token
mask_token_id = enc.n_vocab
def encode(s):
return enc.encode(s)
def decode(l):
return enc.decode([t for t in l if t != mask_token_id])
def format_masked_text(l):
chunks = []
current_chunk = []
for t in l:
if t == mask_token_id:
if current_chunk:
chunks.append(enc.decode(current_chunk))
current_chunk = []
chunks.append(" [MASK] ")
else:
current_chunk.append(t)
if current_chunk:
chunks.append(enc.decode(current_chunk))
return "".join(chunks)
def norm(x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5)
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3)
return out.to(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(self, x, cos_sin):
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.config.n_head, self.config.head_dim)
k = self.c_k(x).view(B, T, self.config.n_head, self.config.head_dim)
v = self.c_v(x).view(B, T, self.config.n_head, self.config.head_dim)
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
hidden_dim = int(8 * config.n_embd / 3)
self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
def forward(self, x):
return self.c_proj(F.silu(self.w1(x)) * self.w2(x))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.attn = MultiHeadAttention(config)
self.mlp = MLP(config)
def forward(self, x, cos_sin):
x = x + self.attn(norm(x), cos_sin)
x = x + self.mlp(norm(x))
return x
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.token_emb = nn.Embedding(vocab_size, config.n_embd)
self.time_emb = nn.Sequential(
nn.Linear(1, config.n_embd),
nn.SiLU(),
nn.Linear(config.n_embd, config.n_embd),
)
self.rotary_seq_len = config.block_size * 2
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight # tie weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _precompute_rotary_embeddings(self, seq_len, base=10000, device=None):
if device is None:
device = self.token_emb.weight.device
channel_range = torch.arange(0, self.config.head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / self.config.head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def forward(self, idx, targets=None, mask=None, mask_rate=None):
B, T = idx.size()
x = self.token_emb(idx)
if mask_rate is not None:
t = mask_rate.float().unsqueeze(-1) # (B, 1, 1)
x = x + self.time_emb(t)
x = norm(x)
cos_sin = (self.cos[:, :T], self.sin[:, :T])
for block in self.blocks:
x = block(x, cos_sin)
x = norm(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits_flat = logits.view(B * T, C)
targets_flat = targets.view(B * T)
if mask is not None:
mask_flat = mask.view(B * T)
loss = F.cross_entropy(logits_flat, targets_flat, reduction="none")
loss = (loss * mask_flat).sum() / mask_flat.sum()
else:
loss = F.cross_entropy(logits_flat, targets_flat)
return logits, loss
class Config:
def __init__(self, model_type):
self.block_size = 512
if model_type == 'medium':
self.n_embd = 512
self.n_head = 8
self.n_layer = 8
self.weights_path = "tinystories_diffusion_med_dual.pt"
elif model_type == 'gpt2':
self.n_embd = 768
self.n_head = 12
self.n_layer = 12
self.weights_path = "tinystories_diffusion_GPT2_dual.pt"
else:
raise ValueError("model_type must be 'medium' or 'gpt2'")
self.head_dim = self.n_embd // self.n_head
# Dynamic loading
loaded_model_type = None
loaded_model = None
def get_model(model_type):
global loaded_model_type, loaded_model
if loaded_model_type == model_type and loaded_model is not None:
return loaded_model, Config(model_type)
print(f"Loading {model_type} model...")
config = Config(model_type)
model = Model(config)
weights_path = config.weights_path
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
unwrapped_state_dict = {}
for k, v in state_dict.items():
# Handle 'module.' prefix from DataParallel if present
if k.startswith("module."):
unwrapped_state_dict[k[7:]] = v
else:
unwrapped_state_dict[k] = v
model.load_state_dict(unwrapped_state_dict)
print("Model loaded successfully!")
else:
print(f"Warning: {weights_path} not found. Running with uninitialized random parameters.")
model.to(device)
loaded_model = model
loaded_model_type = model_type
return model, config
@torch.no_grad()
def generate_diffusion(prompt, max_new_tokens=100, mode="Direct Output", model_type="medium"):
model, config = get_model(model_type)
prompt_tokens = encode(prompt)
model.eval()
prompt_len = len(prompt_tokens)
all_tokens = prompt_tokens.copy()
temp = 1.0
confidence_threshold = 0.95
top_k = 3
while len(all_tokens) - len(prompt_tokens) < max_new_tokens:
curr_prompt_len = len(all_tokens)
block_len = min(config.block_size - curr_prompt_len, len(prompt_tokens) + max_new_tokens - len(all_tokens))
if block_len <= 0: break
x = torch.full((1, config.block_size), mask_token_id, dtype=torch.long, device=device)
x[0, :curr_prompt_len] = torch.tensor(all_tokens[-curr_prompt_len:], device=device)
masked = torch.zeros(1, config.block_size, dtype=torch.bool, device=device)
masked[0, curr_prompt_len : curr_prompt_len + block_len] = True
while masked.any():
logits, _ = model(x)
probs = F.softmax(logits / temp, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k, dim=-1)
confidences = top_k_probs.sum(dim=-1)
decode_mask = (confidences >= confidence_threshold) & masked
if not decode_mask.any():
masked_confidences = torch.where(masked, confidences, torch.tensor(-float('inf')).to(device))
decode_mask.view(-1)[masked_confidences.argmax()] = True
top_k_probs_norm = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
sampled_k = torch.multinomial(top_k_probs_norm.view(-1, top_k), 1).view(1, config.block_size)
sampled_tokens = torch.gather(top_k_indices, -1, sampled_k.unsqueeze(-1)).squeeze(-1)
x = torch.where(decode_mask, sampled_tokens, x)
masked = masked & ~decode_mask
if mode == "Show Generation Process":
current_block = x[0, curr_prompt_len : curr_prompt_len + block_len].tolist()
yield format_masked_text(all_tokens + current_block)
all_tokens.extend(x[0, curr_prompt_len : curr_prompt_len + block_len].tolist())
full_output = decode(all_tokens)
yield full_output
def gradio_fn(prompt, display_mode, max_tokens, model_type):
for text in generate_diffusion(prompt, max_new_tokens=max_tokens, mode=display_mode, model_type=model_type):
yield text
# Gradio
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# TinyStories Diffusion LM")
gr.Markdown("A non-autoregressive language model leveraging parallel block-decoding and SwiGLU networks.")
with gr.Row():
with gr.Column():
prompt_in = gr.Textbox(lines=2, placeholder="Once upon a time, there was a little girl who", label="Prompt (approx 10 words)")
model_type_in = gr.Radio(["medium", "gpt2"], value="medium", label="Model Architecture")
mode = gr.Radio(["Direct Output", "Show Generation Process"], value="Direct Output", label="Display Mode")
max_tokens = gr.Slider(minimum=20, maximum=1000, value=100, step=1, label="Max Tokens")
generate_btn = gr.Button("Generate Story", variant='primary')
with gr.Column():
output = gr.Textbox(lines=10, label="Output")
generate_btn.click(fn=gradio_fn, inputs=[prompt_in, mode, max_tokens, model_type_in], outputs=output)
if __name__ == "__main__":
demo.queue().launch()