Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import tiktoken | |
| import gradio as gr | |
| import math | |
| import os | |
| device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
| # Tokenizer setup | |
| enc = tiktoken.get_encoding("gpt2") | |
| vocab_size = enc.n_vocab + 1 # +1 for mask token | |
| mask_token_id = enc.n_vocab | |
| def encode(s): | |
| return enc.encode(s) | |
| def decode(l): | |
| return enc.decode([t for t in l if t != mask_token_id]) | |
| def format_masked_text(l): | |
| chunks = [] | |
| current_chunk = [] | |
| for t in l: | |
| if t == mask_token_id: | |
| if current_chunk: | |
| chunks.append(enc.decode(current_chunk)) | |
| current_chunk = [] | |
| chunks.append(" [MASK] ") | |
| else: | |
| current_chunk.append(t) | |
| if current_chunk: | |
| chunks.append(enc.decode(current_chunk)) | |
| return "".join(chunks) | |
| def norm(x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5) | |
| def apply_rotary_emb(x, cos, sin): | |
| assert x.ndim == 4 | |
| d = x.shape[3] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| out = torch.cat([y1, y2], 3) | |
| return out.to(x.dtype) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| def forward(self, x, cos_sin): | |
| B, T, C = x.size() | |
| q = self.c_q(x).view(B, T, self.config.n_head, self.config.head_dim) | |
| k = self.c_k(x).view(B, T, self.config.n_head, self.config.head_dim) | |
| v = self.c_v(x).view(B, T, self.config.n_head, self.config.head_dim) | |
| cos, sin = cos_sin | |
| q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) | |
| q, k = norm(q), norm(k) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False) | |
| y = y.transpose(1, 2).contiguous().view(B, T, -1) | |
| y = self.c_proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| hidden_dim = int(8 * config.n_embd / 3) | |
| self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) | |
| self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) | |
| def forward(self, x): | |
| return self.c_proj(F.silu(self.w1(x)) * self.w2(x)) | |
| class Block(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.attn = MultiHeadAttention(config) | |
| self.mlp = MLP(config) | |
| def forward(self, x, cos_sin): | |
| x = x + self.attn(norm(x), cos_sin) | |
| x = x + self.mlp(norm(x)) | |
| return x | |
| class Model(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.token_emb = nn.Embedding(vocab_size, config.n_embd) | |
| self.time_emb = nn.Sequential( | |
| nn.Linear(1, config.n_embd), | |
| nn.SiLU(), | |
| nn.Linear(config.n_embd, config.n_embd), | |
| ) | |
| self.rotary_seq_len = config.block_size * 2 | |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len) | |
| self.register_buffer("cos", cos, persistent=False) | |
| self.register_buffer("sin", sin, persistent=False) | |
| self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) | |
| self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False) | |
| self.lm_head.weight = self.token_emb.weight # tie weights | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def _precompute_rotary_embeddings(self, seq_len, base=10000, device=None): | |
| if device is None: | |
| device = self.token_emb.weight.device | |
| channel_range = torch.arange(0, self.config.head_dim, 2, dtype=torch.float32, device=device) | |
| inv_freq = 1.0 / (base ** (channel_range / self.config.head_dim)) | |
| t = torch.arange(seq_len, dtype=torch.float32, device=device) | |
| freqs = torch.outer(t, inv_freq) | |
| cos, sin = freqs.cos(), freqs.sin() | |
| cos, sin = cos[None, :, None, :], sin[None, :, None, :] | |
| return cos, sin | |
| def forward(self, idx, targets=None, mask=None, mask_rate=None): | |
| B, T = idx.size() | |
| x = self.token_emb(idx) | |
| if mask_rate is not None: | |
| t = mask_rate.float().unsqueeze(-1) # (B, 1, 1) | |
| x = x + self.time_emb(t) | |
| x = norm(x) | |
| cos_sin = (self.cos[:, :T], self.sin[:, :T]) | |
| for block in self.blocks: | |
| x = block(x, cos_sin) | |
| x = norm(x) | |
| logits = self.lm_head(x) | |
| if targets is None: | |
| loss = None | |
| else: | |
| B, T, C = logits.shape | |
| logits_flat = logits.view(B * T, C) | |
| targets_flat = targets.view(B * T) | |
| if mask is not None: | |
| mask_flat = mask.view(B * T) | |
| loss = F.cross_entropy(logits_flat, targets_flat, reduction="none") | |
| loss = (loss * mask_flat).sum() / mask_flat.sum() | |
| else: | |
| loss = F.cross_entropy(logits_flat, targets_flat) | |
| return logits, loss | |
| class Config: | |
| def __init__(self, model_type): | |
| self.block_size = 512 | |
| if model_type == 'medium': | |
| self.n_embd = 512 | |
| self.n_head = 8 | |
| self.n_layer = 8 | |
| self.weights_path = "tinystories_diffusion_med_dual.pt" | |
| elif model_type == 'gpt2': | |
| self.n_embd = 768 | |
| self.n_head = 12 | |
| self.n_layer = 12 | |
| self.weights_path = "tinystories_diffusion_GPT2_dual.pt" | |
| else: | |
| raise ValueError("model_type must be 'medium' or 'gpt2'") | |
| self.head_dim = self.n_embd // self.n_head | |
| # Dynamic loading | |
| loaded_model_type = None | |
| loaded_model = None | |
| def get_model(model_type): | |
| global loaded_model_type, loaded_model | |
| if loaded_model_type == model_type and loaded_model is not None: | |
| return loaded_model, Config(model_type) | |
| print(f"Loading {model_type} model...") | |
| config = Config(model_type) | |
| model = Model(config) | |
| weights_path = config.weights_path | |
| if os.path.exists(weights_path): | |
| state_dict = torch.load(weights_path, map_location=device, weights_only=True) | |
| unwrapped_state_dict = {} | |
| for k, v in state_dict.items(): | |
| # Handle 'module.' prefix from DataParallel if present | |
| if k.startswith("module."): | |
| unwrapped_state_dict[k[7:]] = v | |
| else: | |
| unwrapped_state_dict[k] = v | |
| model.load_state_dict(unwrapped_state_dict) | |
| print("Model loaded successfully!") | |
| else: | |
| print(f"Warning: {weights_path} not found. Running with uninitialized random parameters.") | |
| model.to(device) | |
| loaded_model = model | |
| loaded_model_type = model_type | |
| return model, config | |
| def generate_diffusion(prompt, max_new_tokens=100, mode="Direct Output", model_type="medium"): | |
| model, config = get_model(model_type) | |
| prompt_tokens = encode(prompt) | |
| model.eval() | |
| prompt_len = len(prompt_tokens) | |
| all_tokens = prompt_tokens.copy() | |
| temp = 1.0 | |
| confidence_threshold = 0.95 | |
| top_k = 3 | |
| while len(all_tokens) - len(prompt_tokens) < max_new_tokens: | |
| curr_prompt_len = len(all_tokens) | |
| block_len = min(config.block_size - curr_prompt_len, len(prompt_tokens) + max_new_tokens - len(all_tokens)) | |
| if block_len <= 0: break | |
| x = torch.full((1, config.block_size), mask_token_id, dtype=torch.long, device=device) | |
| x[0, :curr_prompt_len] = torch.tensor(all_tokens[-curr_prompt_len:], device=device) | |
| masked = torch.zeros(1, config.block_size, dtype=torch.bool, device=device) | |
| masked[0, curr_prompt_len : curr_prompt_len + block_len] = True | |
| while masked.any(): | |
| logits, _ = model(x) | |
| probs = F.softmax(logits / temp, dim=-1) | |
| top_k_probs, top_k_indices = torch.topk(probs, k=top_k, dim=-1) | |
| confidences = top_k_probs.sum(dim=-1) | |
| decode_mask = (confidences >= confidence_threshold) & masked | |
| if not decode_mask.any(): | |
| masked_confidences = torch.where(masked, confidences, torch.tensor(-float('inf')).to(device)) | |
| decode_mask.view(-1)[masked_confidences.argmax()] = True | |
| top_k_probs_norm = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) | |
| sampled_k = torch.multinomial(top_k_probs_norm.view(-1, top_k), 1).view(1, config.block_size) | |
| sampled_tokens = torch.gather(top_k_indices, -1, sampled_k.unsqueeze(-1)).squeeze(-1) | |
| x = torch.where(decode_mask, sampled_tokens, x) | |
| masked = masked & ~decode_mask | |
| if mode == "Show Generation Process": | |
| current_block = x[0, curr_prompt_len : curr_prompt_len + block_len].tolist() | |
| yield format_masked_text(all_tokens + current_block) | |
| all_tokens.extend(x[0, curr_prompt_len : curr_prompt_len + block_len].tolist()) | |
| full_output = decode(all_tokens) | |
| yield full_output | |
| def gradio_fn(prompt, display_mode, max_tokens, model_type): | |
| for text in generate_diffusion(prompt, max_new_tokens=max_tokens, mode=display_mode, model_type=model_type): | |
| yield text | |
| # Gradio | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
| gr.Markdown("# TinyStories Diffusion LM") | |
| gr.Markdown("A non-autoregressive language model leveraging parallel block-decoding and SwiGLU networks.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_in = gr.Textbox(lines=2, placeholder="Once upon a time, there was a little girl who", label="Prompt (approx 10 words)") | |
| model_type_in = gr.Radio(["medium", "gpt2"], value="medium", label="Model Architecture") | |
| mode = gr.Radio(["Direct Output", "Show Generation Process"], value="Direct Output", label="Display Mode") | |
| max_tokens = gr.Slider(minimum=20, maximum=1000, value=100, step=1, label="Max Tokens") | |
| generate_btn = gr.Button("Generate Story", variant='primary') | |
| with gr.Column(): | |
| output = gr.Textbox(lines=10, label="Output") | |
| generate_btn.click(fn=gradio_fn, inputs=[prompt_in, mode, max_tokens, model_type_in], outputs=output) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |