Upload 3 files
Browse files- Dockerfile +44 -0
- requirements.txt +6 -0
- tensor_server.py +271 -0
Dockerfile
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
FROM python:3.11-slim-bullseye
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Enable contrib and non-free repos, and install system dependencies
|
| 8 |
+
RUN sed -i 's/main/main contrib non-free/' /etc/apt/sources.list && \
|
| 9 |
+
apt-get update && \
|
| 10 |
+
apt-get install -y --no-install-recommends \
|
| 11 |
+
unrar \
|
| 12 |
+
libgl1 \
|
| 13 |
+
libglib2.0-0 \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 18 |
+
|
| 19 |
+
# Python + dependencies
|
| 20 |
+
RUN apt-get update && apt-get install -y python3 python3-pip git && \
|
| 21 |
+
pip3 install --upgrade pip
|
| 22 |
+
|
| 23 |
+
# Set working dir
|
| 24 |
+
WORKDIR /app
|
| 25 |
+
|
| 26 |
+
# Copy and install requirements
|
| 27 |
+
COPY requirements.txt ./
|
| 28 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 29 |
+
|
| 30 |
+
# Copy app code
|
| 31 |
+
COPY . .
|
| 32 |
+
|
| 33 |
+
# ybyjngamhtcuaupc gsmt
|
| 34 |
+
|
| 35 |
+
# Make the entire /app directory fully writeable for all users
|
| 36 |
+
RUN chmod -R 777 /app
|
| 37 |
+
|
| 38 |
+
# Ensure the app runs as the same user as the Space UI
|
| 39 |
+
RUN useradd -m -u 1000 user
|
| 40 |
+
USER user
|
| 41 |
+
|
| 42 |
+
# Launch FastAPI download server on container start
|
| 43 |
+
CMD ["uvicorn", "tensor_server:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 44 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.0
|
| 2 |
+
uvicorn==0.23.2
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
psutil>=5.9.0
|
| 6 |
+
pydantic>=2.0.0
|
tensor_server.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import psutil
|
| 5 |
+
import asyncio
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, List, Optional
|
| 8 |
+
from fastapi import FastAPI, HTTPException
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
import uvicorn
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# ===== Config =====
|
| 14 |
+
class Settings:
|
| 15 |
+
# Server configuration
|
| 16 |
+
HOST = "0.0.0.0" # Listen on all interfaces
|
| 17 |
+
PORT = 8001
|
| 18 |
+
SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
|
| 19 |
+
|
| 20 |
+
# The IP or hostname where this tensor server is accessible
|
| 21 |
+
PUBLIC_URL = os.getenv("PUBLIC_URL", f"https://fred808-ilob.hf.space")
|
| 22 |
+
|
| 23 |
+
# URLs for other services (should be actual IP addresses or hostnames)
|
| 24 |
+
CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
|
| 25 |
+
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 26 |
+
|
| 27 |
+
# Model settings
|
| 28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
MAX_BATCH_SIZE = 32
|
| 30 |
+
METRICS_UPDATE_INTERVAL = 5 # seconds
|
| 31 |
+
MODEL_DIR = "model_chunks"
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_env(cls):
|
| 35 |
+
"""Load settings from environment variables"""
|
| 36 |
+
cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
|
| 37 |
+
cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
|
| 38 |
+
cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
|
| 39 |
+
cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
|
| 40 |
+
cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
|
| 41 |
+
return cls
|
| 42 |
+
|
| 43 |
+
# ===== Models =====
|
| 44 |
+
class ModelChunk(BaseModel):
|
| 45 |
+
"""Represents a received model chunk configuration"""
|
| 46 |
+
chunk_id: int
|
| 47 |
+
files: List[str]
|
| 48 |
+
config: Dict
|
| 49 |
+
|
| 50 |
+
class InferenceRequest(BaseModel):
|
| 51 |
+
"""Represents an inference request"""
|
| 52 |
+
inputs: List[List[float]]
|
| 53 |
+
batch_size: Optional[int] = None
|
| 54 |
+
|
| 55 |
+
class MetricsData(BaseModel):
|
| 56 |
+
"""Server metrics data"""
|
| 57 |
+
cpu_usage: float
|
| 58 |
+
memory_usage: float
|
| 59 |
+
gpu_usage: Optional[float]
|
| 60 |
+
active_requests: int
|
| 61 |
+
total_requests: int
|
| 62 |
+
average_response_time: float
|
| 63 |
+
last_error: Optional[str]
|
| 64 |
+
error_count: int
|
| 65 |
+
|
| 66 |
+
# ===== FastAPI App =====
|
| 67 |
+
app = FastAPI(
|
| 68 |
+
title="Tensor Server",
|
| 69 |
+
description="Handles model chunk computations",
|
| 70 |
+
version="1.0.0"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# ===== State =====
|
| 74 |
+
class ServerState:
|
| 75 |
+
def __init__(self):
|
| 76 |
+
self.loaded_chunks: Dict[int, torch.nn.Module] = {}
|
| 77 |
+
self.active_requests: int = 0
|
| 78 |
+
self.total_requests: int = 0
|
| 79 |
+
self.request_times: List[float] = []
|
| 80 |
+
self.error_count: int = 0
|
| 81 |
+
self.last_error: Optional[str] = None
|
| 82 |
+
self.is_computing: bool = False
|
| 83 |
+
|
| 84 |
+
state = ServerState()
|
| 85 |
+
|
| 86 |
+
# ===== Metrics Collection =====
|
| 87 |
+
async def collect_metrics() -> MetricsData:
|
| 88 |
+
"""Collect current server metrics"""
|
| 89 |
+
# CPU and memory metrics
|
| 90 |
+
cpu_usage = psutil.cpu_percent()
|
| 91 |
+
memory = psutil.virtual_memory()
|
| 92 |
+
memory_usage = memory.percent
|
| 93 |
+
|
| 94 |
+
# GPU metrics if available
|
| 95 |
+
gpu_usage = None
|
| 96 |
+
if torch.cuda.is_available():
|
| 97 |
+
try:
|
| 98 |
+
gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
|
| 99 |
+
except:
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
# Calculate average response time
|
| 103 |
+
avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
|
| 104 |
+
|
| 105 |
+
return MetricsData(
|
| 106 |
+
cpu_usage=cpu_usage,
|
| 107 |
+
memory_usage=memory_usage,
|
| 108 |
+
gpu_usage=gpu_usage,
|
| 109 |
+
active_requests=state.active_requests,
|
| 110 |
+
total_requests=state.total_requests,
|
| 111 |
+
average_response_time=avg_response_time,
|
| 112 |
+
last_error=state.last_error,
|
| 113 |
+
error_count=state.error_count
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
async def update_metrics_loop():
|
| 117 |
+
"""Background task to update metrics periodically"""
|
| 118 |
+
while True:
|
| 119 |
+
try:
|
| 120 |
+
metrics = await collect_metrics()
|
| 121 |
+
# Store metrics for health checks
|
| 122 |
+
state.current_metrics = metrics
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"[ERROR] Failed to update metrics: {str(e)}")
|
| 125 |
+
await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
|
| 126 |
+
|
| 127 |
+
# ===== Helper Functions =====
|
| 128 |
+
def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
|
| 129 |
+
"""Load a model chunk into memory"""
|
| 130 |
+
try:
|
| 131 |
+
# Create chunk directory if it doesn't exist
|
| 132 |
+
os.makedirs(Settings.MODEL_DIR, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
# Get chunk configuration
|
| 135 |
+
input_size = chunk.config["input_size"]
|
| 136 |
+
output_size = chunk.config["output_size"]
|
| 137 |
+
weight_keys = chunk.config["weight_keys"]
|
| 138 |
+
|
| 139 |
+
# Create a simple linear transformation for this chunk
|
| 140 |
+
chunk_model = torch.nn.Linear(input_size, output_size)
|
| 141 |
+
chunk_model = chunk_model.to(Settings.DEVICE)
|
| 142 |
+
|
| 143 |
+
# Load the weights
|
| 144 |
+
chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
|
| 145 |
+
if os.path.exists(chunk_file):
|
| 146 |
+
weights = torch.load(chunk_file, map_location=Settings.DEVICE)
|
| 147 |
+
|
| 148 |
+
# Initialize weights from the loaded state dict
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
# Combine weights if multiple keys
|
| 151 |
+
if len(weight_keys) > 1:
|
| 152 |
+
combined_weight = torch.cat([weights[k] for k in weight_keys], dim=0)
|
| 153 |
+
chunk_model.weight.copy_(combined_weight)
|
| 154 |
+
else:
|
| 155 |
+
chunk_model.weight.copy_(weights[weight_keys[0]])
|
| 156 |
+
|
| 157 |
+
return chunk_model
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
raise Exception(f"Failed to load chunk: {str(e)}")
|
| 161 |
+
|
| 162 |
+
async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""Process input tensor through the specified chunk"""
|
| 164 |
+
if chunk_id not in state.loaded_chunks:
|
| 165 |
+
raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
|
| 166 |
+
|
| 167 |
+
chunk_model = state.loaded_chunks[chunk_id]
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
outputs = chunk_model(inputs)
|
| 170 |
+
return outputs
|
| 171 |
+
|
| 172 |
+
# ===== API Endpoints =====
|
| 173 |
+
@app.get("/health")
|
| 174 |
+
async def health_check():
|
| 175 |
+
"""Health check endpoint"""
|
| 176 |
+
metrics = await collect_metrics()
|
| 177 |
+
return {
|
| 178 |
+
"status": "healthy",
|
| 179 |
+
"device": Settings.DEVICE,
|
| 180 |
+
"loaded_chunks": list(state.loaded_chunks.keys()),
|
| 181 |
+
"metrics": metrics.dict()
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
@app.get("/metrics")
|
| 185 |
+
async def get_metrics():
|
| 186 |
+
"""Get current server metrics"""
|
| 187 |
+
return await collect_metrics()
|
| 188 |
+
|
| 189 |
+
@app.post("/load_chunk")
|
| 190 |
+
async def load_model_chunk(chunk: ModelChunk):
|
| 191 |
+
"""Load a model chunk into memory"""
|
| 192 |
+
try:
|
| 193 |
+
# Load the chunk
|
| 194 |
+
chunk_model = load_chunk(chunk)
|
| 195 |
+
state.loaded_chunks[chunk.chunk_id] = chunk_model
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"status": "loaded",
|
| 199 |
+
"chunk_id": chunk.chunk_id,
|
| 200 |
+
"device": str(next(chunk_model.parameters()).device)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
state.error_count += 1
|
| 205 |
+
state.last_error = str(e)
|
| 206 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 207 |
+
|
| 208 |
+
@app.post("/compute/{chunk_id}")
|
| 209 |
+
async def compute(chunk_id: int, request: InferenceRequest):
|
| 210 |
+
"""Perform computation on inputs using specified chunk"""
|
| 211 |
+
try:
|
| 212 |
+
start_time = datetime.now()
|
| 213 |
+
state.active_requests += 1
|
| 214 |
+
state.total_requests += 1
|
| 215 |
+
|
| 216 |
+
# Convert inputs to tensor
|
| 217 |
+
inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
|
| 218 |
+
|
| 219 |
+
# Split into batches if needed
|
| 220 |
+
batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
|
| 221 |
+
if len(inputs) > batch_size:
|
| 222 |
+
batches = torch.split(inputs, batch_size)
|
| 223 |
+
outputs = []
|
| 224 |
+
for batch in batches:
|
| 225 |
+
batch_output = await process_tensor(chunk_id, batch)
|
| 226 |
+
outputs.append(batch_output)
|
| 227 |
+
output_tensor = torch.cat(outputs, dim=0)
|
| 228 |
+
else:
|
| 229 |
+
output_tensor = await process_tensor(chunk_id, inputs)
|
| 230 |
+
|
| 231 |
+
# Convert output to list
|
| 232 |
+
output_list = output_tensor.cpu().numpy().tolist()
|
| 233 |
+
|
| 234 |
+
# Update metrics
|
| 235 |
+
end_time = datetime.now()
|
| 236 |
+
processing_time = (end_time - start_time).total_seconds()
|
| 237 |
+
state.request_times.append(processing_time)
|
| 238 |
+
# Keep only last 100 request times
|
| 239 |
+
state.request_times = state.request_times[-100:]
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"outputs": output_list,
|
| 243 |
+
"processing_time": processing_time
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
state.error_count += 1
|
| 248 |
+
state.last_error = str(e)
|
| 249 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 250 |
+
|
| 251 |
+
finally:
|
| 252 |
+
state.active_requests -= 1
|
| 253 |
+
|
| 254 |
+
@app.on_event("startup")
|
| 255 |
+
async def startup_event():
|
| 256 |
+
"""Start background tasks"""
|
| 257 |
+
asyncio.create_task(update_metrics_loop())
|
| 258 |
+
|
| 259 |
+
# ===== Main Execution =====
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
|
| 262 |
+
print(f"[INFO] Starting tensor server on port {port}")
|
| 263 |
+
print(f"[INFO] Using device: {Settings.DEVICE}")
|
| 264 |
+
print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
|
| 265 |
+
|
| 266 |
+
uvicorn.run(
|
| 267 |
+
"tensor_server:app",
|
| 268 |
+
host="0.0.0.0",
|
| 269 |
+
port=port,
|
| 270 |
+
reload=False
|
| 271 |
+
)
|