Update app.py
Browse files
app.py
CHANGED
|
@@ -144,9 +144,18 @@ async def split_model_weights():
|
|
| 144 |
import torch
|
| 145 |
import math
|
| 146 |
|
| 147 |
-
# Load the full model weights
|
| 148 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
# Calculate total model size and chunks
|
| 152 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|
|
|
|
| 144 |
import torch
|
| 145 |
import math
|
| 146 |
|
| 147 |
+
# Load the full model weights
|
| 148 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
|
| 149 |
+
loaded_weights = torch.load(model_file, weights_only=False) # Explicitly allow non-weights for compatibility
|
| 150 |
+
|
| 151 |
+
# For OPT models, weights are stored under model.decoder
|
| 152 |
+
if isinstance(loaded_weights, dict):
|
| 153 |
+
if 'model.decoder' in str(list(loaded_weights.keys())):
|
| 154 |
+
# Get all weights that are part of the decoder
|
| 155 |
+
weights = {k: v for k, v in loaded_weights.items() if k.startswith('model.decoder')}
|
| 156 |
+
else:
|
| 157 |
+
# Just use all weights if no decoder prefix found
|
| 158 |
+
weights = loaded_weights
|
| 159 |
|
| 160 |
# Calculate total model size and chunks
|
| 161 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|