Redhanuman's picture
Update app.py
d245330 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import Tokenizer
import json
import math
print("πŸš€ Starting Twitter Reply Bot...")
# Load configuration
with open('model_config.json', 'r') as f:
config = json.load(f)
print(f"βœ… Config loaded: {config['vocab_size']} vocab, {config['d_model']} d_model")
# Load tokenizer
tokenizer = Tokenizer.from_file("twitter_tokenizer.json")
print("βœ… Tokenizer loaded")
# ==================== EXACT MODEL ARCHITECTURE FROM TRAINING ====================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
x_normed = x / rms
return self.weight * x_normed
class RotaryPositionEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE)"""
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, q, k):
seq_len = q.shape[2]
cos = self.cos_cached[:, :, :seq_len, ...]
sin = self.sin_cached[:, :, :seq_len, ...]
q_rot = (q * cos) + (self._rotate_half(q) * sin)
k_rot = (k * cos) + (self._rotate_half(k) * sin)
return q_rot, k_rot
def _rotate_half(self, x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
class MultiHeadAttention(nn.Module):
"""Multi-Head Self Attention with RoPE"""
def __init__(self, d_model: int, n_heads: int, max_seq_len: int):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rope(q, k)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.o_proj(attn_output)
class SwiGLU(nn.Module):
"""SwiGLU Activation Function"""
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
"""Single Transformer Block"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, max_seq_len: int):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len)
self.feed_forward = SwiGLU(d_model, d_ff)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
def forward(self, x, mask=None):
x = x + self.attention(self.norm1(x), mask)
x = x + self.feed_forward(self.norm2(x))
return x
class TwitterTransformer(nn.Module):
"""Twitter Reply Transformer Model - EXACT TRAINING ARCHITECTURE"""
def __init__(self, vocab_size=8000, d_model=256, n_layers=6, n_heads=8,
d_ff=1024, max_seq_len=128, pad_token_id=0):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_seq_len = max_seq_len
self.pad_token_id = pad_token_id
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, max_seq_len)
for _ in range(n_layers)
])
self.norm = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.token_embedding.weight
def forward(self, input_ids, attention_mask=None):
batch_size, seq_len = input_ids.shape
# Create causal mask
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
causal_mask = causal_mask * attention_mask
x = self.token_embedding(input_ids)
for layer in self.layers:
x = layer(x, causal_mask)
x = self.norm(x)
logits = self.lm_head(x)
return logits
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=50, eos_token_id=None):
self.eval()
for _ in range(max_new_tokens):
input_ids_cropped = input_ids[:, -self.max_seq_len:]
logits = self(input_ids_cropped)
logits = logits[:, -1, :] / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if eos_token_id is not None and next_token.item() == eos_token_id:
break
return input_ids
# ==================== LOAD MODEL ====================
print("πŸ“₯ Loading model...")
model = TwitterTransformer(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
n_layers=6, # From your training
n_heads=8, # From your training
d_ff=1024, # From your training
max_seq_len=config['max_seq_len'],
pad_token_id=config['pad_token_id']
)
# Load weights
model.load_state_dict(torch.load('twitter_reply_model_final.pt', map_location='cpu'))
model.eval()
print("βœ… Model loaded successfully!")
print(f"πŸ“Š Parameters: {sum(p.numel() for p in model.parameters()):,}")
# ==================== GENERATION FUNCTION ====================
def generate_reply(tweet, personality, temperature, top_k):
"""Generate a reply to a tweet"""
try:
# Validate input
if not tweet or len(tweet.strip()) < 3:
return "⚠️ Please enter a valid tweet (at least 3 characters)"
# Format input
input_text = f"{personality}{tweet.strip()}<SEP>"
# Tokenize
input_ids = [config['bos_token_id']] + tokenizer.encode(input_text).ids
input_ids = torch.tensor([input_ids], dtype=torch.long)
# Generate
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=50,
temperature=max(0.5, min(temperature, 1.5)), # Clamp temperature
top_k=int(top_k),
eos_token_id=config['eos_token_id']
)
# Decode
text = tokenizer.decode(output[0].tolist())
# Extract reply
try:
reply = text.split('<SEP>')[1].split('[EOS]')[0].strip()
# Remove any leftover special tokens
reply = reply.replace('[BOS]', '').replace('[EOS]', '').replace('<SEP>', '').strip()
except:
reply = text.strip()
return reply if reply else "πŸ€” Hmm, try adjusting temperature or rephrasing the tweet!"
except Exception as e:
return f"❌ Error: {str(e)}\n\nTry refreshing the page or adjusting parameters."
# ==================== GRADIO INTERFACE ====================
examples = [
["Why is my internet so slow today?", "[HELPFUL]", 0.7, 40],
["Your customer service is terrible!", "[PROFESSIONAL]", 0.6, 40],
["I love your product!", "[WITTY]", 0.8, 50],
["This is the worst service ever", "[HUMOR]", 0.8, 40],
["How do I reset my password?", "[FRIENDLY]", 0.7, 40],
["My order hasn't arrived yet", "[PROFESSIONAL]", 0.6, 40],
]
# Create interface
with gr.Blocks(theme=gr.themes.Soft(), title="Twitter Reply Bot") as demo:
gr.Markdown("""
# πŸ€– Twitter Reply Bot
## 8.34M Parameter Custom Transformer
Generate witty, contextual replies to tweets using an AI model trained from scratch on 100K customer service conversations.
**Training Stats:** Final Loss: 3.43 | 3 Epochs | 15 mins on T4 GPU
""")
with gr.Row():
with gr.Column(scale=1):
tweet_input = gr.Textbox(
label="πŸ“± Tweet",
placeholder="Enter a tweet to reply to...",
lines=4,
max_lines=6
)
personality_dropdown = gr.Dropdown(
choices=["[WITTY]", "[HUMOR]", "[FRIENDLY]", "[PROFESSIONAL]", "[HELPFUL]"],
label="🎭 Reply Personality",
value="[WITTY]",
info="Choose the tone for the reply"
)
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.5,
maximum=1.2,
value=0.7,
step=0.1,
label="🌑️ Temperature",
info="Higher = more creative"
)
top_k_slider = gr.Slider(
minimum=10,
maximum=100,
value=40,
step=10,
label="🎯 Top-K",
info="Token selection diversity"
)
generate_btn = gr.Button("✨ Generate Reply", variant="primary", size="lg")
with gr.Column(scale=1):
output = gr.Textbox(
label="πŸ€– Generated Reply",
lines=6,
max_lines=8,
placeholder="Your AI-generated reply will appear here..."
)
gr.Markdown("""
### πŸ’‘ Personality Guide:
- **🎭 WITTY**: Clever, playful, engaging
- **πŸ˜‚ HUMOR**: Light-hearted, funny
- **🀝 FRIENDLY**: Warm, conversational
- **πŸ‘” PROFESSIONAL**: Formal, business tone
- **πŸ†˜ HELPFUL**: Solution-focused, supportive
### βš™οΈ Parameter Tips:
- **Low temp (0.5-0.6)**: Consistent, safe replies
- **Mid temp (0.7-0.8)**: Balanced creativity
- **High temp (0.9-1.2)**: More creative, riskier
""")
# Examples section
gr.Markdown("### πŸ“ Try These Examples:")
gr.Examples(
examples=examples,
inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider],
outputs=output,
fn=generate_reply,
cache_examples=False,
)
# Connect button
generate_btn.click(
fn=generate_reply,
inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider],
outputs=output
)
gr.Markdown("""
---
**⚑ Model Architecture:** Custom Transformer with RoPE + SwiGLU + RMSNorm
**πŸ“Š Training Data:** 945K customer service tweets
**πŸ› οΈ Built with:** PyTorch, Tokenizers, Gradio
**πŸš€ Deployed on:** HuggingFace Spaces (Free CPU)
""")
# Launch
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)