Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -114,14 +114,35 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
| 114 |
config = GPTConfig()
|
| 115 |
model = GPT(config)
|
| 116 |
|
| 117 |
-
# Try to load model
|
| 118 |
try:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
except FileNotFoundError:
|
| 123 |
print("Warning: Model checkpoint not found. Using untrained model.")
|
| 124 |
# Model will be randomly initialized - not ideal but won't crash
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
model.to(device)
|
| 127 |
model.eval()
|
|
|
|
| 114 |
config = GPTConfig()
|
| 115 |
model = GPT(config)
|
| 116 |
|
| 117 |
+
# Try to load model from HuggingFace Model Hub first, then local file
|
| 118 |
try:
|
| 119 |
+
from huggingface_hub import hf_hub_download
|
| 120 |
+
import os
|
| 121 |
+
|
| 122 |
+
# Try to get model path from environment variable or use default
|
| 123 |
+
repo_id = os.getenv('HF_MODEL_REPO', 'YOUR_USERNAME/gpt2-shakespeare-124m') # Update with your repo
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
model_path = hf_hub_download(
|
| 127 |
+
repo_id=repo_id,
|
| 128 |
+
filename="model_checkpoint_final.pt",
|
| 129 |
+
cache_dir=None
|
| 130 |
+
)
|
| 131 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 132 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 133 |
+
print(f"Model loaded from HuggingFace Hub: {repo_id}")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"Could not load from Hub ({e}), trying local file...")
|
| 136 |
+
# Fallback to local file
|
| 137 |
+
checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
|
| 138 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 139 |
+
print("Model loaded from local checkpoint")
|
| 140 |
except FileNotFoundError:
|
| 141 |
print("Warning: Model checkpoint not found. Using untrained model.")
|
| 142 |
# Model will be randomly initialized - not ideal but won't crash
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Error loading model: {e}")
|
| 145 |
+
print("Using untrained model as fallback.")
|
| 146 |
|
| 147 |
model.to(device)
|
| 148 |
model.eval()
|