""" Scholar Sage - Improved Language Model Web Interface Optimized for better text generation quality """ import torch import gradio as gr from transformers import AutoTokenizer from model.transformer_explained import TinyTransformerLM from generation_config import CONFIGS class TextGenerator: def __init__(self, model_path="models/best_model_FIXED.pt"): print("🔄 Loading model...") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained("gpt2") self.model = TinyTransformerLM( vocab_size=self.tokenizer.vocab_size, d_model=512, n_layers=6, num_heads=8, d_ff=2048, max_len=512 ) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() print(f"✅ Model loaded on {self.device}") def generate(self, prompt, max_length=50, temperature=0.7, top_k=40, top_p=0.9, repetition_penalty=1.3, num_return_sequences=1): """Generate text with optimized sampling.""" # Improved prompt preprocessing if not prompt.strip(): return "⚠️ Please enter a prompt!" # Add context hints for better generation enhanced_prompt = prompt.strip() outputs = [] for _ in range(num_return_sequences): input_ids = self.tokenizer(enhanced_prompt, return_tensors="pt")["input_ids"].to(self.device) with torch.no_grad(): for step in range(max_length): logits, _ = self.model(input_ids) next_token_logits = logits[:, -1, :].clone() # Enhanced repetition penalty if repetition_penalty != 1.0: for token_id in set(input_ids[0].tolist()): if next_token_logits[0, token_id] < 0: next_token_logits[0, token_id] *= repetition_penalty else: next_token_logits[0, token_id] /= repetition_penalty next_token_logits = next_token_logits / temperature # Top-k filtering if top_k > 0: indices_to_remove = next_token_logits < torch.topk( next_token_logits, min(top_k, next_token_logits.size(-1)) )[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # Top-p filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) # Better stopping conditions if input_ids.size(1) >= 512: break if next_token.item() == self.tokenizer.eos_token_id: break # Stop on double newline for cleaner outputs if step > 10 and self.tokenizer.decode(input_ids[0, -2:]) == "\n\n": break generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) outputs.append(generated_text) return outputs[0] if num_return_sequences == 1 else "\n\n---\n\n".join(outputs) generator = TextGenerator() def generate_with_preset(prompt, preset, max_length, custom_temp, custom_top_k, custom_top_p, custom_rep_pen, num_outputs): """Generate using preset or custom parameters.""" if not prompt.strip(): return "⚠️ Please enter a prompt!" # Use preset if selected, otherwise use custom values if preset != "custom": config = CONFIGS[preset] temp = config["temperature"] top_k = config["top_k"] top_p = config["top_p"] rep_pen = config["repetition_penalty"] else: temp = custom_temp top_k = custom_top_k top_p = custom_top_p rep_pen = custom_rep_pen try: result = generator.generate( prompt=prompt, max_length=int(max_length), temperature=float(temp), top_k=int(top_k), top_p=float(top_p), repetition_penalty=float(rep_pen), num_return_sequences=int(num_outputs) ) return result except Exception as e: return f"❌ Error: {str(e)}" # Build Gradio Interface with gr.Blocks(title="Scholar Sage - Improved", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎓 Scholar Sage - Language Model (Optimized) A 45M parameter transformer trained on WikiText-2. **Use presets** for best results! 💡 **Tips for Quality Output:** - Use **"Balanced" preset** for general use - Start with encyclopedia-style prompts (model trained on WikiText) - Try longer prompts (10-20 words) for better context - Experiment with different presets for different styles """) with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox( label="📝 Enter Your Prompt", placeholder="Example: The theory of relativity is a scientific theory that", lines=4 ) preset_selector = gr.Radio( choices=["balanced", "creative", "focused", "factual", "custom"], value="balanced", label="🎚️ Preset Configuration", info="Balanced is recommended for most uses" ) max_length = gr.Slider( minimum=20, maximum=150, value=60, step=10, label="📏 Max Length (tokens)" ) num_outputs = gr.Slider( minimum=1, maximum=3, value=1, step=1, label="🔢 Number of Outputs" ) with gr.Accordion("⚙️ Custom Settings", open=False): gr.Markdown("*Only used when 'custom' preset is selected*") custom_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature") custom_top_k = gr.Slider(0, 100, 40, step=5, label="Top-k") custom_top_p = gr.Slider(0.0, 1.0, 0.9, step=0.05, label="Top-p") custom_rep_pen = gr.Slider(1.0, 2.0, 1.3, step=0.1, label="Repetition Penalty") generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg") with gr.Column(scale=1): output_text = gr.Textbox( label="✨ Generated Text", lines=18, show_copy_button=True ) # Example prompts optimized for WikiText-2 training gr.Markdown("### 💡 Example Prompts (Optimized for this Model)") gr.Examples( examples=[ ["The history of artificial intelligence began in", "balanced", 60, 0.7, 40, 0.9, 1.3, 1], ["Python programming language is a high-level", "factual", 60, 0.3, 20, 0.8, 1.4, 1], ["In the field of quantum mechanics,", "balanced", 60, 0.7, 40, 0.9, 1.3, 1], ["The United States is a country located in", "factual", 60, 0.3, 20, 0.8, 1.4, 1], ["Machine learning algorithms can be used to", "balanced", 60, 0.7, 40, 0.9, 1.3, 1], ], inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k, custom_top_p, custom_rep_pen, num_outputs], ) generate_btn.click( fn=generate_with_preset, inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k, custom_top_p, custom_rep_pen, num_outputs], outputs=output_text ) gr.Markdown(""" --- ### 📌 Understanding the Presets - **Balanced** (default): Best for general encyclopedia-style text - **Creative**: More diverse outputs, good for storytelling - **Focused**: Deterministic, good for factual content - **Factual**: Highest coherence, lowest creativity - **Custom**: Manual control over all parameters ### ⚠️ Model Limitations This is a 45M parameter research model (vs GPT-3's 175B). It works best with: - ✅ Encyclopedia-style content (trained on WikiText-2) - ✅ Factual, informative text - ✅ Short to medium generations (20-100 tokens) It struggles with: - ❌ Creative fiction or dialogue - ❌ Very long context understanding - ❌ Highly specialized technical content """) if __name__ == "__main__": demo.launch()