Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ class Settings:
|
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
| 31 |
-
MODEL_REPO = "https://huggingface.co/
|
| 32 |
|
| 33 |
# Server settings
|
| 34 |
TENSOR_SERVER_TIMEOUT = 30 # seconds
|
|
@@ -145,17 +145,19 @@ async def split_model_weights():
|
|
| 145 |
import math
|
| 146 |
|
| 147 |
# Load the full model weights
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
#
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# Calculate total model size and chunks
|
| 161 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|
|
|
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
| 31 |
+
MODEL_REPO = "https://huggingface.co/facebook/opt-125m"
|
| 32 |
|
| 33 |
# Server settings
|
| 34 |
TENSOR_SERVER_TIMEOUT = 30 # seconds
|
|
|
|
| 145 |
import math
|
| 146 |
|
| 147 |
# Load the full model weights
|
| 148 |
+
import torch
|
| 149 |
+
from safetensors.torch import load_file as load_safetensors
|
| 150 |
+
|
| 151 |
+
# Try safetensors first, then fallback to pytorch
|
| 152 |
+
try:
|
| 153 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 154 |
+
print(f"[INFO] Loading weights from safetensors file: {model_file}")
|
| 155 |
+
weights = load_safetensors(model_file)
|
| 156 |
+
except StopIteration:
|
| 157 |
+
# No safetensors file found, try pytorch
|
| 158 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
|
| 159 |
+
print(f"[INFO] Loading weights from PyTorch file: {model_file}")
|
| 160 |
+
weights = torch.load(model_file, map_location='cpu')
|
| 161 |
|
| 162 |
# Calculate total model size and chunks
|
| 163 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|