|
|
""" |
|
|
Tensor Server A - Handles partial matrix multiplication for Tensor Parallelism (TP) |
|
|
This server loads a shard of the model weights from disk and performs MatMul operations. |
|
|
""" |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import numpy as np |
|
|
import uvicorn |
|
|
import json |
|
|
|
|
|
SERVER_PORT = 8001 |
|
|
SERVER_ID = "A" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
WEIGHT_SHARDS = {} |
|
|
|
|
|
class MatMulRequest(BaseModel): |
|
|
layer_id: int |
|
|
weight_name: str |
|
|
input_tensor: list |
|
|
input_shape: list |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load weight shards from disk during server startup.""" |
|
|
global WEIGHT_SHARDS |
|
|
print(f"[Tensor Server {SERVER_ID}] Starting up on port {SERVER_PORT}...") |
|
|
print(f"[Tensor Server {SERVER_ID}] Waiting for weight shards to be loaded...") |
|
|
|
|
|
|
|
|
@app.post("/load_weights") |
|
|
async def load_weights(data: dict): |
|
|
""" |
|
|
Dynamically load weight shards sent by the control server. |
|
|
data = { |
|
|
"layer_id": int, |
|
|
"weight_name": str, |
|
|
"weight_shard": list (serialized tensor), |
|
|
"weight_shape": list |
|
|
} |
|
|
""" |
|
|
global WEIGHT_SHARDS |
|
|
try: |
|
|
layer_id = data["layer_id"] |
|
|
weight_name = data["weight_name"] |
|
|
weight_shard_list = data["weight_shard"] |
|
|
weight_shape = data["weight_shape"] |
|
|
|
|
|
|
|
|
weight_tensor = torch.tensor(weight_shard_list, dtype=torch.float32).reshape(weight_shape) |
|
|
|
|
|
if layer_id not in WEIGHT_SHARDS: |
|
|
WEIGHT_SHARDS[layer_id] = {} |
|
|
|
|
|
WEIGHT_SHARDS[layer_id][weight_name] = weight_tensor |
|
|
print(f"[Tensor Server {SERVER_ID}] Loaded weight shard: layer={layer_id}, weight={weight_name}, shape={weight_shape}") |
|
|
|
|
|
return {"status": "success", "message": f"Weight shard loaded: {weight_name}"} |
|
|
except Exception as e: |
|
|
print(f"[Tensor Server {SERVER_ID}] Error loading weights: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/matmul") |
|
|
async def matmul_op(data: MatMulRequest): |
|
|
""" |
|
|
Perform partial matrix multiplication: Y_shard = X @ W_shard |
|
|
""" |
|
|
try: |
|
|
layer_id = data.layer_id |
|
|
weight_name = data.weight_name |
|
|
input_list = data.input_tensor |
|
|
input_shape = data.input_shape |
|
|
|
|
|
|
|
|
if layer_id not in WEIGHT_SHARDS or weight_name not in WEIGHT_SHARDS[layer_id]: |
|
|
raise ValueError(f"Weight shard not found: layer={layer_id}, weight={weight_name}") |
|
|
|
|
|
weight_shard = WEIGHT_SHARDS[layer_id][weight_name] |
|
|
|
|
|
|
|
|
input_tensor = torch.tensor(input_list, dtype=torch.float32).reshape(input_shape) |
|
|
|
|
|
|
|
|
original_shape = input_tensor.shape |
|
|
X = input_tensor.view(-1, original_shape[-1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Y_shard = X @ weight_shard |
|
|
|
|
|
|
|
|
output_tensor = Y_shard.view(original_shape[0], original_shape[1], -1) |
|
|
|
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"output_shard": output_tensor.tolist(), |
|
|
"output_shape": list(output_tensor.shape) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[Tensor Server {SERVER_ID}] Error in matmul: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
return {"status": "healthy", "server_id": SERVER_ID} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app:app, host="0.0.0.0", port=SERVER_PORT) |
|
|
|