Fred808 commited on
Commit
94a6cd4
·
verified ·
1 Parent(s): e8316fa

Update tensor_server.py

Browse files
Files changed (1) hide show
  1. tensor_server.py +41 -1
tensor_server.py CHANGED
@@ -151,6 +151,8 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
151
  self.config = config
152
  self.start_offset = config.get('start_offset', 0)
153
  self.size = config.get('size_bytes', 0)
 
 
154
 
155
  def forward(self, x: torch.Tensor) -> torch.Tensor:
156
  # In a real implementation, this would process the input
@@ -160,8 +162,10 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
160
 
161
  # Create and return the chunk buffer
162
  chunk_model = ChunkBuffer(chunk_file, chunk_config)
 
 
163
  print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
164
-
165
  return chunk_model
166
 
167
  except Exception as e:
@@ -206,6 +210,28 @@ async def load_model_chunk(chunk: ModelChunk):
206
  # Store the chunk metadata
207
  chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
208
  state.chunk_configs = getattr(state, 'chunk_configs', {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  state.chunk_configs[chunk.chunk_id] = chunk
210
 
211
  print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
@@ -239,6 +265,20 @@ async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)):
239
 
240
  # Now load the chunk
241
  chunk_model = load_chunk(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  state.loaded_chunks[chunk_id] = chunk_model
243
 
244
  file_size = os.path.getsize(chunk_file)
 
151
  self.config = config
152
  self.start_offset = config.get('start_offset', 0)
153
  self.size = config.get('size_bytes', 0)
154
+ # expose vocab_offset on the module for aggregator use
155
+ self.vocab_offset = int(config.get('vocab_offset', 0))
156
 
157
  def forward(self, x: torch.Tensor) -> torch.Tensor:
158
  # In a real implementation, this would process the input
 
162
 
163
  # Create and return the chunk buffer
164
  chunk_model = ChunkBuffer(chunk_file, chunk_config)
165
+ # Ensure the chunk_model.config is the up-to-date config (including any assigned offsets)
166
+ chunk_model.config = chunk_config
167
  print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
168
+
169
  return chunk_model
170
 
171
  except Exception as e:
 
210
  # Store the chunk metadata
211
  chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
212
  state.chunk_configs = getattr(state, 'chunk_configs', {})
213
+
214
+ # Ensure a vocab_offset is present; if not, assign a non-overlapping offset
215
+ cfg = chunk.config or {}
216
+ if 'vocab_offset' not in cfg:
217
+ # Compute next available offset from existing registered chunks
218
+ max_end = 0
219
+ for existing in state.chunk_configs.values():
220
+ try:
221
+ e_cfg = existing.config if hasattr(existing, 'config') else existing
222
+ e_offset = int(e_cfg.get('vocab_offset', 0))
223
+ e_shard = int(e_cfg.get('shard_dim', e_cfg.get('size', 1) or 1))
224
+ max_end = max(max_end, e_offset + e_shard)
225
+ except Exception:
226
+ continue
227
+
228
+ # If this chunk declares a shard_dim, use it; otherwise default to 1
229
+ shard_dim = int(cfg.get('shard_dim', cfg.get('size', 1) or 1))
230
+ cfg['vocab_offset'] = max_end
231
+ cfg['shard_dim'] = cfg.get('shard_dim', shard_dim)
232
+
233
+ # Store back the possibly-updated config
234
+ chunk.config = cfg
235
  state.chunk_configs[chunk.chunk_id] = chunk
236
 
237
  print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
 
265
 
266
  # Now load the chunk
267
  chunk_model = load_chunk(chunk)
268
+ # Ensure the loaded module has the registered config (including vocab_offset)
269
+ try:
270
+ registered = getattr(state, 'chunk_configs', {}).get(chunk_id)
271
+ if registered is not None:
272
+ # registered is a ModelChunk; merge config into module
273
+ reg_cfg = registered.config or {}
274
+ if hasattr(chunk_model, 'config'):
275
+ chunk_model.config.update(reg_cfg)
276
+ else:
277
+ chunk_model.config = reg_cfg
278
+ # expose vocab_offset on module
279
+ chunk_model.vocab_offset = int(reg_cfg.get('vocab_offset', 0))
280
+ except Exception:
281
+ pass
282
  state.loaded_chunks[chunk_id] = chunk_model
283
 
284
  file_size = os.path.getsize(chunk_file)