""" Tensor Server A - Handles partial matrix multiplication for Tensor Parallelism (TP) This server loads a shard of the model weights from disk and performs MatMul operations. """ import torch import torch.nn.functional as F from fastapi import FastAPI, HTTPException from pydantic import BaseModel import numpy as np import uvicorn import json SERVER_PORT = 8001 SERVER_ID = "A" app = FastAPI() # Global variables to hold weight shards WEIGHT_SHARDS = {} # {layer_id: {weight_name: tensor}} class MatMulRequest(BaseModel): layer_id: int weight_name: str # e.g., "c_fc", "c_proj", "attn_c_attn", "attn_c_proj" input_tensor: list # [batch, seq_len, hidden_dim] or flattened input_shape: list # Original shape before flattening @app.on_event("startup") async def startup_event(): """Load weight shards from disk during server startup.""" global WEIGHT_SHARDS print(f"[Tensor Server {SERVER_ID}] Starting up on port {SERVER_PORT}...") print(f"[Tensor Server {SERVER_ID}] Waiting for weight shards to be loaded...") # Weights will be loaded dynamically when the control server sends them @app.post("/load_weights") async def load_weights(data: dict): """ Dynamically load weight shards sent by the control server. data = { "layer_id": int, "weight_name": str, "weight_shard": list (serialized tensor), "weight_shape": list } """ global WEIGHT_SHARDS try: layer_id = data["layer_id"] weight_name = data["weight_name"] weight_shard_list = data["weight_shard"] weight_shape = data["weight_shape"] # Convert list back to tensor weight_tensor = torch.tensor(weight_shard_list, dtype=torch.float32).reshape(weight_shape) if layer_id not in WEIGHT_SHARDS: WEIGHT_SHARDS[layer_id] = {} WEIGHT_SHARDS[layer_id][weight_name] = weight_tensor print(f"[Tensor Server {SERVER_ID}] Loaded weight shard: layer={layer_id}, weight={weight_name}, shape={weight_shape}") return {"status": "success", "message": f"Weight shard loaded: {weight_name}"} except Exception as e: print(f"[Tensor Server {SERVER_ID}] Error loading weights: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/matmul") async def matmul_op(data: MatMulRequest): """ Perform partial matrix multiplication: Y_shard = X @ W_shard """ try: layer_id = data.layer_id weight_name = data.weight_name input_list = data.input_tensor input_shape = data.input_shape # Retrieve the weight shard if layer_id not in WEIGHT_SHARDS or weight_name not in WEIGHT_SHARDS[layer_id]: raise ValueError(f"Weight shard not found: layer={layer_id}, weight={weight_name}") weight_shard = WEIGHT_SHARDS[layer_id][weight_name] # Convert input list back to tensor input_tensor = torch.tensor(input_list, dtype=torch.float32).reshape(input_shape) # Flatten batch and sequence dimensions for MatMul original_shape = input_tensor.shape # [batch, seq_len, hidden_dim] X = input_tensor.view(-1, original_shape[-1]) # [batch*seq_len, hidden_dim] # Perform the partial matrix multiplication # X: [batch*seq_len, input_dim] # weight_shard: [input_dim, output_dim/2] # Y_shard: [batch*seq_len, output_dim/2] Y_shard = X @ weight_shard # Reshape back to [batch, seq_len, output_dim/2] output_tensor = Y_shard.view(original_shape[0], original_shape[1], -1) # Convert result to list for JSON serialization return { "status": "success", "output_shard": output_tensor.tolist(), "output_shape": list(output_tensor.shape) } except Exception as e: print(f"[Tensor Server {SERVER_ID}] Error in matmul: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint.""" return {"status": "healthy", "server_id": SERVER_ID} if __name__ == "__main__": uvicorn.run(app:app, host="0.0.0.0", port=SERVER_PORT)