Fred808 commited on
Commit
ca8b343
·
verified ·
1 Parent(s): 7bcdb30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -144,9 +144,18 @@ async def split_model_weights():
144
  import torch
145
  import math
146
 
147
- # Load the full model weights without forcing CPU - let tensor servers handle device placement
148
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
149
- weights = torch.load(model_file, weights_only=False) # Explicitly allow non-weights for compatibility
 
 
 
 
 
 
 
 
 
150
 
151
  # Calculate total model size and chunks
152
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
 
144
  import torch
145
  import math
146
 
147
+ # Load the full model weights
148
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
149
+ loaded_weights = torch.load(model_file, weights_only=False) # Explicitly allow non-weights for compatibility
150
+
151
+ # For OPT models, weights are stored under model.decoder
152
+ if isinstance(loaded_weights, dict):
153
+ if 'model.decoder' in str(list(loaded_weights.keys())):
154
+ # Get all weights that are part of the decoder
155
+ weights = {k: v for k, v in loaded_weights.items() if k.startswith('model.decoder')}
156
+ else:
157
+ # Just use all weights if no decoder prefix found
158
+ weights = loaded_weights
159
 
160
  # Calculate total model size and chunks
161
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())