Scholar-Sage / app.py
TheCodeKat's picture
Add preset configurations for better quality
bbffad2
"""
Scholar Sage - Improved Language Model Web Interface
Optimized for better text generation quality
"""
import torch
import gradio as gr
from transformers import AutoTokenizer
from model.transformer_explained import TinyTransformerLM
from generation_config import CONFIGS
class TextGenerator:
def __init__(self, model_path="models/best_model_FIXED.pt"):
print("πŸ”„ Loading model...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.model = TinyTransformerLM(
vocab_size=self.tokenizer.vocab_size,
d_model=512, n_layers=6, num_heads=8, d_ff=2048, max_len=512
)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
print(f"βœ… Model loaded on {self.device}")
def generate(self, prompt, max_length=50, temperature=0.7, top_k=40,
top_p=0.9, repetition_penalty=1.3, num_return_sequences=1):
"""Generate text with optimized sampling."""
# Improved prompt preprocessing
if not prompt.strip():
return "⚠️ Please enter a prompt!"
# Add context hints for better generation
enhanced_prompt = prompt.strip()
outputs = []
for _ in range(num_return_sequences):
input_ids = self.tokenizer(enhanced_prompt, return_tensors="pt")["input_ids"].to(self.device)
with torch.no_grad():
for step in range(max_length):
logits, _ = self.model(input_ids)
next_token_logits = logits[:, -1, :].clone()
# Enhanced repetition penalty
if repetition_penalty != 1.0:
for token_id in set(input_ids[0].tolist()):
if next_token_logits[0, token_id] < 0:
next_token_logits[0, token_id] *= repetition_penalty
else:
next_token_logits[0, token_id] /= repetition_penalty
next_token_logits = next_token_logits / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(
next_token_logits, min(top_k, next_token_logits.size(-1))
)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
# Better stopping conditions
if input_ids.size(1) >= 512:
break
if next_token.item() == self.tokenizer.eos_token_id:
break
# Stop on double newline for cleaner outputs
if step > 10 and self.tokenizer.decode(input_ids[0, -2:]) == "\n\n":
break
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
outputs.append(generated_text)
return outputs[0] if num_return_sequences == 1 else "\n\n---\n\n".join(outputs)
generator = TextGenerator()
def generate_with_preset(prompt, preset, max_length, custom_temp, custom_top_k,
custom_top_p, custom_rep_pen, num_outputs):
"""Generate using preset or custom parameters."""
if not prompt.strip():
return "⚠️ Please enter a prompt!"
# Use preset if selected, otherwise use custom values
if preset != "custom":
config = CONFIGS[preset]
temp = config["temperature"]
top_k = config["top_k"]
top_p = config["top_p"]
rep_pen = config["repetition_penalty"]
else:
temp = custom_temp
top_k = custom_top_k
top_p = custom_top_p
rep_pen = custom_rep_pen
try:
result = generator.generate(
prompt=prompt,
max_length=int(max_length),
temperature=float(temp),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(rep_pen),
num_return_sequences=int(num_outputs)
)
return result
except Exception as e:
return f"❌ Error: {str(e)}"
# Build Gradio Interface
with gr.Blocks(title="Scholar Sage - Improved", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸŽ“ Scholar Sage - Language Model (Optimized)
A 45M parameter transformer trained on WikiText-2. **Use presets** for best results!
πŸ’‘ **Tips for Quality Output:**
- Use **"Balanced" preset** for general use
- Start with encyclopedia-style prompts (model trained on WikiText)
- Try longer prompts (10-20 words) for better context
- Experiment with different presets for different styles
""")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="πŸ“ Enter Your Prompt",
placeholder="Example: The theory of relativity is a scientific theory that",
lines=4
)
preset_selector = gr.Radio(
choices=["balanced", "creative", "focused", "factual", "custom"],
value="balanced",
label="🎚️ Preset Configuration",
info="Balanced is recommended for most uses"
)
max_length = gr.Slider(
minimum=20, maximum=150, value=60, step=10,
label="πŸ“ Max Length (tokens)"
)
num_outputs = gr.Slider(
minimum=1, maximum=3, value=1, step=1,
label="πŸ”’ Number of Outputs"
)
with gr.Accordion("βš™οΈ Custom Settings", open=False):
gr.Markdown("*Only used when 'custom' preset is selected*")
custom_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
custom_top_k = gr.Slider(0, 100, 40, step=5, label="Top-k")
custom_top_p = gr.Slider(0.0, 1.0, 0.9, step=0.05, label="Top-p")
custom_rep_pen = gr.Slider(1.0, 2.0, 1.3, step=0.1, label="Repetition Penalty")
generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="✨ Generated Text",
lines=18,
show_copy_button=True
)
# Example prompts optimized for WikiText-2 training
gr.Markdown("### πŸ’‘ Example Prompts (Optimized for this Model)")
gr.Examples(
examples=[
["The history of artificial intelligence began in", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
["Python programming language is a high-level", "factual", 60, 0.3, 20, 0.8, 1.4, 1],
["In the field of quantum mechanics,", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
["The United States is a country located in", "factual", 60, 0.3, 20, 0.8, 1.4, 1],
["Machine learning algorithms can be used to", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
],
inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k,
custom_top_p, custom_rep_pen, num_outputs],
)
generate_btn.click(
fn=generate_with_preset,
inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k,
custom_top_p, custom_rep_pen, num_outputs],
outputs=output_text
)
gr.Markdown("""
---
### πŸ“Œ Understanding the Presets
- **Balanced** (default): Best for general encyclopedia-style text
- **Creative**: More diverse outputs, good for storytelling
- **Focused**: Deterministic, good for factual content
- **Factual**: Highest coherence, lowest creativity
- **Custom**: Manual control over all parameters
### ⚠️ Model Limitations
This is a 45M parameter research model (vs GPT-3's 175B). It works best with:
- βœ… Encyclopedia-style content (trained on WikiText-2)
- βœ… Factual, informative text
- βœ… Short to medium generations (20-100 tokens)
It struggles with:
- ❌ Creative fiction or dialogue
- ❌ Very long context understanding
- ❌ Highly specialized technical content
""")
if __name__ == "__main__":
demo.launch()