Fred808 commited on
Commit
7bcdb30
·
verified ·
1 Parent(s): 501d5b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -144,9 +144,9 @@ async def split_model_weights():
144
  import torch
145
  import math
146
 
147
- # Load the full model weights
148
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
149
- weights = torch.load(model_file, map_location='cpu')
150
 
151
  # Calculate total model size and chunks
152
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
 
144
  import torch
145
  import math
146
 
147
+ # Load the full model weights without forcing CPU - let tensor servers handle device placement
148
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
149
+ weights = torch.load(model_file, weights_only=False) # Explicitly allow non-weights for compatibility
150
 
151
  # Calculate total model size and chunks
152
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())