Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ class Settings:
|
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
| 31 |
-
MODEL_REPO = "https://huggingface.co/
|
| 32 |
|
| 33 |
# Server settings
|
| 34 |
TENSOR_SERVER_TIMEOUT = 30 # seconds
|
|
@@ -232,16 +232,30 @@ async def split_model_weights():
|
|
| 232 |
raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}")
|
| 233 |
|
| 234 |
# Create chunk metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
state.model_chunks[chunk_id] = ModelChunk(
|
| 236 |
chunk_id=chunk_id,
|
| 237 |
files=[f"chunk_{chunk_id}.bin"],
|
| 238 |
-
config=
|
| 239 |
-
"start_offset": start_pos,
|
| 240 |
-
"size_bytes": current_chunk_size,
|
| 241 |
-
"is_last_chunk": chunk_id == num_chunks - 1,
|
| 242 |
-
"total_chunks": num_chunks,
|
| 243 |
-
"original_file": os.path.basename(model_file)
|
| 244 |
-
},
|
| 245 |
size_bytes=current_chunk_size,
|
| 246 |
status="ready"
|
| 247 |
)
|
|
@@ -331,7 +345,10 @@ async def split_model_weights():
|
|
| 331 |
"size_bytes": chunk_total_size,
|
| 332 |
"num_parameters": sum(weights[k].nelement() for k in current_chunk),
|
| 333 |
"input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0,
|
| 334 |
-
"output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0
|
|
|
|
|
|
|
|
|
|
| 335 |
}
|
| 336 |
)
|
| 337 |
|
|
@@ -1130,9 +1147,11 @@ async def startup_event():
|
|
| 1130 |
)
|
| 1131 |
distribution_tasks.append(task)
|
| 1132 |
print(f"[INFO] Sending chunk {chunk_id} to {server_url}")
|
| 1133 |
-
|
| 1134 |
# Track assignments for future reference
|
| 1135 |
-
|
|
|
|
|
|
|
|
|
|
| 1136 |
|
| 1137 |
if distribution_tasks:
|
| 1138 |
print(f"[INFO] Distributing {len(distribution_tasks)} chunks...")
|
|
|
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
| 31 |
+
MODEL_REPO = "https://huggingface.co/facebook/opt-125m"
|
| 32 |
|
| 33 |
# Server settings
|
| 34 |
TENSOR_SERVER_TIMEOUT = 30 # seconds
|
|
|
|
| 232 |
raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}")
|
| 233 |
|
| 234 |
# Create chunk metadata
|
| 235 |
+
# Assign vocab_offset based on cumulative sizes of earlier chunks
|
| 236 |
+
# so that chunks map to disjoint vocab ranges for aggregation.
|
| 237 |
+
cumulative = 0
|
| 238 |
+
for cid, c in state.model_chunks.items():
|
| 239 |
+
try:
|
| 240 |
+
cumulative += int(c.config.get('shard_dim', c.config.get('size_bytes', 1)))
|
| 241 |
+
except Exception:
|
| 242 |
+
cumulative += 1
|
| 243 |
+
|
| 244 |
+
cfg = {
|
| 245 |
+
"start_offset": start_pos,
|
| 246 |
+
"size_bytes": current_chunk_size,
|
| 247 |
+
"is_last_chunk": chunk_id == num_chunks - 1,
|
| 248 |
+
"total_chunks": num_chunks,
|
| 249 |
+
"original_file": os.path.basename(model_file),
|
| 250 |
+
# minimal shard mapping; users should adjust shard_dim to real local vocab size
|
| 251 |
+
"vocab_offset": cumulative,
|
| 252 |
+
"shard_dim": int(current_chunk_size) if current_chunk_size > 0 else 1
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
state.model_chunks[chunk_id] = ModelChunk(
|
| 256 |
chunk_id=chunk_id,
|
| 257 |
files=[f"chunk_{chunk_id}.bin"],
|
| 258 |
+
config=cfg,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
size_bytes=current_chunk_size,
|
| 260 |
status="ready"
|
| 261 |
)
|
|
|
|
| 345 |
"size_bytes": chunk_total_size,
|
| 346 |
"num_parameters": sum(weights[k].nelement() for k in current_chunk),
|
| 347 |
"input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0,
|
| 348 |
+
"output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0,
|
| 349 |
+
# assign a vocab_offset cumulatively
|
| 350 |
+
"vocab_offset": sum(int(c.config.get('shard_dim', c.config.get('size_bytes', 1))) for c in state.model_chunks.values()),
|
| 351 |
+
"shard_dim": int(chunk_total_size)
|
| 352 |
}
|
| 353 |
)
|
| 354 |
|
|
|
|
| 1147 |
)
|
| 1148 |
distribution_tasks.append(task)
|
| 1149 |
print(f"[INFO] Sending chunk {chunk_id} to {server_url}")
|
|
|
|
| 1150 |
# Track assignments for future reference
|
| 1151 |
+
try:
|
| 1152 |
+
chunk.server_assignments.append(server_url)
|
| 1153 |
+
except Exception:
|
| 1154 |
+
pass
|
| 1155 |
|
| 1156 |
if distribution_tasks:
|
| 1157 |
print(f"[INFO] Distributing {len(distribution_tasks)} chunks...")
|