Update tensor_server.py
Browse files- tensor_server.py +41 -1
tensor_server.py
CHANGED
|
@@ -151,6 +151,8 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
|
|
| 151 |
self.config = config
|
| 152 |
self.start_offset = config.get('start_offset', 0)
|
| 153 |
self.size = config.get('size_bytes', 0)
|
|
|
|
|
|
|
| 154 |
|
| 155 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 156 |
# In a real implementation, this would process the input
|
|
@@ -160,8 +162,10 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
|
|
| 160 |
|
| 161 |
# Create and return the chunk buffer
|
| 162 |
chunk_model = ChunkBuffer(chunk_file, chunk_config)
|
|
|
|
|
|
|
| 163 |
print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
|
| 164 |
-
|
| 165 |
return chunk_model
|
| 166 |
|
| 167 |
except Exception as e:
|
|
@@ -206,6 +210,28 @@ async def load_model_chunk(chunk: ModelChunk):
|
|
| 206 |
# Store the chunk metadata
|
| 207 |
chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
|
| 208 |
state.chunk_configs = getattr(state, 'chunk_configs', {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
state.chunk_configs[chunk.chunk_id] = chunk
|
| 210 |
|
| 211 |
print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
|
|
@@ -239,6 +265,20 @@ async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)):
|
|
| 239 |
|
| 240 |
# Now load the chunk
|
| 241 |
chunk_model = load_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
state.loaded_chunks[chunk_id] = chunk_model
|
| 243 |
|
| 244 |
file_size = os.path.getsize(chunk_file)
|
|
|
|
| 151 |
self.config = config
|
| 152 |
self.start_offset = config.get('start_offset', 0)
|
| 153 |
self.size = config.get('size_bytes', 0)
|
| 154 |
+
# expose vocab_offset on the module for aggregator use
|
| 155 |
+
self.vocab_offset = int(config.get('vocab_offset', 0))
|
| 156 |
|
| 157 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 158 |
# In a real implementation, this would process the input
|
|
|
|
| 162 |
|
| 163 |
# Create and return the chunk buffer
|
| 164 |
chunk_model = ChunkBuffer(chunk_file, chunk_config)
|
| 165 |
+
# Ensure the chunk_model.config is the up-to-date config (including any assigned offsets)
|
| 166 |
+
chunk_model.config = chunk_config
|
| 167 |
print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
|
| 168 |
+
|
| 169 |
return chunk_model
|
| 170 |
|
| 171 |
except Exception as e:
|
|
|
|
| 210 |
# Store the chunk metadata
|
| 211 |
chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
|
| 212 |
state.chunk_configs = getattr(state, 'chunk_configs', {})
|
| 213 |
+
|
| 214 |
+
# Ensure a vocab_offset is present; if not, assign a non-overlapping offset
|
| 215 |
+
cfg = chunk.config or {}
|
| 216 |
+
if 'vocab_offset' not in cfg:
|
| 217 |
+
# Compute next available offset from existing registered chunks
|
| 218 |
+
max_end = 0
|
| 219 |
+
for existing in state.chunk_configs.values():
|
| 220 |
+
try:
|
| 221 |
+
e_cfg = existing.config if hasattr(existing, 'config') else existing
|
| 222 |
+
e_offset = int(e_cfg.get('vocab_offset', 0))
|
| 223 |
+
e_shard = int(e_cfg.get('shard_dim', e_cfg.get('size', 1) or 1))
|
| 224 |
+
max_end = max(max_end, e_offset + e_shard)
|
| 225 |
+
except Exception:
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
# If this chunk declares a shard_dim, use it; otherwise default to 1
|
| 229 |
+
shard_dim = int(cfg.get('shard_dim', cfg.get('size', 1) or 1))
|
| 230 |
+
cfg['vocab_offset'] = max_end
|
| 231 |
+
cfg['shard_dim'] = cfg.get('shard_dim', shard_dim)
|
| 232 |
+
|
| 233 |
+
# Store back the possibly-updated config
|
| 234 |
+
chunk.config = cfg
|
| 235 |
state.chunk_configs[chunk.chunk_id] = chunk
|
| 236 |
|
| 237 |
print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
|
|
|
|
| 265 |
|
| 266 |
# Now load the chunk
|
| 267 |
chunk_model = load_chunk(chunk)
|
| 268 |
+
# Ensure the loaded module has the registered config (including vocab_offset)
|
| 269 |
+
try:
|
| 270 |
+
registered = getattr(state, 'chunk_configs', {}).get(chunk_id)
|
| 271 |
+
if registered is not None:
|
| 272 |
+
# registered is a ModelChunk; merge config into module
|
| 273 |
+
reg_cfg = registered.config or {}
|
| 274 |
+
if hasattr(chunk_model, 'config'):
|
| 275 |
+
chunk_model.config.update(reg_cfg)
|
| 276 |
+
else:
|
| 277 |
+
chunk_model.config = reg_cfg
|
| 278 |
+
# expose vocab_offset on module
|
| 279 |
+
chunk_model.vocab_offset = int(reg_cfg.get('vocab_offset', 0))
|
| 280 |
+
except Exception:
|
| 281 |
+
pass
|
| 282 |
state.loaded_chunks[chunk_id] = chunk_model
|
| 283 |
|
| 284 |
file_size = os.path.getsize(chunk_file)
|