Fred808 commited on
Commit
972a07c
·
verified ·
1 Parent(s): 1a079db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -28,7 +28,7 @@ class Settings:
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
31
- MODEL_REPO = "https://huggingface.co/microsoft/florence-2-large"
32
 
33
  # Server settings
34
  TENSOR_SERVER_TIMEOUT = 30 # seconds
@@ -232,16 +232,30 @@ async def split_model_weights():
232
  raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}")
233
 
234
  # Create chunk metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  state.model_chunks[chunk_id] = ModelChunk(
236
  chunk_id=chunk_id,
237
  files=[f"chunk_{chunk_id}.bin"],
238
- config={
239
- "start_offset": start_pos,
240
- "size_bytes": current_chunk_size,
241
- "is_last_chunk": chunk_id == num_chunks - 1,
242
- "total_chunks": num_chunks,
243
- "original_file": os.path.basename(model_file)
244
- },
245
  size_bytes=current_chunk_size,
246
  status="ready"
247
  )
@@ -331,7 +345,10 @@ async def split_model_weights():
331
  "size_bytes": chunk_total_size,
332
  "num_parameters": sum(weights[k].nelement() for k in current_chunk),
333
  "input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0,
334
- "output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0
 
 
 
335
  }
336
  )
337
 
@@ -1130,9 +1147,11 @@ async def startup_event():
1130
  )
1131
  distribution_tasks.append(task)
1132
  print(f"[INFO] Sending chunk {chunk_id} to {server_url}")
1133
-
1134
  # Track assignments for future reference
1135
- chunk.server_assignments.append(server_url)
 
 
 
1136
 
1137
  if distribution_tasks:
1138
  print(f"[INFO] Distributing {len(distribution_tasks)} chunks...")
 
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
31
+ MODEL_REPO = "https://huggingface.co/facebook/opt-125m"
32
 
33
  # Server settings
34
  TENSOR_SERVER_TIMEOUT = 30 # seconds
 
232
  raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}")
233
 
234
  # Create chunk metadata
235
+ # Assign vocab_offset based on cumulative sizes of earlier chunks
236
+ # so that chunks map to disjoint vocab ranges for aggregation.
237
+ cumulative = 0
238
+ for cid, c in state.model_chunks.items():
239
+ try:
240
+ cumulative += int(c.config.get('shard_dim', c.config.get('size_bytes', 1)))
241
+ except Exception:
242
+ cumulative += 1
243
+
244
+ cfg = {
245
+ "start_offset": start_pos,
246
+ "size_bytes": current_chunk_size,
247
+ "is_last_chunk": chunk_id == num_chunks - 1,
248
+ "total_chunks": num_chunks,
249
+ "original_file": os.path.basename(model_file),
250
+ # minimal shard mapping; users should adjust shard_dim to real local vocab size
251
+ "vocab_offset": cumulative,
252
+ "shard_dim": int(current_chunk_size) if current_chunk_size > 0 else 1
253
+ }
254
+
255
  state.model_chunks[chunk_id] = ModelChunk(
256
  chunk_id=chunk_id,
257
  files=[f"chunk_{chunk_id}.bin"],
258
+ config=cfg,
 
 
 
 
 
 
259
  size_bytes=current_chunk_size,
260
  status="ready"
261
  )
 
345
  "size_bytes": chunk_total_size,
346
  "num_parameters": sum(weights[k].nelement() for k in current_chunk),
347
  "input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0,
348
+ "output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0,
349
+ # assign a vocab_offset cumulatively
350
+ "vocab_offset": sum(int(c.config.get('shard_dim', c.config.get('size_bytes', 1))) for c in state.model_chunks.values()),
351
+ "shard_dim": int(chunk_total_size)
352
  }
353
  )
354
 
 
1147
  )
1148
  distribution_tasks.append(task)
1149
  print(f"[INFO] Sending chunk {chunk_id} to {server_url}")
 
1150
  # Track assignments for future reference
1151
+ try:
1152
+ chunk.server_assignments.append(server_url)
1153
+ except Exception:
1154
+ pass
1155
 
1156
  if distribution_tasks:
1157
  print(f"[INFO] Distributing {len(distribution_tasks)} chunks...")