Update app.py
Browse files
app.py
CHANGED
|
@@ -144,9 +144,9 @@ async def split_model_weights():
|
|
| 144 |
import torch
|
| 145 |
import math
|
| 146 |
|
| 147 |
-
# Load the full model weights
|
| 148 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
|
| 149 |
-
weights = torch.load(model_file,
|
| 150 |
|
| 151 |
# Calculate total model size and chunks
|
| 152 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|
|
|
|
| 144 |
import torch
|
| 145 |
import math
|
| 146 |
|
| 147 |
+
# Load the full model weights without forcing CPU - let tensor servers handle device placement
|
| 148 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
|
| 149 |
+
weights = torch.load(model_file, weights_only=False) # Explicitly allow non-weights for compatibility
|
| 150 |
|
| 151 |
# Calculate total model size and chunks
|
| 152 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|