|
|
""" |
|
|
WAN-Distributed JAX Inference on Hugging Face Spaces |
|
|
Each Space runs this app and can be configured as head or worker. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import threading |
|
|
import queue |
|
|
from typing import Dict, List, Optional, Any |
|
|
from dataclasses import dataclass, field |
|
|
import hashlib |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import requests |
|
|
|
|
|
|
|
|
os.environ["JAX_PLATFORMS"] = "cpu" |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class NodeConfig: |
|
|
"""Node configuration from environment.""" |
|
|
role: str = os.environ.get("NODE_ROLE", "worker") |
|
|
node_id: str = os.environ.get("NODE_ID", hashlib.md5(os.urandom(8)).hexdigest()[:8]) |
|
|
head_url: str = os.environ.get("HEAD_URL", "") |
|
|
secret_token: str = os.environ.get("SECRET_TOKEN", "default-token") |
|
|
port: int = int(os.environ.get("PORT", "7860")) |
|
|
|
|
|
|
|
|
CONFIG = NodeConfig() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClusterState: |
|
|
"""Shared state for the cluster.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.workers: Dict[str, Dict] = {} |
|
|
self.shards: Dict[str, np.ndarray] = {} |
|
|
self.lock = threading.Lock() |
|
|
self.is_initialized = False |
|
|
self.pending_results: Dict[str, Any] = {} |
|
|
self.request_queue: queue.Queue = queue.Queue() |
|
|
|
|
|
def register_worker(self, worker_id: str, url: str, info: Dict) -> bool: |
|
|
with self.lock: |
|
|
self.workers[worker_id] = { |
|
|
"url": url, |
|
|
"info": info, |
|
|
"registered_at": time.time(), |
|
|
"last_seen": time.time(), |
|
|
"status": "active" |
|
|
} |
|
|
return True |
|
|
|
|
|
def get_workers(self) -> List[Dict]: |
|
|
with self.lock: |
|
|
return [ |
|
|
{"worker_id": wid, **winfo} |
|
|
for wid, winfo in self.workers.items() |
|
|
if winfo.get("status") == "active" |
|
|
] |
|
|
|
|
|
def store_shard(self, name: str, data: np.ndarray): |
|
|
with self.lock: |
|
|
self.shards[name] = data |
|
|
|
|
|
def get_shard(self, name: str) -> Optional[np.ndarray]: |
|
|
with self.lock: |
|
|
return self.shards.get(name) |
|
|
|
|
|
def heartbeat(self, worker_id: str): |
|
|
with self.lock: |
|
|
if worker_id in self.workers: |
|
|
self.workers[worker_id]["last_seen"] = time.time() |
|
|
|
|
|
|
|
|
STATE = ClusterState() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_request(url: str, endpoint: str, data: Dict, timeout: int = 30) -> Optional[Dict]: |
|
|
"""Make HTTP request to another Space.""" |
|
|
try: |
|
|
full_url = f"{url.rstrip('/')}/api/{endpoint}" |
|
|
headers = {"Authorization": f"Bearer {CONFIG.secret_token}"} |
|
|
|
|
|
response = requests.post( |
|
|
full_url, |
|
|
json=data, |
|
|
headers=headers, |
|
|
timeout=timeout |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
return response.json() |
|
|
else: |
|
|
print(f"Request failed: {response.status_code} - {response.text}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Request error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def worker_register_with_head(): |
|
|
"""Register this worker with the head node.""" |
|
|
if not CONFIG.head_url: |
|
|
print("No HEAD_URL configured, cannot register") |
|
|
return False |
|
|
|
|
|
|
|
|
space_url = os.environ.get("SPACE_URL", f"http://localhost:{CONFIG.port}") |
|
|
|
|
|
result = make_request( |
|
|
CONFIG.head_url, |
|
|
"register_worker", |
|
|
{ |
|
|
"worker_id": CONFIG.node_id, |
|
|
"worker_url": space_url, |
|
|
"info": { |
|
|
"jax_devices": len(jax.devices()), |
|
|
"platform": jax.default_backend(), |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
if result and result.get("success"): |
|
|
print(f"Registered with head at {CONFIG.head_url}") |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def worker_heartbeat_loop(): |
|
|
"""Send periodic heartbeats to head.""" |
|
|
while True: |
|
|
time.sleep(30) |
|
|
if CONFIG.head_url: |
|
|
make_request( |
|
|
CONFIG.head_url, |
|
|
"heartbeat", |
|
|
{"worker_id": CONFIG.node_id} |
|
|
) |
|
|
|
|
|
|
|
|
def worker_forward_pass(input_data: np.ndarray) -> np.ndarray: |
|
|
"""Run forward pass on local shards.""" |
|
|
x = jnp.array(input_data) |
|
|
|
|
|
|
|
|
for name, weight in sorted(STATE.shards.items()): |
|
|
if weight.ndim == 2: |
|
|
|
|
|
if x.shape[-1] == weight.shape[0]: |
|
|
x = x @ weight |
|
|
elif weight.ndim == 1: |
|
|
|
|
|
if x.shape[-1] == weight.shape[0]: |
|
|
x = x + weight |
|
|
|
|
|
|
|
|
x = jax.nn.relu(x) |
|
|
|
|
|
return np.array(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def head_distribute_model(params: Dict[str, np.ndarray]) -> bool: |
|
|
"""Distribute model parameters to workers.""" |
|
|
workers = STATE.get_workers() |
|
|
if not workers: |
|
|
print("No workers available") |
|
|
return False |
|
|
|
|
|
|
|
|
param_list = list(params.items()) |
|
|
shards_per_worker = max(1, len(param_list) // len(workers)) |
|
|
|
|
|
for i, worker in enumerate(workers): |
|
|
start_idx = i * shards_per_worker |
|
|
end_idx = start_idx + shards_per_worker if i < len(workers) - 1 else len(param_list) |
|
|
|
|
|
worker_shards = dict(param_list[start_idx:end_idx]) |
|
|
|
|
|
for shard_name, shard_data in worker_shards.items(): |
|
|
result = make_request( |
|
|
worker["url"], |
|
|
"store_shard", |
|
|
{ |
|
|
"name": shard_name, |
|
|
"data": shard_data.tolist(), |
|
|
"shape": list(shard_data.shape), |
|
|
"dtype": str(shard_data.dtype) |
|
|
}, |
|
|
timeout=60 |
|
|
) |
|
|
|
|
|
if not result or not result.get("success"): |
|
|
print(f"Failed to send shard {shard_name} to worker {worker['worker_id']}") |
|
|
return False |
|
|
|
|
|
print(f"Distributed {len(params)} shards to {len(workers)} workers") |
|
|
return True |
|
|
|
|
|
|
|
|
def head_run_inference(input_data: np.ndarray) -> np.ndarray: |
|
|
"""Run distributed inference across workers.""" |
|
|
workers = STATE.get_workers() |
|
|
|
|
|
if not workers: |
|
|
|
|
|
return worker_forward_pass(input_data) |
|
|
|
|
|
|
|
|
current_data = input_data |
|
|
|
|
|
for worker in workers: |
|
|
result = make_request( |
|
|
worker["url"], |
|
|
"forward", |
|
|
{ |
|
|
"data": current_data.tolist(), |
|
|
"shape": list(current_data.shape), |
|
|
}, |
|
|
timeout=60 |
|
|
) |
|
|
|
|
|
if result and "output" in result: |
|
|
current_data = np.array(result["output"]) |
|
|
else: |
|
|
print(f"Worker {worker['worker_id']} failed, using local fallback") |
|
|
current_data = worker_forward_pass(current_data) |
|
|
|
|
|
return current_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def api_handler(endpoint: str, data: Dict) -> Dict: |
|
|
"""Handle API requests based on endpoint.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if endpoint == "register_worker": |
|
|
success = STATE.register_worker( |
|
|
data["worker_id"], |
|
|
data["worker_url"], |
|
|
data.get("info", {}) |
|
|
) |
|
|
return {"success": success, "message": "Worker registered" if success else "Failed"} |
|
|
|
|
|
elif endpoint == "heartbeat": |
|
|
STATE.heartbeat(data.get("worker_id", "")) |
|
|
return {"success": True} |
|
|
|
|
|
elif endpoint == "store_shard": |
|
|
shard_data = np.array(data["data"], dtype=data.get("dtype", "float32")) |
|
|
shard_data = shard_data.reshape(data["shape"]) |
|
|
STATE.store_shard(data["name"], shard_data) |
|
|
return {"success": True, "shard": data["name"]} |
|
|
|
|
|
elif endpoint == "forward": |
|
|
input_data = np.array(data["data"]).reshape(data["shape"]) |
|
|
output = worker_forward_pass(input_data) |
|
|
return {"output": output.tolist(), "shape": list(output.shape)} |
|
|
|
|
|
elif endpoint == "status": |
|
|
return { |
|
|
"node_id": CONFIG.node_id, |
|
|
"role": CONFIG.role, |
|
|
"workers": len(STATE.get_workers()), |
|
|
"shards": list(STATE.shards.keys()), |
|
|
"jax_devices": len(jax.devices()), |
|
|
} |
|
|
|
|
|
elif endpoint == "get_workers": |
|
|
return {"workers": STATE.get_workers()} |
|
|
|
|
|
else: |
|
|
return {"error": f"Unknown endpoint: {endpoint}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_test_model(num_layers: int = 4, hidden_size: int = 128) -> Dict[str, np.ndarray]: |
|
|
"""Create a simple test model.""" |
|
|
params = {} |
|
|
|
|
|
for i in range(num_layers): |
|
|
params[f"layer_{i}_weight"] = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02 |
|
|
params[f"layer_{i}_bias"] = np.zeros(hidden_size, dtype=np.float32) |
|
|
|
|
|
return params |
|
|
|
|
|
|
|
|
def gradio_run_inference(input_text: str) -> str: |
|
|
"""Run inference from Gradio UI.""" |
|
|
|
|
|
tokens = np.array([ord(c) / 128.0 for c in input_text[:128]], dtype=np.float32) |
|
|
|
|
|
|
|
|
if len(tokens) < 128: |
|
|
tokens = np.pad(tokens, (0, 128 - len(tokens))) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
if CONFIG.role == "head": |
|
|
output = head_run_inference(tokens) |
|
|
else: |
|
|
output = worker_forward_pass(tokens) |
|
|
|
|
|
latency = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
result = f"Output shape: {output.shape}\n" |
|
|
result += f"Output mean: {output.mean():.4f}\n" |
|
|
result += f"Output std: {output.std():.4f}\n" |
|
|
result += f"Latency: {latency:.1f}ms\n" |
|
|
result += f"Workers used: {len(STATE.get_workers())}" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def gradio_get_status() -> str: |
|
|
"""Get cluster status for Gradio UI.""" |
|
|
status = { |
|
|
"Node ID": CONFIG.node_id, |
|
|
"Role": CONFIG.role, |
|
|
"JAX Devices": len(jax.devices()), |
|
|
"JAX Backend": jax.default_backend(), |
|
|
"Stored Shards": len(STATE.shards), |
|
|
"Shard Names": list(STATE.shards.keys())[:10], |
|
|
} |
|
|
|
|
|
if CONFIG.role == "head": |
|
|
workers = STATE.get_workers() |
|
|
status["Connected Workers"] = len(workers) |
|
|
status["Worker List"] = [ |
|
|
f"{w['worker_id']} @ {w['url']}" |
|
|
for w in workers |
|
|
] |
|
|
else: |
|
|
status["Head URL"] = CONFIG.head_url |
|
|
status["Registered"] = STATE.is_initialized |
|
|
|
|
|
return json.dumps(status, indent=2) |
|
|
|
|
|
|
|
|
def gradio_init_model(num_layers: int, hidden_size: int) -> str: |
|
|
"""Initialize and distribute model.""" |
|
|
params = create_test_model(int(num_layers), int(hidden_size)) |
|
|
|
|
|
if CONFIG.role == "head": |
|
|
workers = STATE.get_workers() |
|
|
if workers: |
|
|
success = head_distribute_model(params) |
|
|
if success: |
|
|
return f"Distributed {len(params)} shards to {len(workers)} workers" |
|
|
else: |
|
|
return "Failed to distribute model" |
|
|
else: |
|
|
|
|
|
for name, data in params.items(): |
|
|
STATE.store_shard(name, data) |
|
|
return f"No workers - stored {len(params)} shards locally" |
|
|
else: |
|
|
|
|
|
for name, data in params.items(): |
|
|
STATE.store_shard(name, data) |
|
|
return f"Stored {len(params)} shards locally" |
|
|
|
|
|
|
|
|
def gradio_register_worker(worker_url: str) -> str: |
|
|
"""Manually register a worker (for head node).""" |
|
|
if CONFIG.role != "head": |
|
|
return "Only head node can register workers" |
|
|
|
|
|
|
|
|
result = make_request(worker_url, "status", {}) |
|
|
|
|
|
if result: |
|
|
worker_id = result.get("node_id", f"worker_{len(STATE.workers)}") |
|
|
STATE.register_worker(worker_id, worker_url, result) |
|
|
return f"Registered worker {worker_id}" |
|
|
else: |
|
|
return f"Failed to reach worker at {worker_url}" |
|
|
|
|
|
|
|
|
def gradio_api_call(endpoint: str, json_data: str) -> str: |
|
|
"""Make API call (for testing).""" |
|
|
try: |
|
|
data = json.loads(json_data) if json_data else {} |
|
|
result = api_handler(endpoint, data) |
|
|
return json.dumps(result, indent=2) |
|
|
except Exception as e: |
|
|
return f"Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_app(): |
|
|
"""Create Gradio app based on node role.""" |
|
|
|
|
|
|
|
|
if CONFIG.role == "worker" and CONFIG.head_url: |
|
|
|
|
|
threading.Thread(target=lambda: time.sleep(5) or worker_register_with_head(), daemon=True).start() |
|
|
|
|
|
threading.Thread(target=worker_heartbeat_loop, daemon=True).start() |
|
|
|
|
|
|
|
|
with gr.Blocks(title=f"WAN-JAX {CONFIG.role.upper()} - {CONFIG.node_id}") as app: |
|
|
gr.Markdown(f""" |
|
|
# π WAN-Distributed JAX Inference |
|
|
|
|
|
**Node ID:** `{CONFIG.node_id}` | **Role:** `{CONFIG.role.upper()}` |
|
|
|
|
|
{"This is the **HEAD** node - it coordinates workers and runs inference." if CONFIG.role == "head" else "This is a **WORKER** node - it stores model shards and computes."} |
|
|
""") |
|
|
|
|
|
with gr.Tab("Status"): |
|
|
status_output = gr.Textbox(label="Cluster Status", lines=15) |
|
|
refresh_btn = gr.Button("Refresh Status") |
|
|
refresh_btn.click(gradio_get_status, outputs=status_output) |
|
|
|
|
|
|
|
|
app.load(gradio_get_status, outputs=status_output) |
|
|
|
|
|
with gr.Tab("Inference"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="Enter text to process...", |
|
|
lines=3 |
|
|
) |
|
|
infer_btn = gr.Button("Run Inference", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="Output", lines=8) |
|
|
|
|
|
infer_btn.click(gradio_run_inference, inputs=input_text, outputs=output_text) |
|
|
|
|
|
with gr.Tab("Model"): |
|
|
with gr.Row(): |
|
|
num_layers = gr.Slider(1, 12, value=4, step=1, label="Number of Layers") |
|
|
hidden_size = gr.Slider(32, 512, value=128, step=32, label="Hidden Size") |
|
|
|
|
|
init_btn = gr.Button("Initialize Model") |
|
|
init_output = gr.Textbox(label="Result") |
|
|
|
|
|
init_btn.click( |
|
|
gradio_init_model, |
|
|
inputs=[num_layers, hidden_size], |
|
|
outputs=init_output |
|
|
) |
|
|
|
|
|
if CONFIG.role == "head": |
|
|
with gr.Tab("Workers"): |
|
|
worker_url_input = gr.Textbox( |
|
|
label="Worker Space URL", |
|
|
placeholder="https://username-spacename.hf.space" |
|
|
) |
|
|
register_btn = gr.Button("Register Worker") |
|
|
register_output = gr.Textbox(label="Result") |
|
|
|
|
|
register_btn.click( |
|
|
gradio_register_worker, |
|
|
inputs=worker_url_input, |
|
|
outputs=register_output |
|
|
) |
|
|
|
|
|
with gr.Tab("API"): |
|
|
gr.Markdown(""" |
|
|
### Direct API Access |
|
|
Use this tab to test API endpoints directly. |
|
|
|
|
|
**Endpoints:** |
|
|
- `status` - Get node status |
|
|
- `register_worker` - Register a worker (head only) |
|
|
- `store_shard` - Store a model shard |
|
|
- `forward` - Run forward pass |
|
|
- `get_workers` - List workers (head only) |
|
|
""") |
|
|
|
|
|
endpoint_input = gr.Textbox(label="Endpoint", value="status") |
|
|
json_input = gr.Textbox(label="JSON Data", value="{}", lines=5) |
|
|
api_btn = gr.Button("Call API") |
|
|
api_output = gr.Textbox(label="Response", lines=10) |
|
|
|
|
|
api_btn.click( |
|
|
gradio_api_call, |
|
|
inputs=[endpoint_input, json_input], |
|
|
outputs=api_output |
|
|
) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
api_app = FastAPI() |
|
|
|
|
|
@api_app.post("/api/{endpoint}") |
|
|
async def api_endpoint(endpoint: str, request: Request): |
|
|
|
|
|
auth_header = request.headers.get("Authorization", "") |
|
|
if not auth_header.startswith("Bearer "): |
|
|
|
|
|
pass |
|
|
|
|
|
try: |
|
|
data = await request.json() |
|
|
except: |
|
|
data = {} |
|
|
|
|
|
result = api_handler(endpoint, data) |
|
|
return JSONResponse(result) |
|
|
|
|
|
@api_app.get("/api/status") |
|
|
async def get_status(): |
|
|
return JSONResponse(api_handler("status", {})) |
|
|
|
|
|
|
|
|
app = create_app() |
|
|
api_app = gr.mount_gradio_app(api_app, app, path="/") |
|
|
|
|
|
print("Running with FastAPI + Gradio") |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
app = create_app() |
|
|
print("Running with pure Gradio") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(f"Starting WAN-JAX Node") |
|
|
print(f" Node ID: {CONFIG.node_id}") |
|
|
print(f" Role: {CONFIG.role}") |
|
|
print(f" Head URL: {CONFIG.head_url}") |
|
|
print(f" JAX devices: {jax.devices()}") |
|
|
|
|
|
app = create_app() |
|
|
app.launch(server_name="0.0.0.0", server_port=CONFIG.port) |