Spaces:
Sleeping
Sleeping
File size: 9,678 Bytes
94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 94aff5c bbffad2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
"""
Scholar Sage - Improved Language Model Web Interface
Optimized for better text generation quality
"""
import torch
import gradio as gr
from transformers import AutoTokenizer
from model.transformer_explained import TinyTransformerLM
from generation_config import CONFIGS
class TextGenerator:
def __init__(self, model_path="models/best_model_FIXED.pt"):
print("π Loading model...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.model = TinyTransformerLM(
vocab_size=self.tokenizer.vocab_size,
d_model=512, n_layers=6, num_heads=8, d_ff=2048, max_len=512
)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
print(f"β
Model loaded on {self.device}")
def generate(self, prompt, max_length=50, temperature=0.7, top_k=40,
top_p=0.9, repetition_penalty=1.3, num_return_sequences=1):
"""Generate text with optimized sampling."""
# Improved prompt preprocessing
if not prompt.strip():
return "β οΈ Please enter a prompt!"
# Add context hints for better generation
enhanced_prompt = prompt.strip()
outputs = []
for _ in range(num_return_sequences):
input_ids = self.tokenizer(enhanced_prompt, return_tensors="pt")["input_ids"].to(self.device)
with torch.no_grad():
for step in range(max_length):
logits, _ = self.model(input_ids)
next_token_logits = logits[:, -1, :].clone()
# Enhanced repetition penalty
if repetition_penalty != 1.0:
for token_id in set(input_ids[0].tolist()):
if next_token_logits[0, token_id] < 0:
next_token_logits[0, token_id] *= repetition_penalty
else:
next_token_logits[0, token_id] /= repetition_penalty
next_token_logits = next_token_logits / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(
next_token_logits, min(top_k, next_token_logits.size(-1))
)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
# Better stopping conditions
if input_ids.size(1) >= 512:
break
if next_token.item() == self.tokenizer.eos_token_id:
break
# Stop on double newline for cleaner outputs
if step > 10 and self.tokenizer.decode(input_ids[0, -2:]) == "\n\n":
break
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
outputs.append(generated_text)
return outputs[0] if num_return_sequences == 1 else "\n\n---\n\n".join(outputs)
generator = TextGenerator()
def generate_with_preset(prompt, preset, max_length, custom_temp, custom_top_k,
custom_top_p, custom_rep_pen, num_outputs):
"""Generate using preset or custom parameters."""
if not prompt.strip():
return "β οΈ Please enter a prompt!"
# Use preset if selected, otherwise use custom values
if preset != "custom":
config = CONFIGS[preset]
temp = config["temperature"]
top_k = config["top_k"]
top_p = config["top_p"]
rep_pen = config["repetition_penalty"]
else:
temp = custom_temp
top_k = custom_top_k
top_p = custom_top_p
rep_pen = custom_rep_pen
try:
result = generator.generate(
prompt=prompt,
max_length=int(max_length),
temperature=float(temp),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(rep_pen),
num_return_sequences=int(num_outputs)
)
return result
except Exception as e:
return f"β Error: {str(e)}"
# Build Gradio Interface
with gr.Blocks(title="Scholar Sage - Improved", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# π Scholar Sage - Language Model (Optimized)
A 45M parameter transformer trained on WikiText-2. **Use presets** for best results!
π‘ **Tips for Quality Output:**
- Use **"Balanced" preset** for general use
- Start with encyclopedia-style prompts (model trained on WikiText)
- Try longer prompts (10-20 words) for better context
- Experiment with different presets for different styles
""")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="π Enter Your Prompt",
placeholder="Example: The theory of relativity is a scientific theory that",
lines=4
)
preset_selector = gr.Radio(
choices=["balanced", "creative", "focused", "factual", "custom"],
value="balanced",
label="ποΈ Preset Configuration",
info="Balanced is recommended for most uses"
)
max_length = gr.Slider(
minimum=20, maximum=150, value=60, step=10,
label="π Max Length (tokens)"
)
num_outputs = gr.Slider(
minimum=1, maximum=3, value=1, step=1,
label="π’ Number of Outputs"
)
with gr.Accordion("βοΈ Custom Settings", open=False):
gr.Markdown("*Only used when 'custom' preset is selected*")
custom_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
custom_top_k = gr.Slider(0, 100, 40, step=5, label="Top-k")
custom_top_p = gr.Slider(0.0, 1.0, 0.9, step=0.05, label="Top-p")
custom_rep_pen = gr.Slider(1.0, 2.0, 1.3, step=0.1, label="Repetition Penalty")
generate_btn = gr.Button("π Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="β¨ Generated Text",
lines=18,
show_copy_button=True
)
# Example prompts optimized for WikiText-2 training
gr.Markdown("### π‘ Example Prompts (Optimized for this Model)")
gr.Examples(
examples=[
["The history of artificial intelligence began in", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
["Python programming language is a high-level", "factual", 60, 0.3, 20, 0.8, 1.4, 1],
["In the field of quantum mechanics,", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
["The United States is a country located in", "factual", 60, 0.3, 20, 0.8, 1.4, 1],
["Machine learning algorithms can be used to", "balanced", 60, 0.7, 40, 0.9, 1.3, 1],
],
inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k,
custom_top_p, custom_rep_pen, num_outputs],
)
generate_btn.click(
fn=generate_with_preset,
inputs=[prompt_input, preset_selector, max_length, custom_temp, custom_top_k,
custom_top_p, custom_rep_pen, num_outputs],
outputs=output_text
)
gr.Markdown("""
---
### π Understanding the Presets
- **Balanced** (default): Best for general encyclopedia-style text
- **Creative**: More diverse outputs, good for storytelling
- **Focused**: Deterministic, good for factual content
- **Factual**: Highest coherence, lowest creativity
- **Custom**: Manual control over all parameters
### β οΈ Model Limitations
This is a 45M parameter research model (vs GPT-3's 175B). It works best with:
- β
Encyclopedia-style content (trained on WikiText-2)
- β
Factual, informative text
- β
Short to medium generations (20-100 tokens)
It struggles with:
- β Creative fiction or dialogue
- β Very long context understanding
- β Highly specialized technical content
""")
if __name__ == "__main__":
demo.launch()
|