Redhanuman commited on
Commit
d245330
Β·
verified Β·
1 Parent(s): 7a019b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -118
app.py CHANGED
@@ -4,115 +4,223 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  from tokenizers import Tokenizer
6
  import json
 
 
 
7
 
8
  # Load configuration
9
  with open('model_config.json', 'r') as f:
10
  config = json.load(f)
 
11
 
12
  # Load tokenizer
13
  tokenizer = Tokenizer.from_file("twitter_tokenizer.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Model Architecture (copy from your training code)
16
  class TwitterTransformer(nn.Module):
17
- def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6,
18
- dim_feedforward=2048, max_seq_len=128, dropout=0.1,
19
- pad_token_id=0, bos_token_id=1, eos_token_id=2):
20
  super().__init__()
 
21
  self.d_model = d_model
 
22
  self.pad_token_id = pad_token_id
23
- self.bos_token_id = bos_token_id
24
- self.eos_token_id = eos_token_id
25
-
26
- self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
27
- self.pos_encoding = nn.Embedding(max_seq_len, d_model)
28
-
29
- decoder_layer = nn.TransformerDecoderLayer(
30
- d_model=d_model,
31
- nhead=nhead,
32
- dim_feedforward=dim_feedforward,
33
- dropout=dropout,
34
- batch_first=True
35
- )
36
- self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
37
- self.fc_out = nn.Linear(d_model, vocab_size)
38
- self.dropout = nn.Dropout(dropout)
39
 
40
- def forward(self, x, mask=None):
41
- batch_size, seq_len = x.shape
42
- positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
 
 
 
 
 
 
 
 
 
 
43
 
44
- x = self.embedding(x) * (self.d_model ** 0.5)
45
- x = x + self.pos_encoding(positions)
46
- x = self.dropout(x)
47
 
48
- if mask is None:
49
- mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=x.device)
 
50
 
51
- memory = torch.zeros_like(x)
52
- x = self.transformer(x, memory, tgt_mask=mask)
53
- return self.fc_out(x)
 
 
 
 
 
 
54
 
55
  @torch.no_grad()
56
- def generate(self, input_ids, max_new_tokens=50, temperature=1.0,
57
- top_k=50, top_p=0.9, eos_token_id=None):
58
  self.eval()
59
- eos_token_id = eos_token_id or self.eos_token_id
60
-
61
  for _ in range(max_new_tokens):
62
- logits = self(input_ids)
 
63
  logits = logits[:, -1, :] / temperature
64
 
65
- # Top-k filtering
66
  if top_k > 0:
67
- top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)))
68
- logits = torch.full_like(logits, float('-inf'))
69
- logits.scatter_(1, top_k_indices, top_k_logits)
70
-
71
- # Top-p filtering
72
- if top_p < 1.0:
73
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
74
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
75
- sorted_indices_to_remove = cumulative_probs > top_p
76
- sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
77
- sorted_indices_to_remove[:, 0] = False
78
-
79
- for i in range(logits.size(0)):
80
- indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]]
81
- logits[i, indices_to_remove] = float('-inf')
82
 
83
  probs = F.softmax(logits, dim=-1)
84
  next_token = torch.multinomial(probs, num_samples=1)
85
  input_ids = torch.cat([input_ids, next_token], dim=1)
86
 
87
- if next_token.item() == eos_token_id:
88
  break
89
-
90
  return input_ids
91
 
92
- # Load model
93
- print("Loading model...")
 
 
94
  model = TwitterTransformer(
95
  vocab_size=config['vocab_size'],
96
  d_model=config['d_model'],
97
- nhead=config['nhead'],
98
- num_layers=config['num_layers'],
99
- dim_feedforward=config['dim_feedforward'],
100
  max_seq_len=config['max_seq_len'],
101
- dropout=config['dropout'],
102
- pad_token_id=config['pad_token_id'],
103
- bos_token_id=config['bos_token_id'],
104
- eos_token_id=config['eos_token_id']
105
  )
106
 
 
107
  model.load_state_dict(torch.load('twitter_reply_model_final.pt', map_location='cpu'))
108
  model.eval()
109
- print("Model loaded successfully!")
110
 
111
- def generate_reply(tweet, personality, temperature, top_k, top_p):
112
- """Generate a reply to a tweet with specified personality"""
 
 
 
 
 
113
  try:
 
 
 
 
114
  # Format input
115
- input_text = f"{personality}{tweet}<SEP>"
116
 
117
  # Tokenize
118
  input_ids = [config['bos_token_id']] + tokenizer.encode(input_text).ids
@@ -123,62 +231,73 @@ def generate_reply(tweet, personality, temperature, top_k, top_p):
123
  output = model.generate(
124
  input_ids,
125
  max_new_tokens=50,
126
- temperature=temperature,
127
  top_k=int(top_k),
128
- top_p=top_p,
129
  eos_token_id=config['eos_token_id']
130
  )
131
 
132
  # Decode
133
  text = tokenizer.decode(output[0].tolist())
134
- reply = text.split('<SEP>')[1].split('[EOS]')[0].strip()
135
 
136
- return reply if reply else "Sorry, I couldn't generate a reply. Try adjusting the parameters!"
 
 
 
 
 
 
 
 
137
 
138
  except Exception as e:
139
- return f"Error: {str(e)}"
140
 
141
- # Examples for the interface
 
142
  examples = [
143
- ["Why is my internet so slow today?", "[HELPFUL]", 0.7, 40, 0.9],
144
- ["Your app keeps crashing!", "[PROFESSIONAL]", 0.6, 40, 0.9],
145
- ["I love your new feature!", "[FRIENDLY]", 0.8, 50, 0.9],
146
- ["This is the worst service ever", "[WITTY]", 0.8, 40, 0.9],
147
- ["How do I cancel my subscription?", "[HELPFUL]", 0.6, 40, 0.9],
 
148
  ]
149
 
150
- # Create Gradio interface
151
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
152
  gr.Markdown("""
153
- # πŸ€– Twitter Reply Bot - 8.34M Parameter Model
154
- Generate contextual, personality-driven replies to tweets using a custom-trained transformer model.
 
 
155
 
156
- **Model Stats:** 8.34M parameters | Training Loss: 3.43 | Trained on 100K tweet pairs
157
  """)
158
 
159
  with gr.Row():
160
- with gr.Column():
161
  tweet_input = gr.Textbox(
162
- label="Tweet",
163
  placeholder="Enter a tweet to reply to...",
164
- lines=3
 
165
  )
166
 
167
  personality_dropdown = gr.Dropdown(
168
  choices=["[WITTY]", "[HUMOR]", "[FRIENDLY]", "[PROFESSIONAL]", "[HELPFUL]"],
169
- label="Reply Personality",
170
  value="[WITTY]",
171
  info="Choose the tone for the reply"
172
  )
173
 
174
- with gr.Accordion("Advanced Settings", open=False):
175
  temperature_slider = gr.Slider(
176
  minimum=0.5,
177
- maximum=1.5,
178
  value=0.7,
179
  step=0.1,
180
- label="Temperature",
181
- info="Higher = more creative, Lower = more focused"
182
  )
183
 
184
  top_k_slider = gr.Slider(
@@ -186,56 +305,63 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
186
  maximum=100,
187
  value=40,
188
  step=10,
189
- label="Top-K",
190
- info="Number of top tokens to consider"
191
- )
192
-
193
- top_p_slider = gr.Slider(
194
- minimum=0.5,
195
- maximum=1.0,
196
- value=0.9,
197
- step=0.05,
198
- label="Top-P (Nucleus Sampling)",
199
- info="Cumulative probability threshold"
200
  )
201
 
202
- generate_btn = gr.Button("Generate Reply", variant="primary")
203
 
204
- with gr.Column():
205
  output = gr.Textbox(
206
- label="Generated Reply",
207
- lines=5,
 
208
  placeholder="Your AI-generated reply will appear here..."
209
  )
210
 
211
  gr.Markdown("""
212
- ### πŸ’‘ Tips:
213
- - **WITTY**: Clever, playful responses
214
- - **HELPFUL**: Supportive, solution-oriented
215
- - **PROFESSIONAL**: Formal, business-appropriate
216
- - **FRIENDLY**: Warm, conversational
217
- - **HUMOR**: Light-hearted, funny
 
 
 
 
 
218
  """)
219
 
 
 
220
  gr.Examples(
221
  examples=examples,
222
- inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider, top_p_slider],
223
  outputs=output,
224
  fn=generate_reply,
225
- cache_examples=True,
226
  )
227
 
 
228
  generate_btn.click(
229
  fn=generate_reply,
230
- inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider, top_p_slider],
231
  outputs=output
232
  )
233
 
234
  gr.Markdown("""
235
  ---
236
- **Model Details:** Custom transformer architecture trained on Twitter customer service data.
237
- Built with PyTorch | Deployed on HuggingFace Spaces
 
 
238
  """)
239
 
240
  # Launch
241
- demo.launch()
 
 
 
 
 
 
4
  import torch.nn.functional as F
