Fred808 commited on
Commit
63bc7d5
·
verified ·
1 Parent(s): e4f843d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -28,7 +28,7 @@ class Settings:
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
31
- MODEL_REPO = "https://huggingface.co/microsoft/florence-2-large"
32
 
33
  # Server settings
34
  TENSOR_SERVER_TIMEOUT = 30 # seconds
@@ -145,17 +145,19 @@ async def split_model_weights():
145
  import math
146
 
147
  # Load the full model weights
148
- model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
149
- loaded_weights = torch.load(model_file, weights_only=False) # Explicitly allow non-weights for compatibility
150
-
151
- # For OPT models, weights are stored under model.decoder
152
- if isinstance(loaded_weights, dict):
153
- if 'model.decoder' in str(list(loaded_weights.keys())):
154
- # Get all weights that are part of the decoder
155
- weights = {k: v for k, v in loaded_weights.items() if k.startswith('model.decoder')}
156
- else:
157
- # Just use all weights if no decoder prefix found
158
- weights = loaded_weights
 
 
159
 
160
  # Calculate total model size and chunks
161
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
 
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
31
+ MODEL_REPO = "https://huggingface.co/facebook/opt-125m"
32
 
33
  # Server settings
34
  TENSOR_SERVER_TIMEOUT = 30 # seconds
 
145
  import math
146
 
147
  # Load the full model weights
148
+ import torch
149
+ from safetensors.torch import load_file as load_safetensors
150
+
151
+ # Try safetensors first, then fallback to pytorch
152
+ try:
153
+ model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
154
+ print(f"[INFO] Loading weights from safetensors file: {model_file}")
155
+ weights = load_safetensors(model_file)
156
+ except StopIteration:
157
+ # No safetensors file found, try pytorch
158
+ model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
159
+ print(f"[INFO] Loading weights from PyTorch file: {model_file}")
160
+ weights = torch.load(model_file, map_location='cpu')
161
 
162
  # Calculate total model size and chunks
163
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())