File size: 9,678 Bytes
94aff5c
bbffad2
 
94aff5c
 
 
 
 
 
bbffad2
94aff5c
 
 
 
 
 
 
 
 
bbffad2
 
94aff5c
 
 
 
 
bbffad2
94aff5c
bbffad2
 
 
94aff5c
bbffad2
94aff5c
 
 
bbffad2
 
94aff5c
bbffad2
94aff5c
bbffad2
94aff5c
 
 
 
 
 
bbffad2
94aff5c
 
 
 
 
 
 
 
 
bbffad2
94aff5c
bbffad2
 
 
94aff5c
 
bbffad2
94aff5c
 
 
 
 
 
 
 
 
 
 
 
 
bbffad2
94aff5c
 
 
 
bbffad2
 
 
94aff5c
 
 
 
bbffad2
94aff5c
 
 
 
 
bbffad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94aff5c
 
 
 
bbffad2
94aff5c
 
bbffad2
94aff5c
 
 
 
 
 
 
bbffad2
 
 
 
 
 
 
 
 
 
 
 
 
94aff5c
 
 
 
bbffad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94aff5c
 
bbffad2
 
 
 
 
 
94aff5c
 
 
 
 
 
bbffad2
94aff5c
 
 
bbffad2
 
94aff5c
 
bbffad2
 
 
 
 
94aff5c
bbffad2
 
94aff5c
 
 
bbffad2
 
 
94aff5c
 
 
bbffad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94aff5c
 
 
bbffad2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
Scholar Sage - Improved Language Model Web Interface
Optimized for better text generation quality
"""

import torch
import gradio as gr
from transformers import AutoTokenizer
from model.transformer_explained import TinyTransformerLM
from generation_config import CONFIGS


class TextGenerator:
    def __init__(self, model_path="models/best_model_FIXED.pt"):
        print("πŸ”„ Loading model...")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        
        self.model = TinyTransformerLM(
            vocab_size=self.tokenizer.vocab_size,
            d_model=512, n_layers=6, num_heads=8, d_ff=2048, max_len=512
        )
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        
        print(f"βœ… Model loaded on {self.device}")
    
    def generate(self, prompt, max_length=50, temperature=0.7, top_k=40, 
                 top_p=0.9, repetition_penalty=1.3, num_return_sequences=1):
        """Generate text with optimized sampling."""
        
        # Improved prompt preprocessing
        if not prompt.strip():
            return "⚠️ Please enter a prompt!"
        
        # Add context hints for better generation
        enhanced_prompt = prompt.strip()
        
        outputs = []
        for _ in range(num_return_sequences):
            input_ids = self.tokenizer(enhanced_prompt, return_tensors="pt")["input_ids"].to(self.device)
            
            with torch.no_grad():
                for step in range(max_length):
                    logits, _ = self.model(input_ids)
                    next_token_logits = logits[:, -1, :].clone()
                    
                    # Enhanced repetition penalty
                    if repetition_penalty != 1.0:
                        for token_id in set(input_ids[0].tolist()):
                            if next_token_logits[0, token_id] < 0:
                                next_token_logits[0, token_id] *= repetition_penalty
                            else:
                                next_token_logits[0, token_id] /= repetition_penalty
                    
                    next_token_logits = next_token_logits / temperature
                    
                    # Top-k filtering
                    if top_k > 0:
                        indices_to_remove = next_token_logits < torch.topk(
                            next_token_logits, min(top_k, next_token_logits.size(-1))
                        )[0][..., -1, None]
                        next_token_logits[indices_to_remove] = float('-inf')
                    
                    # Top-p filtering
                    if top_p < 1.0:
                        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                        sorted_indices_to_remove = cumulative_probs > top_p
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = 0
                        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                        next_token_logits[indices_to_remove] = float('-inf')
                    
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    input_ids = torch.cat([input_ids, next_token], dim=1)
                    
                    # Better stopping conditions
                    if input_ids.size(1) >= 512:
                        break
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break
                    # Stop on double newline for cleaner outputs
                    if step > 10 and self.tokenizer.decode(input_ids[0, -2:]) == "\n\n":
                        break
            
            generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
            outputs.append(generated_text)
        
        return outputs[0] if num_return_sequences == 1 else "\n\n---\n\n".join(outputs)


generator = TextGenerator()


def generate_with_preset(prompt, preset, max_length, custom_temp, custom_top_k, 
                         custom_top_p, custom_rep_pen, num_outputs):
    """Generate using preset or custom parameters."""
    if not prompt.strip():
        return "⚠️ Please enter a prompt!"
    
    # Use preset if selected, otherwise use custom values
    if preset != "custom":
        config = CONFIGS[preset]
        temp = config["temperature"]
        top_k = config["top_k"]
        top_p = config["top_p"]
        rep_pen = config["repetition_penalty"]
    else:
        temp = custom_temp
        top_k = custom_top_k
        top_p = custom_top_p
        rep_pen = custom_rep_pen
    
    try:
        result = generator.generate(
            prompt=prompt,
            max_length=int(max_length),
            temperature=float(temp),
            top_k=int(top_k),
            top_p=float(top_p),
            repetition_penalty=float(rep_pen),
            num_return_sequences=int(num_outputs)
        )
        return result
    except Exception as e:
        return f"❌ Error: {str(e)}"


# Build Gradio Interface
with gr.Blocks(title="Scholar Sage - Improved", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # πŸŽ“ Scholar Sage - Language Model (Optimized)
    
    A 45M parameter transformer trained on WikiText-2. **Use presets** for best results!
    
    πŸ’‘ **Tips for Quality Output:**
    - Use **"Balanced" preset** for general use
    - Start with encyclopedia-style prompts (model trained on WikiText)
    - Try longer prompts (10-20 words) for better context
    - Experiment with different presets for different styles
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(
                label="πŸ“ Enter Your Prompt",
                placeholder="Example: The theory of relativity is a scientific theory that",
                lines=4
            )
            
            preset_selector = gr.Radio(
                choices=["balanced", "creative", "focused", "factual", "custom"],
                value="balanced",
                label="🎚️ Preset Configuration",
                info="Balanced is recommended for most uses"
            )
            
            max_length = gr.Slider(
                minimum=20, maximum=150, value=60, step=10,
                label="πŸ“ Max Length (tokens)"
            )
            
            num_outputs = gr.Slider(
                minimum=1, maximum=3, value=1, step=1,
                label="πŸ”’ Number of Outputs"
            )
            
            with gr.Accordion("βš™οΈ Custom Settings", open=False):
                gr.Markdown("*Only used when 'custom' preset is selected*")
                custom_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
                custom_top_k = gr.Slider(0, 100, 40, step=5, label="Top-k")
                custom_top_p = gr.Slider(0.0, 1.0, 0.9, step=0.05, label="Top-p")
                custom_rep_pen = gr.Slider(1.0, 2.0, 1.3, step=0.1, label="Repetition Penalty")
            
            generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            output_text = gr.Textbox(
                label="✨ Generated Text",
                lines=18,
                show_copy_button=True
            )
    
    # Example prompts optimized for WikiText-2 training
    gr.Markdown("### πŸ’‘ Example Prompts (Optimized for this Model)")
    gr.Examples(
        examples=[
            ["The history of artificial intelligence began in", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
            ["Python programming language is a high-level", "factual", 60, 0.3, 20, 0.8, 1.4, 1],
            ["In the field of quantum mechanics,", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
            ["The United States is a country located in", "factual", 60, 0.3, 20, 0.8, 1.4, 1],
            ["Machine learning algorithms can be used to", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
        ],
        inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k, 
                custom_top_p, custom_rep_pen, num_outputs],
    )
    
    generate_btn.click(
        fn=generate_with_preset,
        inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k,
                custom_top_p, custom_rep_pen, num_outputs],
        outputs=output_text
    )
    
    gr.Markdown("""
    ---
    ### πŸ“Œ Understanding the Presets
    
    - **Balanced** (default): Best for general encyclopedia-style text
    - **Creative**: More diverse outputs, good for storytelling
    - **Focused**: Deterministic, good for factual content
    - **Factual**: Highest coherence, lowest creativity
    - **Custom**: Manual control over all parameters
    
    ### ⚠️ Model Limitations
    
    This is a 45M parameter research model (vs GPT-3's 175B). It works best with:
    - βœ… Encyclopedia-style content (trained on WikiText-2)
    - βœ… Factual, informative text
    - βœ… Short to medium generations (20-100 tokens)
    
    It struggles with:
    - ❌ Creative fiction or dialogue
    - ❌ Very long context understanding
    - ❌ Highly specialized technical content
    """)


if __name__ == "__main__":
    demo.launch()