Fred808 commited on
Commit
b64cd0c
·
verified ·
1 Parent(s): 881bed2

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +7 -6
  2. tensor_server.py +320 -270
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- fastapi==0.104.0
2
- uvicorn==0.23.2
3
- torch>=2.0.0
4
- numpy>=1.24.0
5
- psutil>=5.9.0
6
- pydantic>=2.0.0
 
 
1
+ fastapi==0.104.0
2
+ uvicorn==0.23.2
3
+ torch>=2.0.0
4
+ numpy>=1.24.0
5
+ psutil>=5.9.0
6
+ pydantic>=2.0.0
7
+ python-multipart
tensor_server.py CHANGED
@@ -1,271 +1,321 @@
1
- import os
2
- import json
3
- import torch
4
- import psutil
5
- import asyncio
6
- from datetime import datetime
7
- from typing import Dict, List, Optional
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel
10
- import uvicorn
11
- import numpy as np
12
-
13
- # ===== Config =====
14
- class Settings:
15
- # Server configuration
16
- HOST = "0.0.0.0" # Listen on all interfaces
17
- PORT = 8001
18
- SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
-
20
- # The IP or hostname where this tensor server is accessible
21
- PUBLIC_URL = os.getenv("PUBLIC_URL", f"https://fred808-ilob.hf.space")
22
-
23
- # URLs for other services (should be actual IP addresses or hostnames)
24
- CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
25
- AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
26
-
27
- # Model settings
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
- MAX_BATCH_SIZE = 32
30
- METRICS_UPDATE_INTERVAL = 5 # seconds
31
- MODEL_DIR = "model_chunks"
32
-
33
- @classmethod
34
- def from_env(cls):
35
- """Load settings from environment variables"""
36
- cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
37
- cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
38
- cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
39
- cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
40
- cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
41
- return cls
42
-
43
- # ===== Models =====
44
- class ModelChunk(BaseModel):
45
- """Represents a received model chunk configuration"""
46
- chunk_id: int
47
- files: List[str]
48
- config: Dict
49
-
50
- class InferenceRequest(BaseModel):
51
- """Represents an inference request"""
52
- inputs: List[List[float]]
53
- batch_size: Optional[int] = None
54
-
55
- class MetricsData(BaseModel):
56
- """Server metrics data"""
57
- cpu_usage: float
58
- memory_usage: float
59
- gpu_usage: Optional[float]
60
- active_requests: int
61
- total_requests: int
62
- average_response_time: float
63
- last_error: Optional[str]
64
- error_count: int
65
-
66
- # ===== FastAPI App =====
67
- app = FastAPI(
68
- title="Tensor Server",
69
- description="Handles model chunk computations",
70
- version="1.0.0"
71
- )
72
-
73
- # ===== State =====
74
- class ServerState:
75
- def __init__(self):
76
- self.loaded_chunks: Dict[int, torch.nn.Module] = {}
77
- self.active_requests: int = 0
78
- self.total_requests: int = 0
79
- self.request_times: List[float] = []
80
- self.error_count: int = 0
81
- self.last_error: Optional[str] = None
82
- self.is_computing: bool = False
83
-
84
- state = ServerState()
85
-
86
- # ===== Metrics Collection =====
87
- async def collect_metrics() -> MetricsData:
88
- """Collect current server metrics"""
89
- # CPU and memory metrics
90
- cpu_usage = psutil.cpu_percent()
91
- memory = psutil.virtual_memory()
92
- memory_usage = memory.percent
93
-
94
- # GPU metrics if available
95
- gpu_usage = None
96
- if torch.cuda.is_available():
97
- try:
98
- gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
99
- except:
100
- pass
101
-
102
- # Calculate average response time
103
- avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
104
-
105
- return MetricsData(
106
- cpu_usage=cpu_usage,
107
- memory_usage=memory_usage,
108
- gpu_usage=gpu_usage,
109
- active_requests=state.active_requests,
110
- total_requests=state.total_requests,
111
- average_response_time=avg_response_time,
112
- last_error=state.last_error,
113
- error_count=state.error_count
114
- )
115
-
116
- async def update_metrics_loop():
117
- """Background task to update metrics periodically"""
118
- while True:
119
- try:
120
- metrics = await collect_metrics()
121
- # Store metrics for health checks
122
- state.current_metrics = metrics
123
- except Exception as e:
124
- print(f"[ERROR] Failed to update metrics: {str(e)}")
125
- await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
126
-
127
- # ===== Helper Functions =====
128
- def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
129
- """Load a model chunk into memory"""
130
- try:
131
- # Create chunk directory if it doesn't exist
132
- os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
-
134
- # Get chunk configuration
135
- input_size = chunk.config["input_size"]
136
- output_size = chunk.config["output_size"]
137
- weight_keys = chunk.config["weight_keys"]
138
-
139
- # Create a simple linear transformation for this chunk
140
- chunk_model = torch.nn.Linear(input_size, output_size)
141
- chunk_model = chunk_model.to(Settings.DEVICE)
142
-
143
- # Load the weights
144
- chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
145
- if os.path.exists(chunk_file):
146
- weights = torch.load(chunk_file, map_location=Settings.DEVICE)
147
-
148
- # Initialize weights from the loaded state dict
149
- with torch.no_grad():
150
- # Combine weights if multiple keys
151
- if len(weight_keys) > 1:
152
- combined_weight = torch.cat([weights[k] for k in weight_keys], dim=0)
153
- chunk_model.weight.copy_(combined_weight)
154
- else:
155
- chunk_model.weight.copy_(weights[weight_keys[0]])
156
-
157
- return chunk_model
158
-
159
- except Exception as e:
160
- raise Exception(f"Failed to load chunk: {str(e)}")
161
-
162
- async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
163
- """Process input tensor through the specified chunk"""
164
- if chunk_id not in state.loaded_chunks:
165
- raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
166
-
167
- chunk_model = state.loaded_chunks[chunk_id]
168
- with torch.no_grad():
169
- outputs = chunk_model(inputs)
170
- return outputs
171
-
172
- # ===== API Endpoints =====
173
- @app.get("/health")
174
- async def health_check():
175
- """Health check endpoint"""
176
- metrics = await collect_metrics()
177
- return {
178
- "status": "healthy",
179
- "device": Settings.DEVICE,
180
- "loaded_chunks": list(state.loaded_chunks.keys()),
181
- "metrics": metrics.dict()
182
- }
183
-
184
- @app.get("/metrics")
185
- async def get_metrics():
186
- """Get current server metrics"""
187
- return await collect_metrics()
188
-
189
- @app.post("/load_chunk")
190
- async def load_model_chunk(chunk: ModelChunk):
191
- """Load a model chunk into memory"""
192
- try:
193
- # Load the chunk
194
- chunk_model = load_chunk(chunk)
195
- state.loaded_chunks[chunk.chunk_id] = chunk_model
196
-
197
- return {
198
- "status": "loaded",
199
- "chunk_id": chunk.chunk_id,
200
- "device": str(next(chunk_model.parameters()).device)
201
- }
202
-
203
- except Exception as e:
204
- state.error_count += 1
205
- state.last_error = str(e)
206
- raise HTTPException(status_code=500, detail=str(e))
207
-
208
- @app.post("/compute/{chunk_id}")
209
- async def compute(chunk_id: int, request: InferenceRequest):
210
- """Perform computation on inputs using specified chunk"""
211
- try:
212
- start_time = datetime.now()
213
- state.active_requests += 1
214
- state.total_requests += 1
215
-
216
- # Convert inputs to tensor
217
- inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
218
-
219
- # Split into batches if needed
220
- batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
221
- if len(inputs) > batch_size:
222
- batches = torch.split(inputs, batch_size)
223
- outputs = []
224
- for batch in batches:
225
- batch_output = await process_tensor(chunk_id, batch)
226
- outputs.append(batch_output)
227
- output_tensor = torch.cat(outputs, dim=0)
228
- else:
229
- output_tensor = await process_tensor(chunk_id, inputs)
230
-
231
- # Convert output to list
232
- output_list = output_tensor.cpu().numpy().tolist()
233
-
234
- # Update metrics
235
- end_time = datetime.now()
236
- processing_time = (end_time - start_time).total_seconds()
237
- state.request_times.append(processing_time)
238
- # Keep only last 100 request times
239
- state.request_times = state.request_times[-100:]
240
-
241
- return {
242
- "outputs": output_list,
243
- "processing_time": processing_time
244
- }
245
-
246
- except Exception as e:
247
- state.error_count += 1
248
- state.last_error = str(e)
249
- raise HTTPException(status_code=500, detail=str(e))
250
-
251
- finally:
252
- state.active_requests -= 1
253
-
254
- @app.on_event("startup")
255
- async def startup_event():
256
- """Start background tasks"""
257
- asyncio.create_task(update_metrics_loop())
258
-
259
- # ===== Main Execution =====
260
- if __name__ == "__main__":
261
- port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
262
- print(f"[INFO] Starting tensor server on port {port}")
263
- print(f"[INFO] Using device: {Settings.DEVICE}")
264
- print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
265
-
266
- uvicorn.run(
267
- "tensor_server:app",
268
- host="0.0.0.0",
269
- port=port,
270
- reload=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  )
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import psutil
5
+ import asyncio
6
+ from datetime import datetime
7
+ from typing import Dict, List, Optional
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ import uvicorn
11
+ import numpy as np
12
+
13
+ # ===== Config =====
14
+ class Settings:
15
+ # Server configuration
16
+ HOST = "0.0.0.0" # Listen on all interfaces
17
+ PORT = 8001
18
+ SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
+
20
+ # The IP or hostname where this tensor server is accessible
21
+ PUBLIC_URL = os.getenv("PUBLIC_URL", f"http://192.168.1.101:8001")
22
+
23
+ # URLs for other services (should be actual IP addresses or hostnames)
24
+ CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
25
+ AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
26
+
27
+ # Model settings
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ MAX_BATCH_SIZE = 32
30
+ METRICS_UPDATE_INTERVAL = 5 # seconds
31
+ MODEL_DIR = "model_chunks"
32
+
33
+ @classmethod
34
+ def from_env(cls):
35
+ """Load settings from environment variables"""
36
+ cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
37
+ cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
38
+ cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
39
+ cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
40
+ cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
41
+ return cls
42
+
43
+ # ===== Models =====
44
+ class ModelChunk(BaseModel):
45
+ """Represents a received model chunk configuration"""
46
+ chunk_id: int
47
+ files: List[str]
48
+ config: Dict
49
+
50
+ class InferenceRequest(BaseModel):
51
+ """Represents an inference request"""
52
+ inputs: List[List[float]]
53
+ batch_size: Optional[int] = None
54
+
55
+ class MetricsData(BaseModel):
56
+ """Server metrics data"""
57
+ cpu_usage: float
58
+ memory_usage: float
59
+ gpu_usage: Optional[float]
60
+ active_requests: int
61
+ total_requests: int
62
+ average_response_time: float
63
+ last_error: Optional[str]
64
+ error_count: int
65
+
66
+ # ===== FastAPI App =====
67
+ app = FastAPI(
68
+ title="Tensor Server",
69
+ description="Handles model chunk computations",
70
+ version="1.0.0"
71
+ )
72
+
73
+ # ===== State =====
74
+ class ServerState:
75
+ def __init__(self):
76
+ self.loaded_chunks: Dict[int, torch.nn.Module] = {}
77
+ self.active_requests: int = 0
78
+ self.total_requests: int = 0
79
+ self.request_times: List[float] = []
80
+ self.error_count: int = 0
81
+ self.last_error: Optional[str] = None
82
+ self.is_computing: bool = False
83
+
84
+ state = ServerState()
85
+
86
+ # ===== Metrics Collection =====
87
+ async def collect_metrics() -> MetricsData:
88
+ """Collect current server metrics"""
89
+ # CPU and memory metrics
90
+ cpu_usage = psutil.cpu_percent()
91
+ memory = psutil.virtual_memory()
92
+ memory_usage = memory.percent
93
+
94
+ # GPU metrics if available
95
+ gpu_usage = None
96
+ if torch.cuda.is_available():
97
+ try:
98
+ gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
99
+ except:
100
+ pass
101
+
102
+ # Calculate average response time
103
+ avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
104
+
105
+ return MetricsData(
106
+ cpu_usage=cpu_usage,
107
+ memory_usage=memory_usage,
108
+ gpu_usage=gpu_usage,
109
+ active_requests=state.active_requests,
110
+ total_requests=state.total_requests,
111
+ average_response_time=avg_response_time,
112
+ last_error=state.last_error,
113
+ error_count=state.error_count
114
+ )
115
+
116
+ async def update_metrics_loop():
117
+ """Background task to update metrics periodically"""
118
+ while True:
119
+ try:
120
+ metrics = await collect_metrics()
121
+ # Store metrics for health checks
122
+ state.current_metrics = metrics
123
+ except Exception as e:
124
+ print(f"[ERROR] Failed to update metrics: {str(e)}")
125
+ await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
126
+
127
+ # ===== Helper Functions =====
128
+ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
129
+ """Load a model chunk into memory"""
130
+ try:
131
+ # Create chunk directory if it doesn't exist
132
+ os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
+
134
+ # Get chunk configuration
135
+ chunk_config = chunk.config
136
+ if "original_file" not in chunk_config:
137
+ raise ValueError("Missing original_file in chunk configuration")
138
+
139
+ # Save chunk data to file
140
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
141
+ if not os.path.exists(chunk_file):
142
+ # We'll need to receive the actual chunk data in a separate request
143
+ raise ValueError(f"Chunk file not found: {chunk_file}")
144
+
145
+ # For raw binary chunks, we'll create a simple buffer module
146
+ class ChunkBuffer(torch.nn.Module):
147
+ def __init__(self, chunk_path: str, config: Dict):
148
+ super().__init__()
149
+ self.chunk_path = chunk_path
150
+ self.config = config
151
+ self.start_offset = config.get('start_offset', 0)
152
+ self.size = config.get('size_bytes', 0)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ # In a real implementation, this would process the input
156
+ # using the chunk data. For now, we'll just return the input
157
+ # as this is just for testing the distribution system
158
+ return x
159
+
160
+ # Create and return the chunk buffer
161
+ chunk_model = ChunkBuffer(chunk_file, chunk_config)
162
+ print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
163
+
164
+ return chunk_model
165
+
166
+ except Exception as e:
167
+ raise Exception(f"Failed to load chunk: {str(e)}")
168
+
169
+ async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
170
+ """Process input tensor through the specified chunk"""
171
+ if chunk_id not in state.loaded_chunks:
172
+ raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
173
+
174
+ chunk_model = state.loaded_chunks[chunk_id]
175
+ with torch.no_grad():
176
+ outputs = chunk_model(inputs)
177
+ return outputs
178
+
179
+ # ===== API Endpoints =====
180
+ @app.get("/health")
181
+ async def health_check():
182
+ """Health check endpoint"""
183
+ metrics = await collect_metrics()
184
+ return {
185
+ "status": "healthy",
186
+ "device": Settings.DEVICE,
187
+ "loaded_chunks": list(state.loaded_chunks.keys()),
188
+ "metrics": metrics.dict()
189
+ }
190
+
191
+ @app.get("/metrics")
192
+ async def get_metrics():
193
+ """Get current server metrics"""
194
+ return await collect_metrics()
195
+
196
+ from fastapi import File, UploadFile
197
+
198
+ @app.post("/load_chunk")
199
+ async def load_model_chunk(chunk: ModelChunk):
200
+ """Register a chunk configuration"""
201
+ try:
202
+ # Create model directory if it doesn't exist
203
+ os.makedirs(Settings.MODEL_DIR, exist_ok=True)
204
+
205
+ # Store the chunk metadata
206
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
207
+ state.chunk_configs = getattr(state, 'chunk_configs', {})
208
+ state.chunk_configs[chunk.chunk_id] = chunk
209
+
210
+ print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
211
+ print(f"[INFO] Waiting for chunk data: {chunk.files[0]}")
212
+
213
+ return {
214
+ "status": "configured",
215
+ "chunk_id": chunk.chunk_id,
216
+ "ready_for_data": True
217
+ }
218
+
219
+ except Exception as e:
220
+ state.error_count += 1
221
+ state.last_error = str(e)
222
+ raise HTTPException(status_code=500, detail=str(e))
223
+
224
+ @app.post("/upload_chunk_data/{chunk_id}")
225
+ async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)):
226
+ """Receive the actual chunk data"""
227
+ try:
228
+ if chunk_id not in getattr(state, 'chunk_configs', {}):
229
+ raise HTTPException(status_code=400, detail="Chunk configuration not registered")
230
+
231
+ chunk = state.chunk_configs[chunk_id]
232
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
233
+
234
+ # Save the uploaded file
235
+ with open(chunk_file, 'wb') as f:
236
+ content = await file.read()
237
+ f.write(content)
238
+
239
+ # Now load the chunk
240
+ chunk_model = load_chunk(chunk)
241
+ state.loaded_chunks[chunk_id] = chunk_model
242
+
243
+ file_size = os.path.getsize(chunk_file)
244
+ print(f"[INFO] Received and loaded chunk {chunk_id} data ({file_size} bytes)")
245
+
246
+ return {
247
+ "status": "loaded",
248
+ "chunk_id": chunk_id,
249
+ "size_bytes": file_size,
250
+ "file": chunk.files[0]
251
+ }
252
+
253
+ except Exception as e:
254
+ state.error_count += 1
255
+ state.last_error = str(e)
256
+ raise HTTPException(status_code=500, detail=str(e))
257
+
258
+ @app.post("/compute/{chunk_id}")
259
+ async def compute(chunk_id: int, request: InferenceRequest):
260
+ """Perform computation on inputs using specified chunk"""
261
+ try:
262
+ start_time = datetime.now()
263
+ state.active_requests += 1
264
+ state.total_requests += 1
265
+
266
+ # Convert inputs to tensor
267
+ inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
268
+
269
+ # Split into batches if needed
270
+ batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
271
+ if len(inputs) > batch_size:
272
+ batches = torch.split(inputs, batch_size)
273
+ outputs = []
274
+ for batch in batches:
275
+ batch_output = await process_tensor(chunk_id, batch)
276
+ outputs.append(batch_output)
277
+ output_tensor = torch.cat(outputs, dim=0)
278
+ else:
279
+ output_tensor = await process_tensor(chunk_id, inputs)
280
+
281
+ # Convert output to list
282
+ output_list = output_tensor.cpu().numpy().tolist()
283
+
284
+ # Update metrics
285
+ end_time = datetime.now()
286
+ processing_time = (end_time - start_time).total_seconds()
287
+ state.request_times.append(processing_time)
288
+ # Keep only last 100 request times
289
+ state.request_times = state.request_times[-100:]
290
+
291
+ return {
292
+ "outputs": output_list,
293
+ "processing_time": processing_time
294
+ }
295
+
296
+ except Exception as e:
297
+ state.error_count += 1
298
+ state.last_error = str(e)
299
+ raise HTTPException(status_code=500, detail=str(e))
300
+
301
+ finally:
302
+ state.active_requests -= 1
303
+
304
+ @app.on_event("startup")
305
+ async def startup_event():
306
+ """Start background tasks"""
307
+ asyncio.create_task(update_metrics_loop())
308
+
309
+ # ===== Main Execution =====
310
+ if __name__ == "__main__":
311
+ port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
312
+ print(f"[INFO] Starting tensor server on port {port}")
313
+ print(f"[INFO] Using device: {Settings.DEVICE}")
314
+ print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
315
+
316
+ uvicorn.run(
317
+ "tensor_server:app",
318
+ host="0.0.0.0",
319
+ port=port,
320
+ reload=False
321
  )