5
  from tokenizers import Tokenizer
6
  import json
7
+ import math
8
+
9
+ print("πŸš€ Starting Twitter Reply Bot...")
10
 
11
  # Load configuration
12
  with open('model_config.json', 'r') as f:
13
  config = json.load(f)
14
+ print(f"βœ… Config loaded: {config['vocab_size']} vocab, {config['d_model']} d_model")
15
 
16
  # Load tokenizer
17
  tokenizer = Tokenizer.from_file("twitter_tokenizer.json")
18
+ print("βœ… Tokenizer loaded")
19
+
20
+ # ==================== EXACT MODEL ARCHITECTURE FROM TRAINING ====================
21
+ class RMSNorm(nn.Module):
22
+ """Root Mean Square Layer Normalization"""
23
+ def __init__(self, dim: int, eps: float = 1e-6):
24
+ super().__init__()
25
+ self.eps = eps
26
+ self.weight = nn.Parameter(torch.ones(dim))
27
+
28
+ def forward(self, x):
29
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
30
+ x_normed = x / rms
31
+ return self.weight * x_normed
32
+
33
+
34
+ class RotaryPositionEmbedding(nn.Module):
35
+ """Rotary Position Embeddings (RoPE)"""
36
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
37
+ super().__init__()
38
+ self.dim = dim
39
+ self.max_seq_len = max_seq_len
40
+ self.base = base
41
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
42
+ self.register_buffer("inv_freq", inv_freq)
43
+ self._build_cache(max_seq_len)
44
+
45
+ def _build_cache(self, seq_len: int):
46
+ t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
47
+ freqs = torch.outer(t, self.inv_freq)
48
+ emb = torch.cat((freqs, freqs), dim=-1)
49
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
50
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
51
+
52
+ def forward(self, q, k):
53
+ seq_len = q.shape[2]
54
+ cos = self.cos_cached[:, :, :seq_len, ...]
55
+ sin = self.sin_cached[:, :, :seq_len, ...]
56
+ q_rot = (q * cos) + (self._rotate_half(q) * sin)
57
+ k_rot = (k * cos) + (self._rotate_half(k) * sin)
58
+ return q_rot, k_rot
59
+
60
+ def _rotate_half(self, x):
61
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
62
+ return torch.cat((-x2, x1), dim=-1)
63
+
64
+
65
+ class MultiHeadAttention(nn.Module):
66
+ """Multi-Head Self Attention with RoPE"""
67
+ def __init__(self, d_model: int, n_heads: int, max_seq_len: int):
68
+ super().__init__()
69
+ assert d_model % n_heads == 0
70
+ self.d_model = d_model
71
+ self.n_heads = n_heads
72
+ self.head_dim = d_model // n_heads
73
+
74
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
75
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
76
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
77
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
78
+
79
+ self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len)
80
+
81
+ def forward(self, x, mask=None):
82
+ batch_size, seq_len, d_model = x.shape
83
+
84
+ q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
85
+ k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
86
+ v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
87
+
88
+ q, k = self.rope(q, k)
89
+
90
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
91
+
92
+ if mask is not None:
93
+ scores = scores.masked_fill(mask == 0, float('-inf'))
94
+
95
+ attn_weights = F.softmax(scores, dim=-1)
96
+ attn_output = torch.matmul(attn_weights, v)
97
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
98
+
99
+ return self.o_proj(attn_output)
100
+
101
+
102
+ class SwiGLU(nn.Module):
103
+ """SwiGLU Activation Function"""
104
+ def __init__(self, d_model: int, d_ff: int):
105
+ super().__init__()
106
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
107
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
108
+ self.w3 = nn.Linear(d_model, d_ff, bias=False)
109
+
110
+ def forward(self, x):
111
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
112
+
113
+
114
+ class TransformerBlock(nn.Module):
115
+ """Single Transformer Block"""
116
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, max_seq_len: int):
117
+ super().__init__()
118
+ self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len)
119
+ self.feed_forward = SwiGLU(d_model, d_ff)
120
+ self.norm1 = RMSNorm(d_model)
121
+ self.norm2 = RMSNorm(d_model)
122
+
123
+ def forward(self, x, mask=None):
124
+ x = x + self.attention(self.norm1(x), mask)
125
+ x = x + self.feed_forward(self.norm2(x))
126
+ return x
127
+
128
 
 
129
  class TwitterTransformer(nn.Module):
130
+ """Twitter Reply Transformer Model - EXACT TRAINING ARCHITECTURE"""
131
+ def __init__(self, vocab_size=8000, d_model=256, n_layers=6, n_heads=8,
132
+ d_ff=1024, max_seq_len=128, pad_token_id=0):
133
  super().__init__()
