import os import json import torch import psutil import asyncio from datetime import datetime from typing import Dict, List, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn import numpy as np import torch.nn as nn import torch.nn.functional as F # ===== Config ===== class Settings: # Server configuration HOST = "0.0.0.0" # Listen on all interfaces PORT = 8001 SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server # The IP or hostname where this tensor server is accessible PUBLIC_URL = os.getenv("PUBLIC_URL", f"http://192.168.1.101:8001") # URLs for other services (should be actual IP addresses or hostnames) CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000") AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002") # Model settings DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_BATCH_SIZE = 32 METRICS_UPDATE_INTERVAL = 5 # seconds MODEL_DIR = "model_chunks" @classmethod def from_env(cls): """Load settings from environment variables""" cls.HOST = os.getenv("TENSOR_HOST", cls.HOST) cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT)) cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID) cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL) cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL) return cls # ===== Models ===== class ModelChunk(BaseModel): """Represents a received model chunk configuration""" chunk_id: int files: List[str] config: Dict class InferenceRequest(BaseModel): """Represents an inference request""" inputs: List[List[float]] batch_size: Optional[int] = None top_k: Optional[int] = 5 class MetricsData(BaseModel): """Server metrics data""" cpu_usage: float memory_usage: float gpu_usage: Optional[float] active_requests: int total_requests: int average_response_time: float last_error: Optional[str] error_count: int # ===== FastAPI App ===== app = FastAPI( title="Tensor Server", description="Handles model chunk computations", version="1.0.0" ) # ===== State ===== class ServerState: def __init__(self): self.loaded_chunks: Dict[int, torch.nn.Module] = {} self.active_requests: int = 0 self.total_requests: int = 0 self.request_times: List[float] = [] self.error_count: int = 0 self.last_error: Optional[str] = None self.is_computing: bool = False state = ServerState() # ===== Metrics Collection ===== async def collect_metrics() -> MetricsData: """Collect current server metrics""" # CPU and memory metrics cpu_usage = psutil.cpu_percent() memory = psutil.virtual_memory() memory_usage = memory.percent # GPU metrics if available gpu_usage = None if torch.cuda.is_available(): try: gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100 except: pass # Calculate average response time avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0 return MetricsData( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, active_requests=state.active_requests, total_requests=state.total_requests, average_response_time=avg_response_time, last_error=state.last_error, error_count=state.error_count ) async def update_metrics_loop(): """Background task to update metrics periodically""" while True: try: metrics = await collect_metrics() # Store metrics for health checks state.current_metrics = metrics except Exception as e: print(f"[ERROR] Failed to update metrics: {str(e)}") await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL) # ===== Helper Functions ===== def load_chunk(chunk: ModelChunk) -> torch.nn.Module: """Load a model chunk into memory""" try: # Create chunk directory if it doesn't exist os.makedirs(Settings.MODEL_DIR, exist_ok=True) # Get chunk configuration chunk_config = chunk.config if "original_file" not in chunk_config: raise ValueError("Missing original_file in chunk configuration") # Save chunk data to file chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0]) if not os.path.exists(chunk_file): # We'll need to receive the actual chunk data in a separate request raise ValueError(f"Chunk file not found: {chunk_file}") # For raw binary chunks, we'll create a simple buffer module class ChunkBuffer(nn.Module): """ A single Florence-2 caption chunk that receives pre-encoded image embeddings and produces partial vocabulary logits. """ def __init__(self, chunk_path: str, config: dict): super().__init__() # Get dimensions from config input_dim = config.get("input_dim", 1024) # Florence-2 embedding dim output_dim = config.get("output_dim", 1000) # size of vocab shard dropout = config.get("dropout", 0.1) # Optional: chunk_path can point to pretrained weights self.chunk_path = chunk_path # Main projection layer: embedding → partial vocab logits self.linear = nn.Linear(input_dim, output_dim) # Optional normalization + dropout (stabilizes training or inference variance) self.norm = nn.LayerNorm(input_dim) self.dropout = nn.Dropout(dropout) # Initialize weights (small variance, stable logits) nn.init.xavier_uniform_(self.linear.weight) nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Florence-2 image embedding tensor, shape [batch, 1024] Returns: logits for this vocab shard, shape [batch, output_dim] """ # Normalize + dropout x = self.norm(x) x = self.dropout(x) # Linear projection to vocab slice logits = self.linear(x) # (Optional) softmax for probabilities, but usually the main model handles this # probs = F.softmax(logits, dim=-1) return logits # Create and return the chunk buffer chunk_model = ChunkBuffer(chunk_file, chunk_config) # Ensure the chunk_model.config is the up-to-date config (including any assigned offsets) chunk_model.config = chunk_config print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}") return chunk_model except Exception as e: raise Exception(f"Failed to load chunk: {str(e)}") async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor: """Process input tensor through the specified chunk""" if chunk_id not in state.loaded_chunks: raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded") chunk_model = state.loaded_chunks[chunk_id] with torch.no_grad(): outputs = chunk_model(inputs) return outputs # ===== API Endpoints ===== @app.get("/health") async def health_check(): """Health check endpoint""" metrics = await collect_metrics() return { "status": "healthy", "device": Settings.DEVICE, "loaded_chunks": list(state.loaded_chunks.keys()), "metrics": metrics.dict() } @app.get("/metrics") async def get_metrics(): """Get current server metrics""" return await collect_metrics() from fastapi import File, UploadFile @app.post("/load_chunk") async def load_model_chunk(chunk: ModelChunk): """Register a chunk configuration""" try: # Create model directory if it doesn't exist os.makedirs(Settings.MODEL_DIR, exist_ok=True) # Store the chunk metadata chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0]) state.chunk_configs = getattr(state, 'chunk_configs', {}) # Ensure a vocab_offset is present; if not, assign a non-overlapping offset cfg = chunk.config or {} if 'vocab_offset' not in cfg: # Compute next available offset from existing registered chunks max_end = 0 for existing in state.chunk_configs.values(): try: e_cfg = existing.config if hasattr(existing, 'config') else existing e_offset = int(e_cfg.get('vocab_offset', 0)) e_shard = int(e_cfg.get('shard_dim', e_cfg.get('size', 1) or 1)) max_end = max(max_end, e_offset + e_shard) except Exception: continue # If this chunk declares a shard_dim, use it; otherwise default to 1 shard_dim = int(cfg.get('shard_dim', cfg.get('size', 1) or 1)) cfg['vocab_offset'] = max_end cfg['shard_dim'] = cfg.get('shard_dim', shard_dim) # Store back the possibly-updated config chunk.config = cfg state.chunk_configs[chunk.chunk_id] = chunk print(f"[INFO] Registered chunk {chunk.chunk_id} configuration") print(f"[INFO] Waiting for chunk data: {chunk.files[0]}") return { "status": "configured", "chunk_id": chunk.chunk_id, "ready_for_data": True } except Exception as e: state.error_count += 1 state.last_error = str(e) raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload_chunk_data/{chunk_id}") async def upload_chunk_data(chunk_id: int, file: UploadFile = File(...)): """Receive the actual chunk data""" try: if chunk_id not in getattr(state, 'chunk_configs', {}): raise HTTPException(status_code=400, detail="Chunk configuration not registered") chunk = state.chunk_configs[chunk_id] chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0]) # Save the uploaded file with open(chunk_file, 'wb') as f: content = await file.read() f.write(content) # Now load the chunk chunk_model = load_chunk(chunk) # Ensure the loaded module has the registered config (including vocab_offset) try: registered = getattr(state, 'chunk_configs', {}).get(chunk_id) if registered is not None: # registered is a ModelChunk; merge config into module reg_cfg = registered.config or {} if hasattr(chunk_model, 'config'): chunk_model.config.update(reg_cfg) else: chunk_model.config = reg_cfg # expose vocab_offset on module chunk_model.vocab_offset = int(reg_cfg.get('vocab_offset', 0)) except Exception: pass state.loaded_chunks[chunk_id] = chunk_model file_size = os.path.getsize(chunk_file) print(f"[INFO] Received and loaded chunk {chunk_id} data ({file_size} bytes)") return { "status": "loaded", "chunk_id": chunk_id, "size_bytes": file_size, "file": chunk.files[0] } except Exception as e: state.error_count += 1 state.last_error = str(e) raise HTTPException(status_code=500, detail=str(e)) @app.post("/compute/{chunk_id}") async def compute(chunk_id: int, request: InferenceRequest): """Perform computation on inputs using specified chunk""" try: start_time = datetime.now() state.active_requests += 1 state.total_requests += 1 # Convert inputs to tensor inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE) # Split into batches if needed batch_size = request.batch_size or Settings.MAX_BATCH_SIZE if len(inputs) > batch_size: batches = torch.split(inputs, batch_size) outputs = [] for batch in batches: batch_output = await process_tensor(chunk_id, batch) outputs.append(batch_output) output_tensor = torch.cat(outputs, dim=0) else: output_tensor = await process_tensor(chunk_id, inputs) # Convert output to numpy for diagnostics try: shard_np = output_tensor.cpu().numpy() except Exception: shard_np = None chunk_details = {} try: # Normalize to 2D: (seq_len, shard_dim) if shard_np is None: raise ValueError("Unable to convert output tensor to numpy") seq_len = shard_np.shape[0] shard_2d = shard_np.reshape(seq_len, -1) k = int(request.top_k or 5) k = min(k, shard_2d.shape[1]) if shard_2d.shape[1] > 0 else 0 # compute local top-k per position if k > 0: topk_idx = np.argpartition(-shard_2d, k-1, axis=1)[:, :k] topk_vals = np.take_along_axis(shard_2d, topk_idx, axis=1) else: topk_idx = np.zeros((seq_len, 0), dtype=int) topk_vals = np.zeros((seq_len, 0), dtype=float) # determine vocab_offset from the loaded chunk config if available cfg = None try: chunk_model = state.loaded_chunks.get(chunk_id) cfg = getattr(chunk_model, 'config', None) or getattr(state, 'chunk_configs', {}).get(chunk_id, {}).config if chunk_id in getattr(state, 'chunk_configs', {}) else None except Exception: cfg = None vocab_offset = 0 if isinstance(cfg, dict): vocab_offset = int(cfg.get('vocab_offset', 0)) elif cfg is not None and hasattr(cfg, 'get'): vocab_offset = int(cfg.get('vocab_offset', 0)) per_position_topk = [] for pos_idx in range(seq_len): toks = [] for jj in range(topk_idx.shape[1]): local_idx = int(topk_idx[pos_idx, jj]) token_id = int(vocab_offset + local_idx) score = float(topk_vals[pos_idx, jj]) toks.append([token_id, score]) per_position_topk.append(toks) chunk_details[chunk_id] = { 'logits_shard': shard_2d.tolist(), 'topk': per_position_topk, 'vocab_offset': vocab_offset, 'shard_dim': shard_2d.shape[1] } except Exception as e: # If diagnostics fail, include error info but keep main outputs chunk_details = {chunk_id: {'error': str(e)}} # Convert output to list for backward compatibility output_list = output_tensor.cpu().numpy().tolist() # Update metrics end_time = datetime.now() processing_time = (end_time - start_time).total_seconds() state.request_times.append(processing_time) # Keep only last 100 request times state.request_times = state.request_times[-100:] return { "outputs": output_list, "processing_time": processing_time, "chunk_details": chunk_details } except Exception as e: state.error_count += 1 state.last_error = str(e) raise HTTPException(status_code=500, detail=str(e)) finally: state.active_requests -= 1 @app.on_event("startup") async def startup_event(): """Start background tasks""" asyncio.create_task(update_metrics_loop()) # ===== Main Execution ===== if __name__ == "__main__": port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller print(f"[INFO] Starting tensor server on port {port}") print(f"[INFO] Using device: {Settings.DEVICE}") print(f"[INFO] API Documentation available at http://localhost:{port}/docs") uvicorn.run( "tensor_server:app", host="0.0.0.0", port=port, reload=False )