tser1 / app.py
Fred808's picture
Update app.py
c2fa8c2 verified
"""
Tensor Server A - Handles partial matrix multiplication for Tensor Parallelism (TP)
This server loads a shard of the model weights from disk and performs MatMul operations.
"""
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import uvicorn
import json
SERVER_PORT = 8001
SERVER_ID = "A"
app = FastAPI()
# Global variables to hold weight shards
WEIGHT_SHARDS = {} # {layer_id: {weight_name: tensor}}
class MatMulRequest(BaseModel):
layer_id: int
weight_name: str # e.g., "c_fc", "c_proj", "attn_c_attn", "attn_c_proj"
input_tensor: list # [batch, seq_len, hidden_dim] or flattened
input_shape: list # Original shape before flattening
@app.on_event("startup")
async def startup_event():
"""Load weight shards from disk during server startup."""
global WEIGHT_SHARDS
print(f"[Tensor Server {SERVER_ID}] Starting up on port {SERVER_PORT}...")
print(f"[Tensor Server {SERVER_ID}] Waiting for weight shards to be loaded...")
# Weights will be loaded dynamically when the control server sends them
@app.post("/load_weights")
async def load_weights(data: dict):
"""
Dynamically load weight shards sent by the control server.
data = {
"layer_id": int,
"weight_name": str,
"weight_shard": list (serialized tensor),
"weight_shape": list
}
"""
global WEIGHT_SHARDS
try:
layer_id = data["layer_id"]
weight_name = data["weight_name"]
weight_shard_list = data["weight_shard"]
weight_shape = data["weight_shape"]
# Convert list back to tensor
weight_tensor = torch.tensor(weight_shard_list, dtype=torch.float32).reshape(weight_shape)
if layer_id not in WEIGHT_SHARDS:
WEIGHT_SHARDS[layer_id] = {}
WEIGHT_SHARDS[layer_id][weight_name] = weight_tensor
print(f"[Tensor Server {SERVER_ID}] Loaded weight shard: layer={layer_id}, weight={weight_name}, shape={weight_shape}")
return {"status": "success", "message": f"Weight shard loaded: {weight_name}"}
except Exception as e:
print(f"[Tensor Server {SERVER_ID}] Error loading weights: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/matmul")
async def matmul_op(data: MatMulRequest):
"""
Perform partial matrix multiplication: Y_shard = X @ W_shard
"""
try:
layer_id = data.layer_id
weight_name = data.weight_name
input_list = data.input_tensor
input_shape = data.input_shape
# Retrieve the weight shard
if layer_id not in WEIGHT_SHARDS or weight_name not in WEIGHT_SHARDS[layer_id]:
raise ValueError(f"Weight shard not found: layer={layer_id}, weight={weight_name}")
weight_shard = WEIGHT_SHARDS[layer_id][weight_name]
# Convert input list back to tensor
input_tensor = torch.tensor(input_list, dtype=torch.float32).reshape(input_shape)
# Flatten batch and sequence dimensions for MatMul
original_shape = input_tensor.shape # [batch, seq_len, hidden_dim]
X = input_tensor.view(-1, original_shape[-1]) # [batch*seq_len, hidden_dim]
# Perform the partial matrix multiplication
# X: [batch*seq_len, input_dim]
# weight_shard: [input_dim, output_dim/2]
# Y_shard: [batch*seq_len, output_dim/2]
Y_shard = X @ weight_shard
# Reshape back to [batch, seq_len, output_dim/2]
output_tensor = Y_shard.view(original_shape[0], original_shape[1], -1)
# Convert result to list for JSON serialization
return {
"status": "success",
"output_shard": output_tensor.tolist(),
"output_shape": list(output_tensor.shape)
}
except Exception as e:
print(f"[Tensor Server {SERVER_ID}] Error in matmul: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "server_id": SERVER_ID}
if __name__ == "__main__":
uvicorn.run(app:app, host="0.0.0.0", port=SERVER_PORT)