|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import psutil |
|
|
import asyncio |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Optional |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import uvicorn |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class Settings: |
|
|
|
|
|
HOST = "0.0.0.0" |
|
|
PORT = 8001 |
|
|
SERVER_ID = os.getenv("SERVER_ID", "tensor1") |
|
|
|
|
|
|
|
|
PUBLIC_URL = os.getenv("PUBLIC_URL", f"http://192.168.1.101:8001") |
|
|
|
|
|
|
|
|
CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000") |
|
|
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002") |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
MAX_BATCH_SIZE = 32 |
|
|
METRICS_UPDATE_INTERVAL = 5 |
|
|
MODEL_DIR = "model_chunks" |
|
|
|
|
|
@classmethod |
|
|
def from_env(cls): |
|
|
"""Load settings from environment variables""" |
|
|
cls.HOST = os.getenv("TENSOR_HOST", cls.HOST) |
|
|
cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT)) |
|
|
cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID) |
|
|
cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL) |
|
|
cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL) |
|
|
return cls |
|
|
|
|
|
|
|
|
class ModelChunk(BaseModel): |
|
|
"""Represents a received model chunk configuration""" |
|
|
chunk_id: int |
|
|
files: List[str] |
|
|
config: Dict |
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
|
"""Represents an inference request""" |
|
|
inputs: List[List[float]] |
|
|
batch_size: Optional[int] = None |
|
|
top_k: Optional[int] = 5 |
|
|
|
|
|
class MetricsData(BaseModel): |
|
|
"""Server metrics data""" |
|
|
cpu_usage: float |
|
|
memory_usage: float |
|
|
gpu_usage: Optional[float] |
|
|
active_requests: int |
|
|
total_requests: int |
|
|
average_response_time: float |
|
|
last_error: Optional[str] |
|
|
error_count: int |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Tensor Server", |
|
|
description="Handles model chunk computations", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
class ServerState: |
|
|
def __init__(self): |
|
|
self.loaded_chunks: Dict[int, torch.nn.Module] = {} |
|
|
self.active_requests: int = 0 |
|
|
self.total_requests: int = 0 |
|
|
self.request_times: List[float] = [] |
|
|
self.error_count: int = 0 |
|
|
self.last_error: Optional[str] = None |
|
|
self.is_computing: bool = False |
|
|
|
|
|
state = ServerState() |
|
|
|
|
|
|
|
|
async def collect_metrics() -> MetricsData: |
|
|
"""Collect current server metrics""" |
|
|
|
|
|
cpu_usage = psutil.cpu_percent() |
|
|
memory = psutil.virtual_memory() |
|
|
memory_usage = memory.percent |
|
|
|
|
|
|
|
|
gpu_usage = None |
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100 |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0 |
|
|
|
|
|
return MetricsData( |
|
|
cpu_usage=cpu_usage, |
|
|
memory_usage=memory_usage, |
|
|
gpu_usage=gpu_usage, |
|
|
active_requests=state.active_requests, |
|
|
total_requests=state.total_requests, |
|
|
average_response_time=avg_response_time, |
|
|
last_error=state.last_error, |
|
|
error_count=state.error_count |
|
|
) |
|
|
|
|
|
async def update_metrics_loop(): |
|
|
"""Background task to update metrics periodically""" |
|
|
while True: |
|
|
try: |
|
|
metrics = await collect_metrics() |
|
|
|
|
|
state.current_metrics = metrics |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to update metrics: {str(e)}") |
|
|
await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL) |
|
|
|
|
|
|
|
|
def load_chunk(chunk: ModelChunk) -> torch.nn.Module: |
|
|
"""Load a model chunk into memory""" |
|
|
try: |
|
|
|
|
|
os.makedirs(Settings.MODEL_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
chunk_config = chunk.config |
|
|
if "original_file" not in chunk_config: |
|
|
raise ValueError("Missing original_file in chunk configuration") |
|
|
|
|
|
|
|
|
chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0]) |
|
|
if not os.path.exists(chunk_file): |
|
|
|
|
|
raise ValueError(f"Chunk file not found: {chunk_file}") |
|
|
|
|
|
|
|
|
class ChunkBuffer(nn.Module): |
|
|
""" |
|
|
A single Florence-2 caption chunk that receives pre-encoded image embeddings |
|
|
and produces partial vocabulary logits. |
|
|
""" |
|
|
|
|
|
def __init__(self, chunk_path: str, config: dict): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
input_dim = config.get("input_dim", 1024) |
|
|
output_dim = config.get("output_dim", 1000) |
|
|
dropout = config.get("dropout", 0.1) |
|
|
|
|
|
|
|
|
self.chunk_path = chunk_path |
|
|
|
|
|
|
|
|
self.linear = nn.Linear(input_dim, output_dim) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(input_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
nn.init.xavier_uniform_(self.linear.weight) |
|
|
nn.init.zeros_(self.linear.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Florence-2 image embedding tensor, shape [batch, 1024] |
|
|
Returns: |
|
|
logits for this vocab shard, shape [batch, output_dim] |
|
|
""" |
|
|
|
|
|
x = self.norm(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
|
|
|
logits = self.linear(x) |
|
|
|
|
|
|
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
chunk_model = ChunkBuffer(chunk_file, chunk_config) |
|
|
|
|
|
chunk_model.config = chunk_config |
|
|
print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}") |
|
|
|
|
|
return chunk_model |
|
|
|
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to load chunk: {str(e)}") |
|
|
|
|
|
async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor: |
|
|
"""Process input tensor through the specified chunk""" |
|
|
if chunk_id not in state.loaded_chunks: |
|
|
raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded") |
|
|
|
|
|
chunk_model = state.loaded_chunks[chunk_id] |
|
|
with torch.no_grad(): |
|
|
outputs = chunk_model(inputs) |
|
|
return outputs |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
metrics = await collect_metrics() |
|
|
return { |
|
|
"status": "healthy", |
|
|
"device": Settings.DEVICE, |
|
|
"loaded_chunks": list(state.loaded_chunks.keys()), |
|
|
"metrics": metrics.dict() |
|
|
} |
|
|
|
|
|
@app.get("/metrics") |
|
|
async def get_metrics(): |
|
|
"""Get current server metrics""" |
|
|
return await collect_metrics() |
|
|
|
|
|
from fastapi import File, UploadFile |
|
|
|
|
|
@app.post("/load_chunk") |
|
|
async def load_model_chunk(chunk: ModelChunk): |
|
|
"""Register a chunk configuration""" |
|
|
try: |
|
|
|
|
|
os.makedirs(Settings.MODEL_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0]) |
|
|
state.chunk_configs = getattr(state, 'chunk_configs', {}) |
|
|
|
|
|
|
|
|
cfg = chunk.config or {} |
|
|
if 'vocab_offset' not in cfg: |
|
|
|
|
|
max_end = 0 |
|
|
for existing in state.chunk_configs.values(): |
|
|
try: |
|
|
e_cfg = existing.config if hasattr(existing, 'config') else existing |
|
|
e_offset = int(e_cfg.get('vocab_offset', 0)) |
|
|
e_shard = int(e_cfg.get('shard_dim', e_cfg.get('size', 1) or 1)) |
|
|
max_end = max(max_end, e_offset + e_shard) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
|
|
|
shard_dim = int(cfg.get('shard_dim', cfg.get('size', 1) or 1)) |
|
|
cfg['vocab_offset'] = max_end |
|
|
cfg['shard_dim'] = cfg.get('shard_dim', shard_dim) |
|
|
|
|
|
|
|
|
chunk.config = cfg |
|
|
state.chunk_configs[chunk.chunk_id] = chunk |
|
|
|
|
|
print(f"[INFO] Registered chunk {chunk.chunk_id} configuration") |
|
|
print(f"[INFO] Waiting for chunk data: {chunk.files[0]}") |
|
|
|
|
|
return { |
|
|
"status": "configured", |
|
|
"chunk_id": chunk.chunk_id, |
|
|
"ready_for_data": True |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
state.error_count += 1 |
|
|
state.last_error = str(e) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/upload_chunk_data/{chunk_id}") |
|
|
async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)): |
|
|
"""Receive the actual chunk data""" |
|
|
try: |
|
|
if chunk_id not in getattr(state, 'chunk_configs', {}): |
|
|
raise HTTPException(status_code=400, detail="Chunk configuration not registered") |
|
|
|
|
|
chunk = state.chunk_configs[chunk_id] |
|
|
chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0]) |
|
|
|
|
|
|
|
|
with open(chunk_file, 'wb') as f: |
|
|
content = await file.read() |
|
|
f.write(content) |
|
|
|
|
|
|
|
|
chunk_model = load_chunk(chunk) |
|
|
|
|
|
try: |
|
|
registered = getattr(state, 'chunk_configs', {}).get(chunk_id) |
|
|
if registered is not None: |
|
|
|
|
|
reg_cfg = registered.config or {} |
|
|
if hasattr(chunk_model, 'config'): |
|
|
chunk_model.config.update(reg_cfg) |
|
|
else: |
|
|
chunk_model.config = reg_cfg |
|
|
|
|
|
chunk_model.vocab_offset = int(reg_cfg.get('vocab_offset', 0)) |
|
|
except Exception: |
|
|
pass |
|
|
state.loaded_chunks[chunk_id] = chunk_model |
|
|
|
|
|
file_size = os.path.getsize(chunk_file) |
|
|
print(f"[INFO] Received and loaded chunk {chunk_id} data ({file_size} bytes)") |
|
|
|
|
|
return { |
|
|
"status": "loaded", |
|
|
"chunk_id": chunk_id, |
|
|
"size_bytes": file_size, |
|
|
"file": chunk.files[0] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
state.error_count += 1 |
|
|
state.last_error = str(e) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/compute/{chunk_id}") |
|
|
async def compute(chunk_id: int, request: InferenceRequest): |
|
|
"""Perform computation on inputs using specified chunk""" |
|
|
try: |
|
|
start_time = datetime.now() |
|
|
state.active_requests += 1 |
|
|
state.total_requests += 1 |
|
|
|
|
|
|
|
|
inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE) |
|
|
|
|
|
|
|
|
batch_size = request.batch_size or Settings.MAX_BATCH_SIZE |
|
|
if len(inputs) > batch_size: |
|
|
batches = torch.split(inputs, batch_size) |
|
|
outputs = [] |
|
|
for batch in batches: |
|
|
batch_output = await process_tensor(chunk_id, batch) |
|
|
outputs.append(batch_output) |
|
|
output_tensor = torch.cat(outputs, dim=0) |
|
|
else: |
|
|
output_tensor = await process_tensor(chunk_id, inputs) |
|
|
|
|
|
|
|
|
try: |
|
|
shard_np = output_tensor.cpu().numpy() |
|
|
except Exception: |
|
|
shard_np = None |
|
|
|
|
|
chunk_details = {} |
|
|
try: |
|
|
|
|
|
if shard_np is None: |
|
|
raise ValueError("Unable to convert output tensor to numpy") |
|
|
|
|
|
seq_len = shard_np.shape[0] |
|
|
shard_2d = shard_np.reshape(seq_len, -1) |
|
|
|
|
|
k = int(request.top_k or 5) |
|
|
k = min(k, shard_2d.shape[1]) if shard_2d.shape[1] > 0 else 0 |
|
|
|
|
|
|
|
|
if k > 0: |
|
|
topk_idx = np.argpartition(-shard_2d, k-1, axis=1)[:, :k] |
|
|
topk_vals = np.take_along_axis(shard_2d, topk_idx, axis=1) |
|
|
else: |
|
|
topk_idx = np.zeros((seq_len, 0), dtype=int) |
|
|
topk_vals = np.zeros((seq_len, 0), dtype=float) |
|
|
|
|
|
|
|
|
cfg = None |
|
|
try: |
|
|
chunk_model = state.loaded_chunks.get(chunk_id) |
|
|
cfg = getattr(chunk_model, 'config', None) or getattr(state, 'chunk_configs', {}).get(chunk_id, {}).config if chunk_id in getattr(state, 'chunk_configs', {}) else None |
|
|
except Exception: |
|
|
cfg = None |
|
|
|
|
|
vocab_offset = 0 |
|
|
if isinstance(cfg, dict): |
|
|
vocab_offset = int(cfg.get('vocab_offset', 0)) |
|
|
elif cfg is not None and hasattr(cfg, 'get'): |
|
|
vocab_offset = int(cfg.get('vocab_offset', 0)) |
|
|
|
|
|
per_position_topk = [] |
|
|
for pos_idx in range(seq_len): |
|
|
toks = [] |
|
|
for jj in range(topk_idx.shape[1]): |
|
|
local_idx = int(topk_idx[pos_idx, jj]) |
|
|
token_id = int(vocab_offset + local_idx) |
|
|
score = float(topk_vals[pos_idx, jj]) |
|
|
toks.append([token_id, score]) |
|
|
per_position_topk.append(toks) |
|
|
|
|
|
chunk_details[chunk_id] = { |
|
|
'logits_shard': shard_2d.tolist(), |
|
|
'topk': per_position_topk, |
|
|
'vocab_offset': vocab_offset, |
|
|
'shard_dim': shard_2d.shape[1] |
|
|
} |
|
|
except Exception as e: |
|
|
|
|
|
chunk_details = {chunk_id: {'error': str(e)}} |
|
|
|
|
|
|
|
|
output_list = output_tensor.cpu().numpy().tolist() |
|
|
|
|
|
|
|
|
end_time = datetime.now() |
|
|
processing_time = (end_time - start_time).total_seconds() |
|
|
state.request_times.append(processing_time) |
|
|
|
|
|
state.request_times = state.request_times[-100:] |
|
|
|
|
|
return { |
|
|
"outputs": output_list, |
|
|
"processing_time": processing_time, |
|
|
"chunk_details": chunk_details |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
state.error_count += 1 |
|
|
state.last_error = str(e) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
finally: |
|
|
state.active_requests -= 1 |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Start background tasks""" |
|
|
asyncio.create_task(update_metrics_loop()) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 8001)) |
|
|
print(f"[INFO] Starting tensor server on port {port}") |
|
|
print(f"[INFO] Using device: {Settings.DEVICE}") |
|
|
print(f"[INFO] API Documentation available at http://localhost:{port}/docs") |
|
|
|
|
|
uvicorn.run( |
|
|
"tensor_server:app", |
|
|
host="0.0.0.0", |
|
|
port=port, |
|
|
reload=False |
|
|
) |