import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from tokenizers import Tokenizer import json import math print("🚀 Starting Twitter Reply Bot...") # Load configuration with open('model_config.json', 'r') as f: config = json.load(f) print(f"✅ Config loaded: {config['vocab_size']} vocab, {config['d_model']} d_model") # Load tokenizer tokenizer = Tokenizer.from_file("twitter_tokenizer.json") print("✅ Tokenizer loaded") # ==================== EXACT MODEL ARCHITECTURE FROM TRAINING ==================== class RMSNorm(nn.Module): """Root Mean Square Layer Normalization""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) x_normed = x / rms return self.weight * x_normed class RotaryPositionEmbedding(nn.Module): """Rotary Position Embeddings (RoPE)""" def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, q, k): seq_len = q.shape[2] cos = self.cos_cached[:, :, :seq_len, ...] sin = self.sin_cached[:, :, :seq_len, ...] q_rot = (q * cos) + (self._rotate_half(q) * sin) k_rot = (k * cos) + (self._rotate_half(k) * sin) return q_rot, k_rot def _rotate_half(self, x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) class MultiHeadAttention(nn.Module): """Multi-Head Self Attention with RoPE""" def __init__(self, d_model: int, n_heads: int, max_seq_len: int): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len) def forward(self, x, mask=None): batch_size, seq_len, d_model = x.shape q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) q, k = self.rope(q, k) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model) return self.o_proj(attn_output) class SwiGLU(nn.Module): """SwiGLU Activation Function""" def __init__(self, d_model: int, d_ff: int): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) self.w3 = nn.Linear(d_model, d_ff, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): """Single Transformer Block""" def __init__(self, d_model: int, n_heads: int, d_ff: int, max_seq_len: int): super().__init__() self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len) self.feed_forward = SwiGLU(d_model, d_ff) self.norm1 = RMSNorm(d_model) self.norm2 = RMSNorm(d_model) def forward(self, x, mask=None): x = x + self.attention(self.norm1(x), mask) x = x + self.feed_forward(self.norm2(x)) return x class TwitterTransformer(nn.Module): """Twitter Reply Transformer Model - EXACT TRAINING ARCHITECTURE""" def __init__(self, vocab_size=8000, d_model=256, n_layers=6, n_heads=8, d_ff=1024, max_seq_len=128, pad_token_id=0): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_seq_len = max_seq_len self.pad_token_id = pad_token_id self.token_embedding = nn.Embedding(vocab_size, d_model) self.layers = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, max_seq_len) for _ in range(n_layers) ]) self.norm = RMSNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size, bias=False) # Weight tying self.lm_head.weight = self.token_embedding.weight def forward(self, input_ids, attention_mask=None): batch_size, seq_len = input_ids.shape # Create causal mask causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) causal_mask = causal_mask * attention_mask x = self.token_embedding(input_ids) for layer in self.layers: x = layer(x, causal_mask) x = self.norm(x) logits = self.lm_head(x) return logits @torch.no_grad() def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=50, eos_token_id=None): self.eval() for _ in range(max_new_tokens): input_ids_cropped = input_ids[:, -self.max_seq_len:] logits = self(input_ids_cropped) logits = logits[:, -1, :] / temperature if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) if eos_token_id is not None and next_token.item() == eos_token_id: break return input_ids # ==================== LOAD MODEL ==================== print("📥 Loading model...") model = TwitterTransformer( vocab_size=config['vocab_size'], d_model=config['d_model'], n_layers=6, # From your training n_heads=8, # From your training d_ff=1024, # From your training max_seq_len=config['max_seq_len'], pad_token_id=config['pad_token_id'] ) # Load weights model.load_state_dict(torch.load('twitter_reply_model_final.pt', map_location='cpu')) model.eval() print("✅ Model loaded successfully!") print(f"📊 Parameters: {sum(p.numel() for p in model.parameters()):,}") # ==================== GENERATION FUNCTION ==================== def generate_reply(tweet, personality, temperature, top_k): """Generate a reply to a tweet""" try: # Validate input if not tweet or len(tweet.strip()) < 3: return "⚠️ Please enter a valid tweet (at least 3 characters)" # Format input input_text = f"{personality}{tweet.strip()}" # Tokenize input_ids = [config['bos_token_id']] + tokenizer.encode(input_text).ids input_ids = torch.tensor([input_ids], dtype=torch.long) # Generate with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=50, temperature=max(0.5, min(temperature, 1.5)), # Clamp temperature top_k=int(top_k), eos_token_id=config['eos_token_id'] ) # Decode text = tokenizer.decode(output[0].tolist()) # Extract reply try: reply = text.split('')[1].split('[EOS]')[0].strip() # Remove any leftover special tokens reply = reply.replace('[BOS]', '').replace('[EOS]', '').replace('', '').strip() except: reply = text.strip() return reply if reply else "🤔 Hmm, try adjusting temperature or rephrasing the tweet!" except Exception as e: return f"❌ Error: {str(e)}\n\nTry refreshing the page or adjusting parameters." # ==================== GRADIO INTERFACE ==================== examples = [ ["Why is my internet so slow today?", "[HELPFUL]", 0.7, 40], ["Your customer service is terrible!", "[PROFESSIONAL]", 0.6, 40], ["I love your product!", "[WITTY]", 0.8, 50], ["This is the worst service ever", "[HUMOR]", 0.8, 40], ["How do I reset my password?", "[FRIENDLY]", 0.7, 40], ["My order hasn't arrived yet", "[PROFESSIONAL]", 0.6, 40], ] # Create interface with gr.Blocks(theme=gr.themes.Soft(), title="Twitter Reply Bot") as demo: gr.Markdown(""" # 🤖 Twitter Reply Bot ## 8.34M Parameter Custom Transformer Generate witty, contextual replies to tweets using an AI model trained from scratch on 100K customer service conversations. **Training Stats:** Final Loss: 3.43 | 3 Epochs | 15 mins on T4 GPU """) with gr.Row(): with gr.Column(scale=1): tweet_input = gr.Textbox( label="📱 Tweet", placeholder="Enter a tweet to reply to...", lines=4, max_lines=6 ) personality_dropdown = gr.Dropdown( choices=["[WITTY]", "[HUMOR]", "[FRIENDLY]", "[PROFESSIONAL]", "[HELPFUL]"], label="🎭 Reply Personality", value="[WITTY]", info="Choose the tone for the reply" ) with gr.Row(): temperature_slider = gr.Slider( minimum=0.5, maximum=1.2, value=0.7, step=0.1, label="🌡️ Temperature", info="Higher = more creative" ) top_k_slider = gr.Slider( minimum=10, maximum=100, value=40, step=10, label="🎯 Top-K", info="Token selection diversity" ) generate_btn = gr.Button("✨ Generate Reply", variant="primary", size="lg") with gr.Column(scale=1): output = gr.Textbox( label="🤖 Generated Reply", lines=6, max_lines=8, placeholder="Your AI-generated reply will appear here..." ) gr.Markdown(""" ### 💡 Personality Guide: - **🎭 WITTY**: Clever, playful, engaging - **😂 HUMOR**: Light-hearted, funny - **🤝 FRIENDLY**: Warm, conversational - **👔 PROFESSIONAL**: Formal, business tone - **🆘 HELPFUL**: Solution-focused, supportive ### ⚙️ Parameter Tips: - **Low temp (0.5-0.6)**: Consistent, safe replies - **Mid temp (0.7-0.8)**: Balanced creativity - **High temp (0.9-1.2)**: More creative, riskier """) # Examples section gr.Markdown("### 📝 Try These Examples:") gr.Examples( examples=examples, inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider], outputs=output, fn=generate_reply, cache_examples=False, ) # Connect button generate_btn.click( fn=generate_reply, inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider], outputs=output ) gr.Markdown(""" --- **⚡ Model Architecture:** Custom Transformer with RoPE + SwiGLU + RMSNorm **📊 Training Data:** 945K customer service tweets **🛠️ Built with:** PyTorch, Tokenizers, Gradio **🚀 Deployed on:** HuggingFace Spaces (Free CPU) """) # Launch if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )