Fred808 commited on
Commit
15bcc79
·
verified ·
1 Parent(s): 45e602a

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. tensor_server.py +75 -25
requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn==0.23.2
3
  torch>=2.0.0
4
  numpy>=1.24.0
5
  psutil>=5.9.0
6
- pydantic>=2.0.0
 
 
3
  torch>=2.0.0
4
  numpy>=1.24.0
5
  psutil>=5.9.0
6
+ pydantic>=2.0.0
7
+ python-multipart
tensor_server.py CHANGED
@@ -18,7 +18,7 @@ class Settings:
18
  SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
 
20
  # The IP or hostname where this tensor server is accessible
21
- PUBLIC_URL = os.getenv("PUBLIC_URL", f"https://fred808-ilob.hf.space")
22
 
23
  # URLs for other services (should be actual IP addresses or hostnames)
24
  CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
@@ -132,27 +132,34 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
132
  os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
 
134
  # Get chunk configuration
135
- input_size = chunk.config["input_size"]
136
- output_size = chunk.config["output_size"]
137
- weight_keys = chunk.config["weight_keys"]
138
-
139
- # Create a simple linear transformation for this chunk
140
- chunk_model = torch.nn.Linear(input_size, output_size)
141
- chunk_model = chunk_model.to(Settings.DEVICE)
142
-
143
- # Load the weights
144
  chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
145
- if os.path.exists(chunk_file):
146
- weights = torch.load(chunk_file, map_location=Settings.DEVICE)
 
147
 
148
- # Initialize weights from the loaded state dict
149
- with torch.no_grad():
150
- # Combine weights if multiple keys
151
- if len(weight_keys) > 1:
152
- combined_weight = torch.cat([weights[k] for k in weight_keys], dim=0)
153
- chunk_model.weight.copy_(combined_weight)
154
- else:
155
- chunk_model.weight.copy_(weights[weight_keys[0]])
 
 
 
 
 
 
 
 
 
 
156
 
157
  return chunk_model
158
 
@@ -186,18 +193,61 @@ async def get_metrics():
186
  """Get current server metrics"""
187
  return await collect_metrics()
188
 
 
 
189
  @app.post("/load_chunk")
190
  async def load_model_chunk(chunk: ModelChunk):
191
- """Load a model chunk into memory"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  try:
193
- # Load the chunk
 
 
 
 
 
 
 
 
 
 
 
194
  chunk_model = load_chunk(chunk)
195
- state.loaded_chunks[chunk.chunk_id] = chunk_model
 
 
 
196
 
197
  return {
198
  "status": "loaded",
199
- "chunk_id": chunk.chunk_id,
200
- "device": str(next(chunk_model.parameters()).device)
 
201
  }
202
 
203
  except Exception as e:
 
18
  SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
 
20
  # The IP or hostname where this tensor server is accessible
21
+ PUBLIC_URL = os.getenv("PUBLIC_URL", f"http://192.168.1.101:8001")
22
 
23
  # URLs for other services (should be actual IP addresses or hostnames)
24
  CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
 
132
  os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
 
134
  # Get chunk configuration
135
+ chunk_config = chunk.config
136
+ if "original_file" not in chunk_config:
137
+ raise ValueError("Missing original_file in chunk configuration")
138
+
139
+ # Save chunk data to file
 
 
 
 
140
  chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
141
+ if not os.path.exists(chunk_file):
142
+ # We'll need to receive the actual chunk data in a separate request
143
+ raise ValueError(f"Chunk file not found: {chunk_file}")
144
 
145
+ # For raw binary chunks, we'll create a simple buffer module
146
+ class ChunkBuffer(torch.nn.Module):
147
+ def __init__(self, chunk_path: str, config: Dict):
148
+ super().__init__()
149
+ self.chunk_path = chunk_path
150
+ self.config = config
151
+ self.start_offset = config.get('start_offset', 0)
152
+ self.size = config.get('size_bytes', 0)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ # In a real implementation, this would process the input
156
+ # using the chunk data. For now, we'll just return the input
157
+ # as this is just for testing the distribution system
158
+ return x
159
+
160
+ # Create and return the chunk buffer
161
+ chunk_model = ChunkBuffer(chunk_file, chunk_config)
162
+ print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
163
 
164
  return chunk_model
165
 
 
193
  """Get current server metrics"""
194
  return await collect_metrics()
195
 
196
+ from fastapi import File, UploadFile
197
+
198
  @app.post("/load_chunk")
199
  async def load_model_chunk(chunk: ModelChunk):
200
+ """Register a chunk configuration"""
201
+ try:
202
+ # Create model directory if it doesn't exist
203
+ os.makedirs(Settings.MODEL_DIR, exist_ok=True)
204
+
205
+ # Store the chunk metadata
206
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
207
+ state.chunk_configs = getattr(state, 'chunk_configs', {})
208
+ state.chunk_configs[chunk.chunk_id] = chunk
209
+
210
+ print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
211
+ print(f"[INFO] Waiting for chunk data: {chunk.files[0]}")
212
+
213
+ return {
214
+ "status": "configured",
215
+ "chunk_id": chunk.chunk_id,
216
+ "ready_for_data": True
217
+ }
218
+
219
+ except Exception as e:
220
+ state.error_count += 1
221
+ state.last_error = str(e)
222
+ raise HTTPException(status_code=500, detail=str(e))
223
+
224
+ @app.post("/upload_chunk_data/{chunk_id}")
225
+ async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)):
226
+ """Receive the actual chunk data"""
227
  try:
228
+ if chunk_id not in getattr(state, 'chunk_configs', {}):
229
+ raise HTTPException(status_code=400, detail="Chunk configuration not registered")
230
+
231
+ chunk = state.chunk_configs[chunk_id]
232
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
233
+
234
+ # Save the uploaded file
235
+ with open(chunk_file, 'wb') as f:
236
+ content = await file.read()
237
+ f.write(content)
238
+
239
+ # Now load the chunk
240
  chunk_model = load_chunk(chunk)
241
+ state.loaded_chunks[chunk_id] = chunk_model
242
+
243
+ file_size = os.path.getsize(chunk_file)
244
+ print(f"[INFO] Received and loaded chunk {chunk_id} data ({file_size} bytes)")
245
 
246
  return {
247
  "status": "loaded",
248
+ "chunk_id": chunk_id,
249
+ "size_bytes": file_size,
250
+ "file": chunk.files[0]
251
  }
252
 
253
  except Exception as e: