Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tokenizers import Tokenizer | |
| import json | |
| import math | |
| print("π Starting Twitter Reply Bot...") | |
| # Load configuration | |
| with open('model_config.json', 'r') as f: | |
| config = json.load(f) | |
| print(f"β Config loaded: {config['vocab_size']} vocab, {config['d_model']} d_model") | |
| # Load tokenizer | |
| tokenizer = Tokenizer.from_file("twitter_tokenizer.json") | |
| print("β Tokenizer loaded") | |
| # ==================== EXACT MODEL ARCHITECTURE FROM TRAINING ==================== | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization""" | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) | |
| x_normed = x / rms | |
| return self.weight * x_normed | |
| class RotaryPositionEmbedding(nn.Module): | |
| """Rotary Position Embeddings (RoPE)""" | |
| def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): | |
| super().__init__() | |
| self.dim = dim | |
| self.max_seq_len = max_seq_len | |
| self.base = base | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self._build_cache(max_seq_len) | |
| def _build_cache(self, seq_len: int): | |
| t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) | |
| freqs = torch.outer(t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | |
| self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | |
| def forward(self, q, k): | |
| seq_len = q.shape[2] | |
| cos = self.cos_cached[:, :, :seq_len, ...] | |
| sin = self.sin_cached[:, :, :seq_len, ...] | |
| q_rot = (q * cos) + (self._rotate_half(q) * sin) | |
| k_rot = (k * cos) + (self._rotate_half(k) * sin) | |
| return q_rot, k_rot | |
| def _rotate_half(self, x): | |
| x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] | |
| return torch.cat((-x2, x1), dim=-1) | |
| class MultiHeadAttention(nn.Module): | |
| """Multi-Head Self Attention with RoPE""" | |
| def __init__(self, d_model: int, n_heads: int, max_seq_len: int): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.o_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len) | |
| def forward(self, x, mask=None): | |
| batch_size, seq_len, d_model = x.shape | |
| q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| q, k = self.rope(q, k) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| attn_weights = F.softmax(scores, dim=-1) | |
| attn_output = torch.matmul(attn_weights, v) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model) | |
| return self.o_proj(attn_output) | |
| class SwiGLU(nn.Module): | |
| """SwiGLU Activation Function""" | |
| def __init__(self, d_model: int, d_ff: int): | |
| super().__init__() | |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) | |
| self.w2 = nn.Linear(d_ff, d_model, bias=False) | |
| self.w3 = nn.Linear(d_model, d_ff, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class TransformerBlock(nn.Module): | |
| """Single Transformer Block""" | |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, max_seq_len: int): | |
| super().__init__() | |
| self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len) | |
| self.feed_forward = SwiGLU(d_model, d_ff) | |
| self.norm1 = RMSNorm(d_model) | |
| self.norm2 = RMSNorm(d_model) | |
| def forward(self, x, mask=None): | |
| x = x + self.attention(self.norm1(x), mask) | |
| x = x + self.feed_forward(self.norm2(x)) | |
| return x | |
| class TwitterTransformer(nn.Module): | |
| """Twitter Reply Transformer Model - EXACT TRAINING ARCHITECTURE""" | |
| def __init__(self, vocab_size=8000, d_model=256, n_layers=6, n_heads=8, | |
| d_ff=1024, max_seq_len=128, pad_token_id=0): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.max_seq_len = max_seq_len | |
| self.pad_token_id = pad_token_id | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(d_model, n_heads, d_ff, max_seq_len) | |
| for _ in range(n_layers) | |
| ]) | |
| self.norm = RMSNorm(d_model) | |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) | |
| # Weight tying | |
| self.lm_head.weight = self.token_embedding.weight | |
| def forward(self, input_ids, attention_mask=None): | |
| batch_size, seq_len = input_ids.shape | |
| # Create causal mask | |
| causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)) | |
| causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |
| causal_mask = causal_mask * attention_mask | |
| x = self.token_embedding(input_ids) | |
| for layer in self.layers: | |
| x = layer(x, causal_mask) | |
| x = self.norm(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=50, eos_token_id=None): | |
| self.eval() | |
| for _ in range(max_new_tokens): | |
| input_ids_cropped = input_ids[:, -self.max_seq_len:] | |
| logits = self(input_ids_cropped) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k > 0: | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = float('-inf') | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| if eos_token_id is not None and next_token.item() == eos_token_id: | |
| break | |
| return input_ids | |
| # ==================== LOAD MODEL ==================== | |
| print("π₯ Loading model...") | |
| model = TwitterTransformer( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['d_model'], | |
| n_layers=6, # From your training | |
| n_heads=8, # From your training | |
| d_ff=1024, # From your training | |
| max_seq_len=config['max_seq_len'], | |
| pad_token_id=config['pad_token_id'] | |
| ) | |
| # Load weights | |
| model.load_state_dict(torch.load('twitter_reply_model_final.pt', map_location='cpu')) | |
| model.eval() | |
| print("β Model loaded successfully!") | |
| print(f"π Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # ==================== GENERATION FUNCTION ==================== | |
| def generate_reply(tweet, personality, temperature, top_k): | |
| """Generate a reply to a tweet""" | |
| try: | |
| # Validate input | |
| if not tweet or len(tweet.strip()) < 3: | |
| return "β οΈ Please enter a valid tweet (at least 3 characters)" | |
| # Format input | |
| input_text = f"{personality}{tweet.strip()}<SEP>" | |
| # Tokenize | |
| input_ids = [config['bos_token_id']] + tokenizer.encode(input_text).ids | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| # Generate | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=50, | |
| temperature=max(0.5, min(temperature, 1.5)), # Clamp temperature | |
| top_k=int(top_k), | |
| eos_token_id=config['eos_token_id'] | |
| ) | |
| # Decode | |
| text = tokenizer.decode(output[0].tolist()) | |
| # Extract reply | |
| try: | |
| reply = text.split('<SEP>')[1].split('[EOS]')[0].strip() | |
| # Remove any leftover special tokens | |
| reply = reply.replace('[BOS]', '').replace('[EOS]', '').replace('<SEP>', '').strip() | |
| except: | |
| reply = text.strip() | |
| return reply if reply else "π€ Hmm, try adjusting temperature or rephrasing the tweet!" | |
| except Exception as e: | |
| return f"β Error: {str(e)}\n\nTry refreshing the page or adjusting parameters." | |
| # ==================== GRADIO INTERFACE ==================== | |
| examples = [ | |
| ["Why is my internet so slow today?", "[HELPFUL]", 0.7, 40], | |
| ["Your customer service is terrible!", "[PROFESSIONAL]", 0.6, 40], | |
| ["I love your product!", "[WITTY]", 0.8, 50], | |
| ["This is the worst service ever", "[HUMOR]", 0.8, 40], | |
| ["How do I reset my password?", "[FRIENDLY]", 0.7, 40], | |
| ["My order hasn't arrived yet", "[PROFESSIONAL]", 0.6, 40], | |
| ] | |
| # Create interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Twitter Reply Bot") as demo: | |
| gr.Markdown(""" | |
| # π€ Twitter Reply Bot | |
| ## 8.34M Parameter Custom Transformer | |
| Generate witty, contextual replies to tweets using an AI model trained from scratch on 100K customer service conversations. | |
| **Training Stats:** Final Loss: 3.43 | 3 Epochs | 15 mins on T4 GPU | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| tweet_input = gr.Textbox( | |
| label="π± Tweet", | |
| placeholder="Enter a tweet to reply to...", | |
| lines=4, | |
| max_lines=6 | |
| ) | |
| personality_dropdown = gr.Dropdown( | |
| choices=["[WITTY]", "[HUMOR]", "[FRIENDLY]", "[PROFESSIONAL]", "[HELPFUL]"], | |
| label="π Reply Personality", | |
| value="[WITTY]", | |
| info="Choose the tone for the reply" | |
| ) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.2, | |
| value=0.7, | |
| step=0.1, | |
| label="π‘οΈ Temperature", | |
| info="Higher = more creative" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=40, | |
| step=10, | |
| label="π― Top-K", | |
| info="Token selection diversity" | |
| ) | |
| generate_btn = gr.Button("β¨ Generate Reply", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output = gr.Textbox( | |
| label="π€ Generated Reply", | |
| lines=6, | |
| max_lines=8, | |
| placeholder="Your AI-generated reply will appear here..." | |
| ) | |
| gr.Markdown(""" | |
| ### π‘ Personality Guide: | |
| - **π WITTY**: Clever, playful, engaging | |
| - **π HUMOR**: Light-hearted, funny | |
| - **π€ FRIENDLY**: Warm, conversational | |
| - **π PROFESSIONAL**: Formal, business tone | |
| - **π HELPFUL**: Solution-focused, supportive | |
| ### βοΈ Parameter Tips: | |
| - **Low temp (0.5-0.6)**: Consistent, safe replies | |
| - **Mid temp (0.7-0.8)**: Balanced creativity | |
| - **High temp (0.9-1.2)**: More creative, riskier | |
| """) | |
| # Examples section | |
| gr.Markdown("### π Try These Examples:") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider], | |
| outputs=output, | |
| fn=generate_reply, | |
| cache_examples=False, | |
| ) | |
| # Connect button | |
| generate_btn.click( | |
| fn=generate_reply, | |
| inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **β‘ Model Architecture:** Custom Transformer with RoPE + SwiGLU + RMSNorm | |
| **π Training Data:** 945K customer service tweets | |
| **π οΈ Built with:** PyTorch, Tokenizers, Gradio | |
| **π Deployed on:** HuggingFace Spaces (Free CPU) | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |