Fred808 commited on
Commit
45e602a
·
verified ·
1 Parent(s): 7bd9c49

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +44 -0
  2. requirements.txt +6 -0
  3. tensor_server.py +271 -0
Dockerfile ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
2
+
3
+ FROM python:3.11-slim-bullseye
4
+
5
+ WORKDIR /app
6
+
7
+ # Enable contrib and non-free repos, and install system dependencies
8
+ RUN sed -i 's/main/main contrib non-free/' /etc/apt/sources.list && \
9
+ apt-get update && \
10
+ apt-get install -y --no-install-recommends \
11
+ unrar \
12
+ libgl1 \
13
+ libglib2.0-0 \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+
17
+ ENV DEBIAN_FRONTEND=noninteractive
18
+
19
+ # Python + dependencies
20
+ RUN apt-get update && apt-get install -y python3 python3-pip git && \
21
+ pip3 install --upgrade pip
22
+
23
+ # Set working dir
24
+ WORKDIR /app
25
+
26
+ # Copy and install requirements
27
+ COPY requirements.txt ./
28
+ RUN pip install --no-cache-dir -r requirements.txt
29
+
30
+ # Copy app code
31
+ COPY . .
32
+
33
+ # ybyjngamhtcuaupc gsmt
34
+
35
+ # Make the entire /app directory fully writeable for all users
36
+ RUN chmod -R 777 /app
37
+
38
+ # Ensure the app runs as the same user as the Space UI
39
+ RUN useradd -m -u 1000 user
40
+ USER user
41
+
42
+ # Launch FastAPI download server on container start
43
+ CMD ["uvicorn", "tensor_server:app", "--host", "0.0.0.0", "--port", "7860"]
44
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.104.0
2
+ uvicorn==0.23.2
3
+ torch>=2.0.0
4
+ numpy>=1.24.0
5
+ psutil>=5.9.0
6
+ pydantic>=2.0.0
tensor_server.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import psutil
5
+ import asyncio
6
+ from datetime import datetime
7
+ from typing import Dict, List, Optional
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ import uvicorn
11
+ import numpy as np
12
+
13
+ # ===== Config =====
14
+ class Settings:
15
+ # Server configuration
16
+ HOST = "0.0.0.0" # Listen on all interfaces
17
+ PORT = 8001
18
+ SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
+
20
+ # The IP or hostname where this tensor server is accessible
21
+ PUBLIC_URL = os.getenv("PUBLIC_URL", f"https://fred808-ilob.hf.space")
22
+
23
+ # URLs for other services (should be actual IP addresses or hostnames)
24
+ CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
25
+ AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
26
+
27
+ # Model settings
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ MAX_BATCH_SIZE = 32
30
+ METRICS_UPDATE_INTERVAL = 5 # seconds
31
+ MODEL_DIR = "model_chunks"
32
+
33
+ @classmethod
34
+ def from_env(cls):
35
+ """Load settings from environment variables"""
36
+ cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
37
+ cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
38
+ cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
39
+ cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
40
+ cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
41
+ return cls
42
+
43
+ # ===== Models =====
44
+ class ModelChunk(BaseModel):
45
+ """Represents a received model chunk configuration"""
46
+ chunk_id: int
47
+ files: List[str]
48
+ config: Dict
49
+
50
+ class InferenceRequest(BaseModel):
51
+ """Represents an inference request"""
52
+ inputs: List[List[float]]
53
+ batch_size: Optional[int] = None
54
+
55
+ class MetricsData(BaseModel):
56
+ """Server metrics data"""
57
+ cpu_usage: float
58
+ memory_usage: float
59
+ gpu_usage: Optional[float]
60
+ active_requests: int
61
+ total_requests: int
62
+ average_response_time: float
63
+ last_error: Optional[str]
64
+ error_count: int
65
+
66
+ # ===== FastAPI App =====
67
+ app = FastAPI(
68
+ title="Tensor Server",
69
+ description="Handles model chunk computations",
70
+ version="1.0.0"
71
+ )
72
+
73
+ # ===== State =====
74
+ class ServerState:
75
+ def __init__(self):
76
+ self.loaded_chunks: Dict[int, torch.nn.Module] = {}
77
+ self.active_requests: int = 0
78
+ self.total_requests: int = 0
79
+ self.request_times: List[float] = []
80
+ self.error_count: int = 0
81
+ self.last_error: Optional[str] = None
82
+ self.is_computing: bool = False
83
+
84
+ state = ServerState()
85
+
86
+ # ===== Metrics Collection =====
87
+ async def collect_metrics() -> MetricsData:
88
+ """Collect current server metrics"""
89
+ # CPU and memory metrics
90
+ cpu_usage = psutil.cpu_percent()
91
+ memory = psutil.virtual_memory()
92
+ memory_usage = memory.percent
93
+
94
+ # GPU metrics if available
95
+ gpu_usage = None
96
+ if torch.cuda.is_available():
97
+ try:
98
+ gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
99
+ except:
100
+ pass
101
+
102
+ # Calculate average response time
103
+ avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
104
+
105
+ return MetricsData(
106
+ cpu_usage=cpu_usage,
107
+ memory_usage=memory_usage,
108
+ gpu_usage=gpu_usage,
109
+ active_requests=state.active_requests,
110
+ total_requests=state.total_requests,
111
+ average_response_time=avg_response_time,
112
+ last_error=state.last_error,
113
+ error_count=state.error_count
114
+ )
115
+
116
+ async def update_metrics_loop():
117
+ """Background task to update metrics periodically"""
118
+ while True:
119
+ try:
120
+ metrics = await collect_metrics()
121
+ # Store metrics for health checks
122
+ state.current_metrics = metrics
123
+ except Exception as e:
124
+ print(f"[ERROR] Failed to update metrics: {str(e)}")
125
+ await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
126
+
127
+ # ===== Helper Functions =====
128
+ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
129
+ """Load a model chunk into memory"""
130
+ try:
131
+ # Create chunk directory if it doesn't exist
132
+ os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
+
134
+ # Get chunk configuration
135
+ input_size = chunk.config["input_size"]
136
+ output_size = chunk.config["output_size"]
137
+ weight_keys = chunk.config["weight_keys"]
138
+
139
+ # Create a simple linear transformation for this chunk
140
+ chunk_model = torch.nn.Linear(input_size, output_size)
141
+ chunk_model = chunk_model.to(Settings.DEVICE)
142
+
143
+ # Load the weights
144
+ chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
145
+ if os.path.exists(chunk_file):
146
+ weights = torch.load(chunk_file, map_location=Settings.DEVICE)
147
+
148
+ # Initialize weights from the loaded state dict
149
+ with torch.no_grad():
150
+ # Combine weights if multiple keys
151
+ if len(weight_keys) > 1:
152
+ combined_weight = torch.cat([weights[k] for k in weight_keys], dim=0)
153
+ chunk_model.weight.copy_(combined_weight)
154
+ else:
155
+ chunk_model.weight.copy_(weights[weight_keys[0]])
156
+
157
+ return chunk_model
158
+
159
+ except Exception as e:
160
+ raise Exception(f"Failed to load chunk: {str(e)}")
161
+
162
+ async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
163
+ """Process input tensor through the specified chunk"""
164
+ if chunk_id not in state.loaded_chunks:
165
+ raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
166
+
167
+ chunk_model = state.loaded_chunks[chunk_id]
168
+ with torch.no_grad():
169
+ outputs = chunk_model(inputs)
170
+ return outputs
171
+
172
+ # ===== API Endpoints =====
173
+ @app.get("/health")
174
+ async def health_check():
175
+ """Health check endpoint"""
176
+ metrics = await collect_metrics()
177
+ return {
178
+ "status": "healthy",
179
+ "device": Settings.DEVICE,
180
+ "loaded_chunks": list(state.loaded_chunks.keys()),
181
+ "metrics": metrics.dict()
182
+ }
183
+
184
+ @app.get("/metrics")
185
+ async def get_metrics():
186
+ """Get current server metrics"""
187
+ return await collect_metrics()
188
+
189
+ @app.post("/load_chunk")
190
+ async def load_model_chunk(chunk: ModelChunk):
191
+ """Load a model chunk into memory"""
192
+ try:
193
+ # Load the chunk
194
+ chunk_model = load_chunk(chunk)
195
+ state.loaded_chunks[chunk.chunk_id] = chunk_model
196
+
197
+ return {
198
+ "status": "loaded",
199
+ "chunk_id": chunk.chunk_id,
200
+ "device": str(next(chunk_model.parameters()).device)
201
+ }
202
+
203
+ except Exception as e:
204
+ state.error_count += 1
205
+ state.last_error = str(e)
206
+ raise HTTPException(status_code=500, detail=str(e))
207
+
208
+ @app.post("/compute/{chunk_id}")
209
+ async def compute(chunk_id: int, request: InferenceRequest):
210
+ """Perform computation on inputs using specified chunk"""
211
+ try:
212
+ start_time = datetime.now()
213
+ state.active_requests += 1
214
+ state.total_requests += 1
215
+
216
+ # Convert inputs to tensor
217
+ inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
218
+
219
+ # Split into batches if needed
220
+ batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
221
+ if len(inputs) > batch_size:
222
+ batches = torch.split(inputs, batch_size)
223
+ outputs = []
224
+ for batch in batches:
225
+ batch_output = await process_tensor(chunk_id, batch)
226
+ outputs.append(batch_output)
227
+ output_tensor = torch.cat(outputs, dim=0)
228
+ else:
229
+ output_tensor = await process_tensor(chunk_id, inputs)
230
+
231
+ # Convert output to list
232
+ output_list = output_tensor.cpu().numpy().tolist()
233
+
234
+ # Update metrics
235
+ end_time = datetime.now()
236
+ processing_time = (end_time - start_time).total_seconds()
237
+ state.request_times.append(processing_time)
238
+ # Keep only last 100 request times
239
+ state.request_times = state.request_times[-100:]
240
+
241
+ return {
242
+ "outputs": output_list,
243
+ "processing_time": processing_time
244
+ }
245
+
246
+ except Exception as e:
247
+ state.error_count += 1
248
+ state.last_error = str(e)
249
+ raise HTTPException(status_code=500, detail=str(e))
250
+
251
+ finally:
252
+ state.active_requests -= 1
253
+
254
+ @app.on_event("startup")
255
+ async def startup_event():
256
+ """Start background tasks"""
257
+ asyncio.create_task(update_metrics_loop())
258
+
259
+ # ===== Main Execution =====
260
+ if __name__ == "__main__":
261
+ port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
262
+ print(f"[INFO] Starting tensor server on port {port}")
263
+ print(f"[INFO] Using device: {Settings.DEVICE}")
264
+ print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
265
+
266
+ uvicorn.run(
267
+ "tensor_server:app",
268
+ host="0.0.0.0",
269
+ port=port,
270
+ reload=False
271
+ )