Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -114,35 +114,97 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
| 114 |
config = GPTConfig()
|
| 115 |
model = GPT(config)
|
| 116 |
|
|
|
|
|
|
|
| 117 |
# Try to load model from HuggingFace Model Hub first, then local file
|
| 118 |
try:
|
| 119 |
from huggingface_hub import hf_hub_download
|
| 120 |
import os
|
| 121 |
|
| 122 |
# Try to get model path from environment variable or use default
|
| 123 |
-
repo_id = os.getenv('HF_MODEL_REPO', '
|
| 124 |
|
| 125 |
try:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
except Exception as e:
|
| 135 |
-
print(f"Could not load from Hub ({e}), trying local file...")
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
except FileNotFoundError:
|
| 141 |
-
print("Warning: Model checkpoint not found. Using untrained model.")
|
| 142 |
-
# Model will be randomly initialized - not ideal but won't crash
|
| 143 |
except Exception as e:
|
| 144 |
-
print(f"Error loading model: {e}")
|
| 145 |
-
print("Using untrained model as fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
model.to(device)
|
| 148 |
model.eval()
|
|
@@ -154,21 +216,47 @@ enc = tiktoken.get_encoding('gpt2')
|
|
| 154 |
def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
|
| 155 |
"""Generate text from prompt"""
|
| 156 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
# Encode prompt
|
| 158 |
tokens = enc.encode(prompt)
|
|
|
|
|
|
|
|
|
|
| 159 |
tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
|
| 160 |
|
| 161 |
# Generate
|
| 162 |
with torch.no_grad():
|
| 163 |
-
for
|
| 164 |
# Forward pass
|
| 165 |
logits, _ = model(tokens)
|
| 166 |
-
logits = logits[:, -1, :] / temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
|
| 169 |
-
topk_probs, topk_indices = torch.topk(F.softmax(logits, dim=-1), top_k, dim=-1)
|
| 170 |
-
ix = torch.multinomial(topk_probs, 1)
|
| 171 |
-
next_token = torch.gather(topk_indices, -1, ix)
|
| 172 |
|
| 173 |
# Append to sequence
|
| 174 |
tokens = torch.cat([tokens, next_token], dim=1)
|
|
@@ -181,14 +269,21 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
|
|
| 181 |
generated_text = enc.decode(tokens[0].tolist())
|
| 182 |
return generated_text
|
| 183 |
except Exception as e:
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
# Create Gradio interface
|
| 188 |
with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
# π GPT-2 124M Shakespeare Language Model
|
| 191 |
|
|
|
|
|
|
|
| 192 |
This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
|
| 193 |
|
| 194 |
**Training Results:**
|
|
@@ -197,6 +292,8 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
|
|
| 197 |
- Training Steps: 1,637
|
| 198 |
|
| 199 |
Enter a prompt below to generate Shakespeare-style text!
|
|
|
|
|
|
|
| 200 |
""")
|
| 201 |
|
| 202 |
with gr.Row():
|
|
@@ -238,7 +335,7 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
|
|
| 238 |
)
|
| 239 |
|
| 240 |
# Example prompts
|
| 241 |
-
gr.Markdown("### Example Prompts:")
|
| 242 |
examples = gr.Examples(
|
| 243 |
examples=[
|
| 244 |
["First Citizen:"],
|
|
@@ -246,6 +343,15 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
|
|
| 246 |
["To be or not"],
|
| 247 |
["HAMLET:"],
|
| 248 |
["MACBETH:"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
],
|
| 250 |
inputs=prompt_input
|
| 251 |
)
|
|
|
|
| 114 |
config = GPTConfig()
|
| 115 |
model = GPT(config)
|
| 116 |
|
| 117 |
+
model_loaded = False
|
| 118 |
+
|
| 119 |
# Try to load model from HuggingFace Model Hub first, then local file
|
| 120 |
try:
|
| 121 |
from huggingface_hub import hf_hub_download
|
| 122 |
import os
|
| 123 |
|
| 124 |
# Try to get model path from environment variable or use default
|
| 125 |
+
repo_id = os.getenv('HF_MODEL_REPO', 'shwethd/gpt2-shakespeare-124m')
|
| 126 |
|
| 127 |
try:
|
| 128 |
+
print(f"Attempting to load from HuggingFace Hub: {repo_id}")
|
| 129 |
+
|
| 130 |
+
# Try SafeTensors first (more secure, no pickle issues)
|
| 131 |
+
try:
|
| 132 |
+
from safetensors.torch import load_file
|
| 133 |
+
try:
|
| 134 |
+
model_path = hf_hub_download(
|
| 135 |
+
repo_id=repo_id,
|
| 136 |
+
filename="model.safetensors",
|
| 137 |
+
cache_dir=None
|
| 138 |
+
)
|
| 139 |
+
state_dict = load_file(model_path, device=device)
|
| 140 |
+
model.load_state_dict(state_dict)
|
| 141 |
+
model_loaded = True
|
| 142 |
+
print(f"β
Model loaded successfully from SafeTensors: {repo_id}")
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"SafeTensors not found ({e}), trying .pt file...")
|
| 145 |
+
# Fallback to .pt file
|
| 146 |
+
model_path = hf_hub_download(
|
| 147 |
+
repo_id=repo_id,
|
| 148 |
+
filename="model_checkpoint_final.pt",
|
| 149 |
+
cache_dir=None
|
| 150 |
+
)
|
| 151 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 152 |
+
|
| 153 |
+
# Handle different checkpoint formats
|
| 154 |
+
if 'model_state_dict' in checkpoint:
|
| 155 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 156 |
+
elif 'state_dict' in checkpoint:
|
| 157 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 158 |
+
else:
|
| 159 |
+
# If checkpoint is the state dict itself
|
| 160 |
+
model.load_state_dict(checkpoint)
|
| 161 |
+
|
| 162 |
+
model_loaded = True
|
| 163 |
+
print(f"β
Model loaded successfully from HuggingFace Hub: {repo_id}")
|
| 164 |
+
except ImportError:
|
| 165 |
+
# safetensors not installed, use .pt file
|
| 166 |
+
model_path = hf_hub_download(
|
| 167 |
+
repo_id=repo_id,
|
| 168 |
+
filename="model_checkpoint_final.pt",
|
| 169 |
+
cache_dir=None
|
| 170 |
+
)
|
| 171 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 172 |
+
|
| 173 |
+
# Handle different checkpoint formats
|
| 174 |
+
if 'model_state_dict' in checkpoint:
|
| 175 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 176 |
+
elif 'state_dict' in checkpoint:
|
| 177 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 178 |
+
else:
|
| 179 |
+
# If checkpoint is the state dict itself
|
| 180 |
+
model.load_state_dict(checkpoint)
|
| 181 |
+
|
| 182 |
+
model_loaded = True
|
| 183 |
+
print(f"β
Model loaded successfully from HuggingFace Hub: {repo_id}")
|
| 184 |
except Exception as e:
|
| 185 |
+
print(f"β οΈ Could not load from Hub ({e}), trying local file...")
|
| 186 |
+
try:
|
| 187 |
+
# Fallback to local file
|
| 188 |
+
checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
|
| 189 |
+
if 'model_state_dict' in checkpoint:
|
| 190 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 191 |
+
elif 'state_dict' in checkpoint:
|
| 192 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 193 |
+
else:
|
| 194 |
+
model.load_state_dict(checkpoint)
|
| 195 |
+
model_loaded = True
|
| 196 |
+
print("β
Model loaded from local checkpoint")
|
| 197 |
+
except Exception as e2:
|
| 198 |
+
print(f"β Could not load from local file either: {e2}")
|
| 199 |
except FileNotFoundError:
|
| 200 |
+
print("β Warning: Model checkpoint not found. Using untrained model.")
|
|
|
|
| 201 |
except Exception as e:
|
| 202 |
+
print(f"β Error loading model: {e}")
|
| 203 |
+
print("β οΈ Using untrained model as fallback - output will be random!")
|
| 204 |
+
|
| 205 |
+
if not model_loaded:
|
| 206 |
+
print("β οΈ WARNING: Model is using random weights! Generation will be nonsensical.")
|
| 207 |
+
print("Please ensure model_checkpoint_final.pt is uploaded to HuggingFace Model Hub.")
|
| 208 |
|
| 209 |
model.to(device)
|
| 210 |
model.eval()
|
|
|
|
| 216 |
def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
|
| 217 |
"""Generate text from prompt"""
|
| 218 |
try:
|
| 219 |
+
if not model_loaded:
|
| 220 |
+
return "β Error: Model not loaded correctly. Please check that model_checkpoint_final.pt is uploaded to HuggingFace Model Hub (shwethd/gpt2-shakespeare-124m)."
|
| 221 |
+
|
| 222 |
+
# Validate inputs
|
| 223 |
+
if not prompt or len(prompt.strip()) == 0:
|
| 224 |
+
return "Please enter a prompt."
|
| 225 |
+
|
| 226 |
+
temperature = max(0.1, min(2.0, temperature)) # Clamp temperature
|
| 227 |
+
top_k = max(1, min(100, int(top_k))) # Clamp top_k
|
| 228 |
+
max_new_tokens = max(1, min(200, int(max_new_tokens))) # Clamp max tokens
|
| 229 |
+
|
| 230 |
# Encode prompt
|
| 231 |
tokens = enc.encode(prompt)
|
| 232 |
+
if len(tokens) == 0:
|
| 233 |
+
return "Error: Could not encode prompt."
|
| 234 |
+
|
| 235 |
tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
|
| 236 |
|
| 237 |
# Generate
|
| 238 |
with torch.no_grad():
|
| 239 |
+
for i in range(max_new_tokens):
|
| 240 |
# Forward pass
|
| 241 |
logits, _ = model(tokens)
|
| 242 |
+
logits = logits[:, -1, :] / max(temperature, 0.1) # Avoid division by zero
|
| 243 |
+
|
| 244 |
+
# Apply top-k filtering
|
| 245 |
+
if top_k < logits.size(-1):
|
| 246 |
+
topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
|
| 247 |
+
# Create filtered logits
|
| 248 |
+
filtered_logits = torch.full_like(logits, float('-inf'))
|
| 249 |
+
filtered_logits.scatter_(-1, topk_indices, topk_logits)
|
| 250 |
+
logits = filtered_logits
|
| 251 |
+
|
| 252 |
+
# Sample from distribution
|
| 253 |
+
probs = F.softmax(logits, dim=-1)
|
| 254 |
+
|
| 255 |
+
# Avoid NaN
|
| 256 |
+
if torch.isnan(probs).any():
|
| 257 |
+
probs = torch.ones_like(probs) / probs.size(-1)
|
| 258 |
|
| 259 |
+
next_token = torch.multinomial(probs, 1)
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# Append to sequence
|
| 262 |
tokens = torch.cat([tokens, next_token], dim=1)
|
|
|
|
| 269 |
generated_text = enc.decode(tokens[0].tolist())
|
| 270 |
return generated_text
|
| 271 |
except Exception as e:
|
| 272 |
+
import traceback
|
| 273 |
+
return f"β Error during generation: {str(e)}\n\nPlease check:\n1. Model is uploaded to HuggingFace Model Hub\n2. Repository name is correct: shwethd/gpt2-shakespeare-124m\n3. File name is exactly: model_checkpoint_final.pt"
|
| 274 |
|
| 275 |
|
| 276 |
# Create Gradio interface
|
| 277 |
with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
|
| 278 |
+
# Status indicator
|
| 279 |
+
status_color = "π’" if model_loaded else "π΄"
|
| 280 |
+
status_text = "Model loaded successfully!" if model_loaded else "β οΈ Model not loaded - check HuggingFace Model Hub!"
|
| 281 |
+
|
| 282 |
+
gr.Markdown(f"""
|
| 283 |
# π GPT-2 124M Shakespeare Language Model
|
| 284 |
|
| 285 |
+
{status_color} **Status:** {status_text}
|
| 286 |
+
|
| 287 |
This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
|
| 288 |
|
| 289 |
**Training Results:**
|
|
|
|
| 292 |
- Training Steps: 1,637
|
| 293 |
|
| 294 |
Enter a prompt below to generate Shakespeare-style text!
|
| 295 |
+
|
| 296 |
+
{"β οΈ **Note:** If you see garbled/random text, the model may not have loaded correctly. Check the logs and ensure the model is uploaded to HuggingFace Model Hub: `shwethd/gpt2-shakespeare-124m`" if not model_loaded else ""}
|
| 297 |
""")
|
| 298 |
|
| 299 |
with gr.Row():
|
|
|
|
| 335 |
)
|
| 336 |
|
| 337 |
# Example prompts
|
| 338 |
+
gr.Markdown("### Example Prompts (Click to try):")
|
| 339 |
examples = gr.Examples(
|
| 340 |
examples=[
|
| 341 |
["First Citizen:"],
|
|
|
|
| 343 |
["To be or not"],
|
| 344 |
["HAMLET:"],
|
| 345 |
["MACBETH:"],
|
| 346 |
+
["JULIET:"],
|
| 347 |
+
["KING:"],
|
| 348 |
+
["LADY MACBETH:"],
|
| 349 |
+
["OTHELLO:"],
|
| 350 |
+
["What light through yonder"],
|
| 351 |
+
["All the world's a stage"],
|
| 352 |
+
["Double, double toil and trouble"],
|
| 353 |
+
["Friends, Romans, countrymen"],
|
| 354 |
+
["A rose by any other name"],
|
| 355 |
],
|
| 356 |
inputs=prompt_input
|
| 357 |
)
|