134
+ self.vocab_size = vocab_size
135
  self.d_model = d_model
136
+ self.max_seq_len = max_seq_len
137
  self.pad_token_id = pad_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
140
+ self.layers = nn.ModuleList([
141
+ TransformerBlock(d_model, n_heads, d_ff, max_seq_len)
142
+ for _ in range(n_layers)
143
+ ])
144
+ self.norm = RMSNorm(d_model)
145
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
146
+
147
+ # Weight tying
148
+ self.lm_head.weight = self.token_embedding.weight
149
+
150
+ def forward(self, input_ids, attention_mask=None):
151
+ batch_size, seq_len = input_ids.shape
152
 
153
+ # Create causal mask
154
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
155
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
156
 
157
+ if attention_mask is not None:
158
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
159
+ causal_mask = causal_mask * attention_mask
160
 
161
+ x = self.token_embedding(input_ids)
162
+
163
+ for layer in self.layers:
164
+ x = layer(x, causal_mask)
165
+
166
+ x = self.norm(x)
167
+ logits = self.lm_head(x)
168
+
169
+ return logits
170
 
171
  @torch.no_grad()
172
+ def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=50, eos_token_id=None):
 
173
  self.eval()
 
 
174
  for _ in range(max_new_tokens):
175
+ input_ids_cropped = input_ids[:, -self.max_seq_len:]
176
+ logits = self(input_ids_cropped)
177
  logits = logits[:, -1, :] / temperature
178
 
 
179
  if top_k > 0:
180
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
181
+ logits[indices_to_remove] = float('-inf')
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  probs = F.softmax(logits, dim=-1)
184
  next_token = torch.multinomial(probs, num_samples=1)
185
  input_ids = torch.cat([input_ids, next_token], dim=1)
186
 
187
+ if eos_token_id is not None and next_token.item() == eos_token_id:
188
  break
189
+
190
  return input_ids
191
 
192
+
193
+ # ==================== LOAD MODEL ====================
194
+ print("πŸ“₯ Loading model...")
195
+
196
  model = TwitterTransformer(
197
  vocab_size=config['vocab_size'],
198
  d_model=config['d_model'],
199
+ n_layers=6, # From your training
200
+ n_heads=8, # From your training
201
+ d_ff=1024, # From your training
202
  max_seq_len=config['max_seq_len'],
203
+ pad_token_id=config['pad_token_id']
 
 
 
204
  )
205
 
206
+ # Load weights
207
  model.load_state_dict(torch.load('twitter_reply_model_final.pt', map_location='cpu'))
208
  model.eval()
 
209
 
210
+ print("βœ… Model loaded successfully!")
211
+ print(f"πŸ“Š Parameters: {sum(p.numel() for p in model.parameters()):,}")
212
+
213
+
214
+ # ==================== GENERATION FUNCTION ====================
215
+ def generate_reply(tweet, personality, temperature, top_k):
216
+ """Generate a reply to a tweet"""
217
  try:
218
+ # Validate input
219
+ if not tweet or len(tweet.strip()) < 3:
220
+ return "⚠️ Please enter a valid tweet (at least 3 characters)"
221
+
222
  # Format input
223
+ input_text = f"{personality}{tweet.strip()}<SEP>"
224
 
225
  # Tokenize
226
  input_ids = [config['bos_token_id']] + tokenizer.encode(input_text).ids
 
231
  output = model.generate(
232
  input_ids,
233
  max_new_tokens=50,
234
+ temperature=max(0.5, min(temperature, 1.5)), # Clamp temperature
235
  top_k=int(top_k),
 
236
  eos_token_id=config['eos_token_id']
237
  )
238
 
239
  # Decode
240
  text = tokenizer.decode(output[0].tolist())
 
241
 
242
+ # Extract reply
243
+ try:
244
+ reply = text.split('<SEP>')[1].split('[EOS]')[0].strip()
245
+ # Remove any leftover special tokens
246
+ reply = reply.replace('[BOS]', '').replace('[EOS]', '').replace('<SEP>', '').strip()
247
+ except:
248
+ reply = text.strip()
249
+
250
+ return reply if reply else "πŸ€” Hmm, try adjusting temperature or rephrasing the tweet!"
251
 
252
  except Exception as e:
253
+ return f"❌ Error: {str(e)}\n\nTry refreshing the page or adjusting parameters."
254
 
