TheCodeKat commited on
Commit
94aff5c
Β·
verified Β·
1 Parent(s): 841149c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +287 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Scholar Sage - Language Model Web Interface
3
+ Interactive text generation using the trained Transformer model
4
+ """
5
+
6
+ import torch
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer
9
+ from model.transformer_explained import TinyTransformerLM
10
+
11
+
12
+ class TextGenerator:
13
+ def __init__(self, model_path="models/best_model_FIXED.pt"):
14
+ """Initialize the text generator with the trained model."""
15
+ print("πŸ”„ Loading model...")
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Load tokenizer
19
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
20
+ vocab_size = self.tokenizer.vocab_size
21
+
22
+ # Create model with same architecture as training
23
+ self.model = TinyTransformerLM(
24
+ vocab_size=vocab_size,
25
+ d_model=512,
26
+ n_layers=6,
27
+ num_heads=8,
28
+ d_ff=2048,
29
+ max_len=512
30
+ )
31
+
32
+ # Load trained weights
33
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
34
+ self.model.to(self.device)
35
+ self.model.eval()
36
+
37
+ total_params = sum(p.numel() for p in self.model.parameters())
38
+ print(f"βœ… Model loaded! ({total_params:,} parameters)")
39
+ print(f"πŸ–₯️ Device: {self.device}")
40
+
41
+ def generate(
42
+ self,
43
+ prompt,
44
+ max_length=50,
45
+ temperature=0.8,
46
+ top_k=40,
47
+ top_p=0.92,
48
+ repetition_penalty=1.2,
49
+ num_return_sequences=1
50
+ ):
51
+ """
52
+ Generate text based on the prompt with advanced sampling.
53
+
54
+ Args:
55
+ prompt: Input text to start generation
56
+ max_length: Maximum number of tokens to generate
57
+ temperature: Sampling temperature (higher = more random)
58
+ top_k: Top-k sampling parameter
59
+ top_p: Top-p (nucleus) sampling parameter
60
+ repetition_penalty: Penalty for repeating tokens (>1.0 discourages repetition)
61
+ num_return_sequences: Number of different outputs to generate
62
+ """
63
+ if not prompt.strip():
64
+ return "⚠️ Please enter a prompt!"
65
+
66
+ outputs = []
67
+
68
+ for _ in range(num_return_sequences):
69
+ # Tokenize input
70
+ input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
71
+ original_length = input_ids.size(1)
72
+
73
+ with torch.no_grad():
74
+ for step in range(max_length):
75
+ # Get logits
76
+ logits, _ = self.model(input_ids)
77
+ next_token_logits = logits[:, -1, :].clone()
78
+
79
+ # Apply repetition penalty
80
+ if repetition_penalty != 1.0:
81
+ for token_id in set(input_ids[0].tolist()):
82
+ # If score < 0, multiply by penalty (make it more negative)
83
+ # If score > 0, divide by penalty (make it smaller)
84
+ if next_token_logits[0, token_id] < 0:
85
+ next_token_logits[0, token_id] *= repetition_penalty
86
+ else:
87
+ next_token_logits[0, token_id] /= repetition_penalty
88
+
89
+ # Apply temperature
90
+ next_token_logits = next_token_logits / temperature
91
+
92
+ # Apply top-k filtering
93
+ if top_k > 0:
94
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))[0][..., -1, None]
95
+ next_token_logits[indices_to_remove] = float('-inf')
96
+
97
+ # Apply top-p (nucleus) filtering
98
+ if top_p < 1.0:
99
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
100
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
101
+
102
+ # Remove tokens with cumulative probability above the threshold
103
+ sorted_indices_to_remove = cumulative_probs > top_p
104
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
105
+ sorted_indices_to_remove[..., 0] = 0
106
+
107
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
108
+ next_token_logits[indices_to_remove] = float('-inf')
109
+
110
+ # Sample from the filtered distribution
111
+ probs = torch.softmax(next_token_logits, dim=-1)
112
+ next_token = torch.multinomial(probs, num_samples=1)
113
+
114
+ # Append to sequence
115
+ input_ids = torch.cat([input_ids, next_token], dim=1)
116
+
117
+ # Early stopping conditions
118
+ # Stop if we hit the model's max length
119
+ if input_ids.size(1) >= 512:
120
+ break
121
+
122
+ # Stop if we generate end-of-sequence token
123
+ if next_token.item() == self.tokenizer.eos_token_id:
124
+ break
125
+
126
+ # Decode the generated sequence
127
+ generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
128
+ outputs.append(generated_text)
129
+
130
+ # Return single output or multiple outputs separated
131
+ if num_return_sequences == 1:
132
+ return outputs[0]
133
+ else:
134
+ return "\n\n" + "="*70 + "\n\n".join(outputs)
135
+
136
+
137
+ # Initialize generator
138
+ generator = TextGenerator()
139
+
140
+
141
+ def generate_text(prompt, max_length, temperature, top_k, top_p, repetition_penalty, num_outputs):
142
+ """Wrapper function for Gradio interface."""
143
+ try:
144
+ result = generator.generate(
145
+ prompt=prompt,
146
+ max_length=int(max_length),
147
+ temperature=float(temperature),
148
+ top_k=int(top_k),
149
+ top_p=float(top_p),
150
+ repetition_penalty=float(repetition_penalty),
151
+ num_return_sequences=int(num_outputs)
152
+ )
153
+ return result
154
+ except Exception as e:
155
+ return f"❌ Error: {str(e)}"
156
+
157
+
158
+ # Create Gradio interface
159
+ with gr.Blocks(title="Scholar Sage - Language Model", theme=gr.themes.Soft()) as demo:
160
+ gr.Markdown(
161
+ """
162
+ # πŸŽ“ Scholar Sage - Language Model
163
+
164
+ A transformer-based language model trained on WikiText-2 with **causal masking**.
165
+
166
+ **Model Details:**
167
+ - 45M parameters (6 layers, 512 hidden dim, 8 attention heads)
168
+ - Trained with proper causal attention masking
169
+ - Best model from epoch 3/5
170
+
171
+ ⚠️ **Note**: This is a small research model (~45M params vs GPT-3's 175B). For best results:
172
+ - Use **Repetition Penalty = 1.2-1.5** to prevent repetitive text
173
+ - Keep prompts clear and specific
174
+ - Expect limited context understanding compared to large commercial models
175
+ """
176
+ )
177
+
178
+ with gr.Row():
179
+ with gr.Column(scale=1):
180
+ prompt_input = gr.Textbox(
181
+ label="πŸ“ Enter your prompt",
182
+ placeholder="Start typing... (e.g., 'Machine learning is')",
183
+ lines=3
184
+ )
185
+
186
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
187
+ max_length = gr.Slider(
188
+ minimum=10,
189
+ maximum=200,
190
+ value=50,
191
+ step=10,
192
+ label="Max Length (tokens to generate)"
193
+ )
194
+
195
+ temperature = gr.Slider(
196
+ minimum=0.1,
197
+ maximum=2.0,
198
+ value=0.8,
199
+ step=0.1,
200
+ label="Temperature (higher = more random)"
201
+ )
202
+
203
+ top_k = gr.Slider(
204
+ minimum=0,
205
+ maximum=100,
206
+ value=40,
207
+ step=5,
208
+ label="Top-k (0 = disabled)"
209
+ )
210
+
211
+ top_p = gr.Slider(
212
+ minimum=0.0,
213
+ maximum=1.0,
214
+ value=0.92,
215
+ step=0.02,
216
+ label="Top-p / Nucleus Sampling"
217
+ )
218
+
219
+ repetition_penalty = gr.Slider(
220
+ minimum=1.0,
221
+ maximum=2.0,
222
+ value=1.2,
223
+ step=0.1,
224
+ label="Repetition Penalty (higher = less repetition)"
225
+ )
226
+
227
+ num_outputs = gr.Slider(
228
+ minimum=1,
229
+ maximum=3,
230
+ value=1,
231
+ step=1,
232
+ label="Number of outputs"
233
+ )
234
+
235
+ generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
236
+
237
+ with gr.Column(scale=1):
238
+ output_text = gr.Textbox(
239
+ label="✨ Generated Text",
240
+ lines=15,
241
+ show_copy_button=True
242
+ )
243
+
244
+ # Examples
245
+ gr.Markdown("### πŸ’‘ Example Prompts")
246
+ gr.Examples(
247
+ examples=[
248
+ ["Machine learning is", 50, 0.8, 40, 0.92, 1.2, 1],
249
+ ["The future of artificial intelligence", 50, 0.8, 40, 0.92, 1.2, 1],
250
+ ["Natural language processing", 50, 0.8, 40, 0.92, 1.2, 1],
251
+ ["In the field of computer science", 50, 0.8, 40, 0.92, 1.2, 1],
252
+ ["Researchers have discovered that", 50, 0.8, 40, 0.92, 1.2, 1],
253
+ ],
254
+ inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, num_outputs],
255
+ outputs=output_text,
256
+ fn=generate_text,
257
+ cache_examples=False
258
+ )
259
+
260
+ # Connect the button
261
+ generate_btn.click(
262
+ fn=generate_text,
263
+ inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, num_outputs],
264
+ outputs=output_text
265
+ )
266
+
267
+ gr.Markdown(
268
+ """
269
+ ---
270
+ **Tips for Better Generation:**
271
+ - 🌑️ **Temperature**: Lower (0.5-0.7) = more focused, Higher (1.0-1.5) = more creative
272
+ - 🎯 **Top-k**: Limits vocabulary to top k most likely tokens (try 30-50)
273
+ - πŸ”¬ **Top-p**: Nucleus sampling - keeps smallest set of tokens with cumulative probability > p (try 0.9-0.95)
274
+ - πŸ” **Repetition Penalty**: Higher values (1.2-1.5) reduce repetition (IMPORTANT for this model!)
275
+
276
+ **For best results**: Use temperature=0.8, top-k=40, top-p=0.92, repetition_penalty=1.2-1.5
277
+ """
278
+ )
279
+
280
+
281
+ if __name__ == "__main__":
282
+ demo.launch(
283
+ server_name="0.0.0.0",
284
+ server_port=7860,
285
+ share=False,
286
+ show_error=True
287
+ )
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.30
3
+ datasets
4
+ sentence-transformers
5
+ tokenizers
6
+ huggingface-hub
7
+ gradio
8
+ fastapi
9
+ uvicorn
10
+ matplotlib
11
+ PyQt5