Fred808 commited on
Commit
47b6b0c
·
verified ·
1 Parent(s): b64cd0c

Update tensor_server.py

Browse files
Files changed (1) hide show
  1. tensor_server.py +382 -320
tensor_server.py CHANGED
@@ -1,321 +1,383 @@
1
- import os
2
- import json
3
- import torch
4
- import psutil
5
- import asyncio
6
- from datetime import datetime
7
- from typing import Dict, List, Optional
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel
10
- import uvicorn
11
- import numpy as np
12
-
13
- # ===== Config =====
14
- class Settings:
15
- # Server configuration
16
- HOST = "0.0.0.0" # Listen on all interfaces
17
- PORT = 8001
18
- SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
-
20
- # The IP or hostname where this tensor server is accessible
21
- PUBLIC_URL = os.getenv("PUBLIC_URL", f"http://192.168.1.101:8001")
22
-
23
- # URLs for other services (should be actual IP addresses or hostnames)
24
- CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
25
- AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
26
-
27
- # Model settings
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
- MAX_BATCH_SIZE = 32
30
- METRICS_UPDATE_INTERVAL = 5 # seconds
31
- MODEL_DIR = "model_chunks"
32
-
33
- @classmethod
34
- def from_env(cls):
35
- """Load settings from environment variables"""
36
- cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
37
- cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
38
- cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
39
- cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
40
- cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
41
- return cls
42
-
43
- # ===== Models =====
44
- class ModelChunk(BaseModel):
45
- """Represents a received model chunk configuration"""
46
- chunk_id: int
47
- files: List[str]
48
- config: Dict
49
-
50
- class InferenceRequest(BaseModel):
51
- """Represents an inference request"""
52
- inputs: List[List[float]]
53
- batch_size: Optional[int] = None
54
-
55
- class MetricsData(BaseModel):
56
- """Server metrics data"""
57
- cpu_usage: float
58
- memory_usage: float
59
- gpu_usage: Optional[float]
60
- active_requests: int
61
- total_requests: int
62
- average_response_time: float
63
- last_error: Optional[str]
64
- error_count: int
65
-
66
- # ===== FastAPI App =====
67
- app = FastAPI(
68
- title="Tensor Server",
69
- description="Handles model chunk computations",
70
- version="1.0.0"
71
- )
72
-
73
- # ===== State =====
74
- class ServerState:
75
- def __init__(self):
76
- self.loaded_chunks: Dict[int, torch.nn.Module] = {}
77
- self.active_requests: int = 0
78
- self.total_requests: int = 0
79
- self.request_times: List[float] = []
80
- self.error_count: int = 0
81
- self.last_error: Optional[str] = None
82
- self.is_computing: bool = False
83
-
84
- state = ServerState()
85
-
86
- # ===== Metrics Collection =====
87
- async def collect_metrics() -> MetricsData:
88
- """Collect current server metrics"""
89
- # CPU and memory metrics
90
- cpu_usage = psutil.cpu_percent()
91
- memory = psutil.virtual_memory()
92
- memory_usage = memory.percent
93
-
94
- # GPU metrics if available
95
- gpu_usage = None
96
- if torch.cuda.is_available():
97
- try:
98
- gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
99
- except:
100
- pass
101
-
102
- # Calculate average response time
103
- avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
104
-
105
- return MetricsData(
106
- cpu_usage=cpu_usage,
107
- memory_usage=memory_usage,
108
- gpu_usage=gpu_usage,
109
- active_requests=state.active_requests,
110
- total_requests=state.total_requests,
111
- average_response_time=avg_response_time,
112
- last_error=state.last_error,
113
- error_count=state.error_count
114
- )
115
-
116
- async def update_metrics_loop():
117
- """Background task to update metrics periodically"""
118
- while True:
119
- try:
120
- metrics = await collect_metrics()
121
- # Store metrics for health checks
122
- state.current_metrics = metrics
123
- except Exception as e:
124
- print(f"[ERROR] Failed to update metrics: {str(e)}")
125
- await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
126
-
127
- # ===== Helper Functions =====
128
- def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
129
- """Load a model chunk into memory"""
130
- try:
131
- # Create chunk directory if it doesn't exist
132
- os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
-
134
- # Get chunk configuration
135
- chunk_config = chunk.config
136
- if "original_file" not in chunk_config:
137
- raise ValueError("Missing original_file in chunk configuration")
138
-
139
- # Save chunk data to file
140
- chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
141
- if not os.path.exists(chunk_file):
142
- # We'll need to receive the actual chunk data in a separate request
143
- raise ValueError(f"Chunk file not found: {chunk_file}")
144
-
145
- # For raw binary chunks, we'll create a simple buffer module
146
- class ChunkBuffer(torch.nn.Module):
147
- def __init__(self, chunk_path: str, config: Dict):
148
- super().__init__()
149
- self.chunk_path = chunk_path
150
- self.config = config
151
- self.start_offset = config.get('start_offset', 0)
152
- self.size = config.get('size_bytes', 0)
153
-
154
- def forward(self, x: torch.Tensor) -> torch.Tensor:
155
- # In a real implementation, this would process the input
156
- # using the chunk data. For now, we'll just return the input
157
- # as this is just for testing the distribution system
158
- return x
159
-
160
- # Create and return the chunk buffer
161
- chunk_model = ChunkBuffer(chunk_file, chunk_config)
162
- print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
163
-
164
- return chunk_model
165
-
166
- except Exception as e:
167
- raise Exception(f"Failed to load chunk: {str(e)}")
168
-
169
- async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
170
- """Process input tensor through the specified chunk"""
171
- if chunk_id not in state.loaded_chunks:
172
- raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
173
-
174
- chunk_model = state.loaded_chunks[chunk_id]
175
- with torch.no_grad():
176
- outputs = chunk_model(inputs)
177
- return outputs
178
-
179
- # ===== API Endpoints =====
180
- @app.get("/health")
181
- async def health_check():
182
- """Health check endpoint"""
183
- metrics = await collect_metrics()
184
- return {
185
- "status": "healthy",
186
- "device": Settings.DEVICE,
187
- "loaded_chunks": list(state.loaded_chunks.keys()),
188
- "metrics": metrics.dict()
189
- }
190
-
191
- @app.get("/metrics")
192
- async def get_metrics():
193
- """Get current server metrics"""
194
- return await collect_metrics()
195
-
196
- from fastapi import File, UploadFile
197
-
198
- @app.post("/load_chunk")
199
- async def load_model_chunk(chunk: ModelChunk):
200
- """Register a chunk configuration"""
201
- try:
202
- # Create model directory if it doesn't exist
203
- os.makedirs(Settings.MODEL_DIR, exist_ok=True)
204
-
205
- # Store the chunk metadata
206
- chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
207
- state.chunk_configs = getattr(state, 'chunk_configs', {})
208
- state.chunk_configs[chunk.chunk_id] = chunk
209
-
210
- print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
211
- print(f"[INFO] Waiting for chunk data: {chunk.files[0]}")
212
-
213
- return {
214
- "status": "configured",
215
- "chunk_id": chunk.chunk_id,
216
- "ready_for_data": True
217
- }
218
-
219
- except Exception as e:
220
- state.error_count += 1
221
- state.last_error = str(e)
222
- raise HTTPException(status_code=500, detail=str(e))
223
-
224
- @app.post("/upload_chunk_data/{chunk_id}")
225
- async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)):
226
- """Receive the actual chunk data"""
227
- try:
228
- if chunk_id not in getattr(state, 'chunk_configs', {}):
229
- raise HTTPException(status_code=400, detail="Chunk configuration not registered")
230
-
231
- chunk = state.chunk_configs[chunk_id]
232
- chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
233
-
234
- # Save the uploaded file
235
- with open(chunk_file, 'wb') as f:
236
- content = await file.read()
237
- f.write(content)
238
-
239
- # Now load the chunk
240
- chunk_model = load_chunk(chunk)
241
- state.loaded_chunks[chunk_id] = chunk_model
242
-
243
- file_size = os.path.getsize(chunk_file)
244
- print(f"[INFO] Received and loaded chunk {chunk_id} data ({file_size} bytes)")
245
-
246
- return {
247
- "status": "loaded",
248
- "chunk_id": chunk_id,
249
- "size_bytes": file_size,
250
- "file": chunk.files[0]
251
- }
252
-
253
- except Exception as e:
254
- state.error_count += 1
255
- state.last_error = str(e)
256
- raise HTTPException(status_code=500, detail=str(e))
257
-
258
- @app.post("/compute/{chunk_id}")
259
- async def compute(chunk_id: int, request: InferenceRequest):
260
- """Perform computation on inputs using specified chunk"""
261
- try:
262
- start_time = datetime.now()
263
- state.active_requests += 1
264
- state.total_requests += 1
265
-
266
- # Convert inputs to tensor
267
- inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
268
-
269
- # Split into batches if needed
270
- batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
271
- if len(inputs) > batch_size:
272
- batches = torch.split(inputs, batch_size)
273
- outputs = []
274
- for batch in batches:
275
- batch_output = await process_tensor(chunk_id, batch)
276
- outputs.append(batch_output)
277
- output_tensor = torch.cat(outputs, dim=0)
278
- else:
279
- output_tensor = await process_tensor(chunk_id, inputs)
280
-
281
- # Convert output to list
282
- output_list = output_tensor.cpu().numpy().tolist()
283
-
284
- # Update metrics
285
- end_time = datetime.now()
286
- processing_time = (end_time - start_time).total_seconds()
287
- state.request_times.append(processing_time)
288
- # Keep only last 100 request times
289
- state.request_times = state.request_times[-100:]
290
-
291
- return {
292
- "outputs": output_list,
293
- "processing_time": processing_time
294
- }
295
-
296
- except Exception as e:
297
- state.error_count += 1
298
- state.last_error = str(e)
299
- raise HTTPException(status_code=500, detail=str(e))
300
-
301
- finally:
302
- state.active_requests -= 1
303
-
304
- @app.on_event("startup")
305
- async def startup_event():
306
- """Start background tasks"""
307
- asyncio.create_task(update_metrics_loop())
308
-
309
- # ===== Main Execution =====
310
- if __name__ == "__main__":
311
- port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
312
- print(f"[INFO] Starting tensor server on port {port}")
313
- print(f"[INFO] Using device: {Settings.DEVICE}")
314
- print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
315
-
316
- uvicorn.run(
317
- "tensor_server:app",
318
- host="0.0.0.0",
319
- port=port,
320
- reload=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import psutil
5
+ import asyncio
6
+ from datetime import datetime
7
+ from typing import Dict, List, Optional
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ import uvicorn
11
+ import numpy as np
12
+
13
+ # ===== Config =====
14
+ class Settings:
15
+ # Server configuration
16
+ HOST = "0.0.0.0" # Listen on all interfaces
17
+ PORT = 8001
18
+ SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
+
20
+ # The IP or hostname where this tensor server is accessible
21
+ PUBLIC_URL = os.getenv("PUBLIC_URL", f"http://192.168.1.101:8001")
22
+
23
+ # URLs for other services (should be actual IP addresses or hostnames)
24
+ CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
25
+ AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
26
+
27
+ # Model settings
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ MAX_BATCH_SIZE = 32
30
+ METRICS_UPDATE_INTERVAL = 5 # seconds
31
+ MODEL_DIR = "model_chunks"
32
+
33
+ @classmethod
34
+ def from_env(cls):
35
+ """Load settings from environment variables"""
36
+ cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
37
+ cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
38
+ cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
39
+ cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
40
+ cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
41
+ return cls
42
+
43
+ # ===== Models =====
44
+ class ModelChunk(BaseModel):
45
+ """Represents a received model chunk configuration"""
46
+ chunk_id: int
47
+ files: List[str]
48
+ config: Dict
49
+
50
+ class InferenceRequest(BaseModel):
51
+ """Represents an inference request"""
52
+ inputs: List[List[float]]
53
+ batch_size: Optional[int] = None
54
+ top_k: Optional[int] = 5
55
+
56
+ class MetricsData(BaseModel):
57
+ """Server metrics data"""
58
+ cpu_usage: float
59
+ memory_usage: float
60
+ gpu_usage: Optional[float]
61
+ active_requests: int
62
+ total_requests: int
63
+ average_response_time: float
64
+ last_error: Optional[str]
65
+ error_count: int
66
+
67
+ # ===== FastAPI App =====
68
+ app = FastAPI(
69
+ title="Tensor Server",
70
+ description="Handles model chunk computations",
71
+ version="1.0.0"
72
+ )
73
+
74
+ # ===== State =====
75
+ class ServerState:
76
+ def __init__(self):
77
+ self.loaded_chunks: Dict[int, torch.nn.Module] = {}
78
+ self.active_requests: int = 0
79
+ self.total_requests: int = 0
80
+ self.request_times: List[float] = []
81
+ self.error_count: int = 0
82
+ self.last_error: Optional[str] = None
83
+ self.is_computing: bool = False
84
+
85
+ state = ServerState()
86
+
87
+ # ===== Metrics Collection =====
88
+ async def collect_metrics() -> MetricsData:
89
+ """Collect current server metrics"""
90
+ # CPU and memory metrics
91
+ cpu_usage = psutil.cpu_percent()
92
+ memory = psutil.virtual_memory()
93
+ memory_usage = memory.percent
94
+
95
+ # GPU metrics if available
96
+ gpu_usage = None
97
+ if torch.cuda.is_available():
98
+ try:
99
+ gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
100
+ except:
101
+ pass
102
+
103
+ # Calculate average response time
104
+ avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
105
+
106
+ return MetricsData(
107
+ cpu_usage=cpu_usage,
108
+ memory_usage=memory_usage,
109
+ gpu_usage=gpu_usage,
110
+ active_requests=state.active_requests,
111
+ total_requests=state.total_requests,
112
+ average_response_time=avg_response_time,
113
+ last_error=state.last_error,
114
+ error_count=state.error_count
115
+ )
116
+
117
+ async def update_metrics_loop():
118
+ """Background task to update metrics periodically"""
119
+ while True:
120
+ try:
121
+ metrics = await collect_metrics()
122
+ # Store metrics for health checks
123
+ state.current_metrics = metrics
124
+ except Exception as e:
125
+ print(f"[ERROR] Failed to update metrics: {str(e)}")
126
+ await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
127
+
128
+ # ===== Helper Functions =====
129
+ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
130
+ """Load a model chunk into memory"""
131
+ try:
132
+ # Create chunk directory if it doesn't exist
133
+ os.makedirs(Settings.MODEL_DIR, exist_ok=True)
134
+
135
+ # Get chunk configuration
136
+ chunk_config = chunk.config
137
+ if "original_file" not in chunk_config:
138
+ raise ValueError("Missing original_file in chunk configuration")
139
+
140
+ # Save chunk data to file
141
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
142
+ if not os.path.exists(chunk_file):
143
+ # We'll need to receive the actual chunk data in a separate request
144
+ raise ValueError(f"Chunk file not found: {chunk_file}")
145
+
146
+ # For raw binary chunks, we'll create a simple buffer module
147
+ class ChunkBuffer(torch.nn.Module):
148
+ def __init__(self, chunk_path: str, config: Dict):
149
+ super().__init__()
150
+ self.chunk_path = chunk_path
151
+ self.config = config
152
+ self.start_offset = config.get('start_offset', 0)
153
+ self.size = config.get('size_bytes', 0)
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ # In a real implementation, this would process the input
157
+ # using the chunk data. For now, we'll just return the input
158
+ # as this is just for testing the distribution system
159
+ return x
160
+
161
+ # Create and return the chunk buffer
162
+ chunk_model = ChunkBuffer(chunk_file, chunk_config)
163
+ print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
164
+
165
+ return chunk_model
166
+
167
+ except Exception as e:
168
+ raise Exception(f"Failed to load chunk: {str(e)}")
169
+
170
+ async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
171
+ """Process input tensor through the specified chunk"""
172
+ if chunk_id not in state.loaded_chunks:
173
+ raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
174
+
175
+ chunk_model = state.loaded_chunks[chunk_id]
176
+ with torch.no_grad():
177
+ outputs = chunk_model(inputs)
178
+ return outputs
179
+
180
+ # ===== API Endpoints =====
181
+ @app.get("/health")
182
+ async def health_check():
183
+ """Health check endpoint"""
184
+ metrics = await collect_metrics()
185
+ return {
186
+ "status": "healthy",
187
+ "device": Settings.DEVICE,
188
+ "loaded_chunks": list(state.loaded_chunks.keys()),
189
+ "metrics": metrics.dict()
190
+ }
191
+
192
+ @app.get("/metrics")
193
+ async def get_metrics():
194
+ """Get current server metrics"""
195
+ return await collect_metrics()
196
+
197
+ from fastapi import File, UploadFile
198
+
199
+ @app.post("/load_chunk")
200
+ async def load_model_chunk(chunk: ModelChunk):
201
+ """Register a chunk configuration"""
202
+ try:
203
+ # Create model directory if it doesn't exist
204
+ os.makedirs(Settings.MODEL_DIR, exist_ok=True)
205
+
206
+ # Store the chunk metadata
207
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
208
+ state.chunk_configs = getattr(state, 'chunk_configs', {})
209
+ state.chunk_configs[chunk.chunk_id] = chunk
210
+
211
+ print(f"[INFO] Registered chunk {chunk.chunk_id} configuration")
212
+ print(f"[INFO] Waiting for chunk data: {chunk.files[0]}")
213
+
214
+ return {
215
+ "status": "configured",
216
+ "chunk_id": chunk.chunk_id,
217
+ "ready_for_data": True
218
+ }
219
+
220
+ except Exception as e:
221
+ state.error_count += 1
222
+ state.last_error = str(e)
223
+ raise HTTPException(status_code=500, detail=str(e))
224
+
225
+ @app.post("/upload_chunk_data/{chunk_id}")
226
+ async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)):
227
+ """Receive the actual chunk data"""
228
+ try:
229
+ if chunk_id not in getattr(state, 'chunk_configs', {}):
230
+ raise HTTPException(status_code=400, detail="Chunk configuration not registered")
231
+
232
+ chunk = state.chunk_configs[chunk_id]
233
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
234
+
235
+ # Save the uploaded file
236
+ with open(chunk_file, 'wb') as f:
237
+ content = await file.read()
238
+ f.write(content)
239
+
240
+ # Now load the chunk
241
+ chunk_model = load_chunk(chunk)
242
+ state.loaded_chunks[chunk_id] = chunk_model
243
+
244
+ file_size = os.path.getsize(chunk_file)
245
+ print(f"[INFO] Received and loaded chunk {chunk_id} data ({file_size} bytes)")
246
+
247
+ return {
248
+ "status": "loaded",
249
+ "chunk_id": chunk_id,
250
+ "size_bytes": file_size,
251
+ "file": chunk.files[0]
252
+ }
253
+
254
+ except Exception as e:
255
+ state.error_count += 1
256
+ state.last_error = str(e)
257
+ raise HTTPException(status_code=500, detail=str(e))
258
+
259
+ @app.post("/compute/{chunk_id}")
260
+ async def compute(chunk_id: int, request: InferenceRequest):
261
+ """Perform computation on inputs using specified chunk"""
262
+ try:
263
+ start_time = datetime.now()
264
+ state.active_requests += 1
265
+ state.total_requests += 1
266
+
267
+ # Convert inputs to tensor
268
+ inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
269
+
270
+ # Split into batches if needed
271
+ batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
272
+ if len(inputs) > batch_size:
273
+ batches = torch.split(inputs, batch_size)
274
+ outputs = []
275
+ for batch in batches:
276
+ batch_output = await process_tensor(chunk_id, batch)
277
+ outputs.append(batch_output)
278
+ output_tensor = torch.cat(outputs, dim=0)
279
+ else:
280
+ output_tensor = await process_tensor(chunk_id, inputs)
281
+
282
+ # Convert output to numpy for diagnostics
283
+ try:
284
+ shard_np = output_tensor.cpu().numpy()
285
+ except Exception:
286
+ shard_np = None
287
+
288
+ chunk_details = {}
289
+ try:
290
+ # Normalize to 2D: (seq_len, shard_dim)
291
+ if shard_np is None:
292
+ raise ValueError("Unable to convert output tensor to numpy")
293
+
294
+ seq_len = shard_np.shape[0]
295
+ shard_2d = shard_np.reshape(seq_len, -1)
296
+
297
+ k = int(request.top_k or 5)
298
+ k = min(k, shard_2d.shape[1]) if shard_2d.shape[1] > 0 else 0
299
+
300
+ # compute local top-k per position
301
+ if k > 0:
302
+ topk_idx = np.argpartition(-shard_2d, k-1, axis=1)[:, :k]
303
+ topk_vals = np.take_along_axis(shard_2d, topk_idx, axis=1)
304
+ else:
305
+ topk_idx = np.zeros((seq_len, 0), dtype=int)
306
+ topk_vals = np.zeros((seq_len, 0), dtype=float)
307
+
308
+ # determine vocab_offset from the loaded chunk config if available
309
+ cfg = None
310
+ try:
311
+ chunk_model = state.loaded_chunks.get(chunk_id)
312
+ cfg = getattr(chunk_model, 'config', None) or getattr(state, 'chunk_configs', {}).get(chunk_id, {}).config if chunk_id in getattr(state, 'chunk_configs', {}) else None
313
+ except Exception:
314
+ cfg = None
315
+
316
+ vocab_offset = 0
317
+ if isinstance(cfg, dict):
318
+ vocab_offset = int(cfg.get('vocab_offset', 0))
319
+ elif cfg is not None and hasattr(cfg, 'get'):
320
+ vocab_offset = int(cfg.get('vocab_offset', 0))
321
+
322
+ per_position_topk = []
323
+ for pos_idx in range(seq_len):
324
+ toks = []
325
+ for jj in range(topk_idx.shape[1]):
326
+ local_idx = int(topk_idx[pos_idx, jj])
327
+ token_id = int(vocab_offset + local_idx)
328
+ score = float(topk_vals[pos_idx, jj])
329
+ toks.append([token_id, score])
330
+ per_position_topk.append(toks)
331
+
332
+ chunk_details[chunk_id] = {
333
+ 'logits_shard': shard_2d.tolist(),
334
+ 'topk': per_position_topk,
335
+ 'vocab_offset': vocab_offset,
336
+ 'shard_dim': shard_2d.shape[1]
337
+ }
338
+ except Exception as e:
339
+ # If diagnostics fail, include error info but keep main outputs
340
+ chunk_details = {chunk_id: {'error': str(e)}}
341
+
342
+ # Convert output to list for backward compatibility
343
+ output_list = output_tensor.cpu().numpy().tolist()
344
+
345
+ # Update metrics
346
+ end_time = datetime.now()
347
+ processing_time = (end_time - start_time).total_seconds()
348
+ state.request_times.append(processing_time)
349
+ # Keep only last 100 request times
350
+ state.request_times = state.request_times[-100:]
351
+
352
+ return {
353
+ "outputs": output_list,
354
+ "processing_time": processing_time,
355
+ "chunk_details": chunk_details
356
+ }
357
+
358
+ except Exception as e:
359
+ state.error_count += 1
360
+ state.last_error = str(e)
361
+ raise HTTPException(status_code=500, detail=str(e))
362
+
363
+ finally:
364
+ state.active_requests -= 1
365
+
366
+ @app.on_event("startup")
367
+ async def startup_event():
368
+ """Start background tasks"""
369
+ asyncio.create_task(update_metrics_loop())
370
+
371
+ # ===== Main Execution =====
372
+ if __name__ == "__main__":
373
+ port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
374
+ print(f"[INFO] Starting tensor server on port {port}")
375
+ print(f"[INFO] Using device: {Settings.DEVICE}")
376
+ print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
377
+
378
+ uvicorn.run(
379
+ "tensor_server:app",
380
+ host="0.0.0.0",
381
+ port=port,
382
+ reload=False
383
  )