255
+
256
+ # ==================== GRADIO INTERFACE ====================
257
  examples = [
258
+ ["Why is my internet so slow today?", "[HELPFUL]", 0.7, 40],
259
+ ["Your customer service is terrible!", "[PROFESSIONAL]", 0.6, 40],
260
+ ["I love your product!", "[WITTY]", 0.8, 50],
261
+ ["This is the worst service ever", "[HUMOR]", 0.8, 40],
262
+ ["How do I reset my password?", "[FRIENDLY]", 0.7, 40],
263
+ ["My order hasn't arrived yet", "[PROFESSIONAL]", 0.6, 40],
264
  ]
265
 
266
+ # Create interface
267
+ with gr.Blocks(theme=gr.themes.Soft(), title="Twitter Reply Bot") as demo:
268
  gr.Markdown("""
269
+ # πŸ€– Twitter Reply Bot
270
+ ## 8.34M Parameter Custom Transformer
271
+
272
+ Generate witty, contextual replies to tweets using an AI model trained from scratch on 100K customer service conversations.
273
 
274
+ **Training Stats:** Final Loss: 3.43 | 3 Epochs | 15 mins on T4 GPU
275
  """)
276
 
277
  with gr.Row():
278
+ with gr.Column(scale=1):
279
  tweet_input = gr.Textbox(
280
+ label="πŸ“± Tweet",
281
  placeholder="Enter a tweet to reply to...",
282
+ lines=4,
283
+ max_lines=6
284
  )
285
 
286
  personality_dropdown = gr.Dropdown(
287
  choices=["[WITTY]", "[HUMOR]", "[FRIENDLY]", "[PROFESSIONAL]", "[HELPFUL]"],
288
+ label="🎭 Reply Personality",
289
  value="[WITTY]",
290
  info="Choose the tone for the reply"
291
  )
292
 
293
+ with gr.Row():
294
  temperature_slider = gr.Slider(
295
  minimum=0.5,
296
+ maximum=1.2,
297
  value=0.7,
298
  step=0.1,
299
+ label="🌑️ Temperature",
300
+ info="Higher = more creative"
301
  )
302
 
303
  top_k_slider = gr.Slider(
 
305
  maximum=100,
306
  value=40,
307
  step=10,
308
+ label="🎯 Top-K",
309
+ info="Token selection diversity"
 
 
 
 
 
 
 
 
 
310
  )
311
 
312
+ generate_btn = gr.Button("✨ Generate Reply", variant="primary", size="lg")
313
 
314
+ with gr.Column(scale=1):
315
  output = gr.Textbox(
316
+ label="πŸ€– Generated Reply",
317
+ lines=6,
318
+ max_lines=8,
319
  placeholder="Your AI-generated reply will appear here..."
320
  )
321
 
322
  gr.Markdown("""
323
+ ### πŸ’‘ Personality Guide:
324
+ - **🎭 WITTY**: Clever, playful, engaging
325
+ - **πŸ˜‚ HUMOR**: Light-hearted, funny
326
+ - **🀝 FRIENDLY**: Warm, conversational
327
+ - **πŸ‘” PROFESSIONAL**: Formal, business tone
328
+ - **πŸ†˜ HELPFUL**: Solution-focused, supportive
329
+
330
+ ### βš™οΈ Parameter Tips:
331
+ - **Low temp (0.5-0.6)**: Consistent, safe replies
332
+ - **Mid temp (0.7-0.8)**: Balanced creativity
333
+ - **High temp (0.9-1.2)**: More creative, riskier
334
  """)
335
 
336
+ # Examples section
337
+ gr.Markdown("### πŸ“ Try These Examples:")
338
  gr.Examples(
339
  examples=examples,
340
+ inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider],
341
  outputs=output,
342
  fn=generate_reply,
343
+ cache_examples=False,
344
  )
345
 
346
+ # Connect button
347
  generate_btn.click(
348
  fn=generate_reply,
349
+ inputs=[tweet_input, personality_dropdown, temperature_slider, top_k_slider],
350
  outputs=output
351
  )
352
 
353
  gr.Markdown("""
354
  ---
355
+ **⚑ Model Architecture:** Custom Transformer with RoPE + SwiGLU + RMSNorm
356
+ **πŸ“Š Training Data:** 945K customer service tweets
357
+ **πŸ› οΈ Built with:** PyTorch, Tokenizers, Gradio
358
+ **πŸš€ Deployed on:** HuggingFace Spaces (Free CPU)
359
  """)
360
 
361
  # Launch
362
+ if __name__ == "__main__":
363
+ demo.launch(
364
+ server_name="0.0.0.0",
365
+ server_port=7860,
366
+ share=False
367
+ )