Spaces:
Sleeping
Sleeping
| """ | |
| Scholar Sage - Improved Language Model Web Interface | |
| Optimized for better text generation quality | |
| """ | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from model.transformer_explained import TinyTransformerLM | |
| from generation_config import CONFIGS | |
| class TextGenerator: | |
| def __init__(self, model_path="models/best_model_FIXED.pt"): | |
| print("π Loading model...") | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| self.model = TinyTransformerLM( | |
| vocab_size=self.tokenizer.vocab_size, | |
| d_model=512, n_layers=6, num_heads=8, d_ff=2048, max_len=512 | |
| ) | |
| self.model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print(f"β Model loaded on {self.device}") | |
| def generate(self, prompt, max_length=50, temperature=0.7, top_k=40, | |
| top_p=0.9, repetition_penalty=1.3, num_return_sequences=1): | |
| """Generate text with optimized sampling.""" | |
| # Improved prompt preprocessing | |
| if not prompt.strip(): | |
| return "β οΈ Please enter a prompt!" | |
| # Add context hints for better generation | |
| enhanced_prompt = prompt.strip() | |
| outputs = [] | |
| for _ in range(num_return_sequences): | |
| input_ids = self.tokenizer(enhanced_prompt, return_tensors="pt")["input_ids"].to(self.device) | |
| with torch.no_grad(): | |
| for step in range(max_length): | |
| logits, _ = self.model(input_ids) | |
| next_token_logits = logits[:, -1, :].clone() | |
| # Enhanced repetition penalty | |
| if repetition_penalty != 1.0: | |
| for token_id in set(input_ids[0].tolist()): | |
| if next_token_logits[0, token_id] < 0: | |
| next_token_logits[0, token_id] *= repetition_penalty | |
| else: | |
| next_token_logits[0, token_id] /= repetition_penalty | |
| next_token_logits = next_token_logits / temperature | |
| # Top-k filtering | |
| if top_k > 0: | |
| indices_to_remove = next_token_logits < torch.topk( | |
| next_token_logits, min(top_k, next_token_logits.size(-1)) | |
| )[0][..., -1, None] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Top-p filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| # Better stopping conditions | |
| if input_ids.size(1) >= 512: | |
| break | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # Stop on double newline for cleaner outputs | |
| if step > 10 and self.tokenizer.decode(input_ids[0, -2:]) == "\n\n": | |
| break | |
| generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| outputs.append(generated_text) | |
| return outputs[0] if num_return_sequences == 1 else "\n\n---\n\n".join(outputs) | |
| generator = TextGenerator() | |
| def generate_with_preset(prompt, preset, max_length, custom_temp, custom_top_k, | |
| custom_top_p, custom_rep_pen, num_outputs): | |
| """Generate using preset or custom parameters.""" | |
| if not prompt.strip(): | |
| return "β οΈ Please enter a prompt!" | |
| # Use preset if selected, otherwise use custom values | |
| if preset != "custom": | |
| config = CONFIGS[preset] | |
| temp = config["temperature"] | |
| top_k = config["top_k"] | |
| top_p = config["top_p"] | |
| rep_pen = config["repetition_penalty"] | |
| else: | |
| temp = custom_temp | |
| top_k = custom_top_k | |
| top_p = custom_top_p | |
| rep_pen = custom_rep_pen | |
| try: | |
| result = generator.generate( | |
| prompt=prompt, | |
| max_length=int(max_length), | |
| temperature=float(temp), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| repetition_penalty=float(rep_pen), | |
| num_return_sequences=int(num_outputs) | |
| ) | |
| return result | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # Build Gradio Interface | |
| with gr.Blocks(title="Scholar Sage - Improved", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π Scholar Sage - Language Model (Optimized) | |
| A 45M parameter transformer trained on WikiText-2. **Use presets** for best results! | |
| π‘ **Tips for Quality Output:** | |
| - Use **"Balanced" preset** for general use | |
| - Start with encyclopedia-style prompts (model trained on WikiText) | |
| - Try longer prompts (10-20 words) for better context | |
| - Experiment with different presets for different styles | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="π Enter Your Prompt", | |
| placeholder="Example: The theory of relativity is a scientific theory that", | |
| lines=4 | |
| ) | |
| preset_selector = gr.Radio( | |
| choices=["balanced", "creative", "focused", "factual", "custom"], | |
| value="balanced", | |
| label="ποΈ Preset Configuration", | |
| info="Balanced is recommended for most uses" | |
| ) | |
| max_length = gr.Slider( | |
| minimum=20, maximum=150, value=60, step=10, | |
| label="π Max Length (tokens)" | |
| ) | |
| num_outputs = gr.Slider( | |
| minimum=1, maximum=3, value=1, step=1, | |
| label="π’ Number of Outputs" | |
| ) | |
| with gr.Accordion("βοΈ Custom Settings", open=False): | |
| gr.Markdown("*Only used when 'custom' preset is selected*") | |
| custom_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature") | |
| custom_top_k = gr.Slider(0, 100, 40, step=5, label="Top-k") | |
| custom_top_p = gr.Slider(0.0, 1.0, 0.9, step=0.05, label="Top-p") | |
| custom_rep_pen = gr.Slider(1.0, 2.0, 1.3, step=0.1, label="Repetition Penalty") | |
| generate_btn = gr.Button("π Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="β¨ Generated Text", | |
| lines=18, | |
| show_copy_button=True | |
| ) | |
| # Example prompts optimized for WikiText-2 training | |
| gr.Markdown("### π‘ Example Prompts (Optimized for this Model)") | |
| gr.Examples( | |
| examples=[ | |
| ["The history of artificial intelligence began in", "balanced", 60, 0.7, 40, 0.9, 1.3, 1], | |
| ["Python programming language is a high-level", "factual", 60, 0.3, 20, 0.8, 1.4, 1], | |
| ["In the field of quantum mechanics,", "balanced", 60, 0.7, 40, 0.9, 1.3, 1], | |
| ["The United States is a country located in", "factual", 60, 0.3, 20, 0.8, 1.4, 1], | |
| ["Machine learning algorithms can be used to", "balanced", 60, 0.7, 40, 0.9, 1.3, 1], | |
| ], | |
| inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k, | |
| custom_top_p, custom_rep_pen, num_outputs], | |
| ) | |
| generate_btn.click( | |
| fn=generate_with_preset, | |
| inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k, | |
| custom_top_p, custom_rep_pen, num_outputs], | |
| outputs=output_text | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Understanding the Presets | |
| - **Balanced** (default): Best for general encyclopedia-style text | |
| - **Creative**: More diverse outputs, good for storytelling | |
| - **Focused**: Deterministic, good for factual content | |
| - **Factual**: Highest coherence, lowest creativity | |
| - **Custom**: Manual control over all parameters | |
| ### β οΈ Model Limitations | |
| This is a 45M parameter research model (vs GPT-3's 175B). It works best with: | |
| - β Encyclopedia-style content (trained on WikiText-2) | |
| - β Factual, informative text | |
| - β Short to medium generations (20-100 tokens) | |
| It struggles with: | |
| - β Creative fiction or dialogue | |
| - β Very long context understanding | |
| - β Highly specialized technical content | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |