Bc-AI commited on
Commit
4b3d37d
Β·
verified Β·
1 Parent(s): 199e2cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +584 -0
app.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WAN-Distributed JAX Inference on Hugging Face Spaces
3
+ Each Space runs this app and can be configured as head or worker.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import threading
10
+ import queue
11
+ from typing import Dict, List, Optional, Any
12
+ from dataclasses import dataclass, field
13
+ import hashlib
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+ import requests
18
+
19
+ # Use CPU JAX
20
+ os.environ["JAX_PLATFORMS"] = "cpu"
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+
25
+ # ============================================================================
26
+ # CONFIGURATION
27
+ # ============================================================================
28
+
29
+ @dataclass
30
+ class NodeConfig:
31
+ """Node configuration from environment."""
32
+ role: str = os.environ.get("NODE_ROLE", "worker") # "head" or "worker"
33
+ node_id: str = os.environ.get("NODE_ID", hashlib.md5(os.urandom(8)).hexdigest()[:8])
34
+ head_url: str = os.environ.get("HEAD_URL", "") # URL of head Space (for workers)
35
+ secret_token: str = os.environ.get("SECRET_TOKEN", "default-token")
36
+ port: int = int(os.environ.get("PORT", "7860"))
37
+
38
+
39
+ CONFIG = NodeConfig()
40
+
41
+
42
+ # ============================================================================
43
+ # SHARED STATE
44
+ # ============================================================================
45
+
46
+ class ClusterState:
47
+ """Shared state for the cluster."""
48
+
49
+ def __init__(self):
50
+ self.workers: Dict[str, Dict] = {} # worker_id -> info
51
+ self.shards: Dict[str, np.ndarray] = {} # shard_name -> data
52
+ self.lock = threading.Lock()
53
+ self.is_initialized = False
54
+ self.pending_results: Dict[str, Any] = {}
55
+ self.request_queue: queue.Queue = queue.Queue()
56
+
57
+ def register_worker(self, worker_id: str, url: str, info: Dict) -> bool:
58
+ with self.lock:
59
+ self.workers[worker_id] = {
60
+ "url": url,
61
+ "info": info,
62
+ "registered_at": time.time(),
63
+ "last_seen": time.time(),
64
+ "status": "active"
65
+ }
66
+ return True
67
+
68
+ def get_workers(self) -> List[Dict]:
69
+ with self.lock:
70
+ return [
71
+ {"worker_id": wid, **winfo}
72
+ for wid, winfo in self.workers.items()
73
+ if winfo.get("status") == "active"
74
+ ]
75
+
76
+ def store_shard(self, name: str, data: np.ndarray):
77
+ with self.lock:
78
+ self.shards[name] = data
79
+
80
+ def get_shard(self, name: str) -> Optional[np.ndarray]:
81
+ with self.lock:
82
+ return self.shards.get(name)
83
+
84
+ def heartbeat(self, worker_id: str):
85
+ with self.lock:
86
+ if worker_id in self.workers:
87
+ self.workers[worker_id]["last_seen"] = time.time()
88
+
89
+
90
+ STATE = ClusterState()
91
+
92
+
93
+ # ============================================================================
94
+ # HTTP COMMUNICATION LAYER
95
+ # ============================================================================
96
+
97
+ def make_request(url: str, endpoint: str, data: Dict, timeout: int = 30) -> Optional[Dict]:
98
+ """Make HTTP request to another Space."""
99
+ try:
100
+ full_url = f"{url.rstrip('/')}/api/{endpoint}"
101
+ headers = {"Authorization": f"Bearer {CONFIG.secret_token}"}
102
+
103
+ response = requests.post(
104
+ full_url,
105
+ json=data,
106
+ headers=headers,
107
+ timeout=timeout
108
+ )
109
+
110
+ if response.status_code == 200:
111
+ return response.json()
112
+ else:
113
+ print(f"Request failed: {response.status_code} - {response.text}")
114
+ return None
115
+ except Exception as e:
116
+ print(f"Request error: {e}")
117
+ return None
118
+
119
+
120
+ # ============================================================================
121
+ # WORKER LOGIC
122
+ # ============================================================================
123
+
124
+ def worker_register_with_head():
125
+ """Register this worker with the head node."""
126
+ if not CONFIG.head_url:
127
+ print("No HEAD_URL configured, cannot register")
128
+ return False
129
+
130
+ # Get this Space's URL from environment or construct it
131
+ space_url = os.environ.get("SPACE_URL", f"http://localhost:{CONFIG.port}")
132
+
133
+ result = make_request(
134
+ CONFIG.head_url,
135
+ "register_worker",
136
+ {
137
+ "worker_id": CONFIG.node_id,
138
+ "worker_url": space_url,
139
+ "info": {
140
+ "jax_devices": len(jax.devices()),
141
+ "platform": jax.default_backend(),
142
+ }
143
+ }
144
+ )
145
+
146
+ if result and result.get("success"):
147
+ print(f"Registered with head at {CONFIG.head_url}")
148
+ return True
149
+ return False
150
+
151
+
152
+ def worker_heartbeat_loop():
153
+ """Send periodic heartbeats to head."""
154
+ while True:
155
+ time.sleep(30)
156
+ if CONFIG.head_url:
157
+ make_request(
158
+ CONFIG.head_url,
159
+ "heartbeat",
160
+ {"worker_id": CONFIG.node_id}
161
+ )
162
+
163
+
164
+ def worker_forward_pass(input_data: np.ndarray) -> np.ndarray:
165
+ """Run forward pass on local shards."""
166
+ x = jnp.array(input_data)
167
+
168
+ # Apply each stored shard (simple linear layers for demo)
169
+ for name, weight in sorted(STATE.shards.items()):
170
+ if weight.ndim == 2:
171
+ # Matrix multiply for weight matrices
172
+ if x.shape[-1] == weight.shape[0]:
173
+ x = x @ weight
174
+ elif weight.ndim == 1:
175
+ # Add for biases
176
+ if x.shape[-1] == weight.shape[0]:
177
+ x = x + weight
178
+
179
+ # Apply simple activation
180
+ x = jax.nn.relu(x)
181
+
182
+ return np.array(x)
183
+
184
+
185
+ # ============================================================================
186
+ # HEAD NODE LOGIC
187
+ # ============================================================================
188
+
189
+ def head_distribute_model(params: Dict[str, np.ndarray]) -> bool:
190
+ """Distribute model parameters to workers."""
191
+ workers = STATE.get_workers()
192
+ if not workers:
193
+ print("No workers available")
194
+ return False
195
+
196
+ # Simple round-robin distribution
197
+ param_list = list(params.items())
198
+ shards_per_worker = max(1, len(param_list) // len(workers))
199
+
200
+ for i, worker in enumerate(workers):
201
+ start_idx = i * shards_per_worker
202
+ end_idx = start_idx + shards_per_worker if i < len(workers) - 1 else len(param_list)
203
+
204
+ worker_shards = dict(param_list[start_idx:end_idx])
205
+
206
+ for shard_name, shard_data in worker_shards.items():
207
+ result = make_request(
208
+ worker["url"],
209
+ "store_shard",
210
+ {
211
+ "name": shard_name,
212
+ "data": shard_data.tolist(),
213
+ "shape": list(shard_data.shape),
214
+ "dtype": str(shard_data.dtype)
215
+ },
216
+ timeout=60
217
+ )
218
+
219
+ if not result or not result.get("success"):
220
+ print(f"Failed to send shard {shard_name} to worker {worker['worker_id']}")
221
+ return False
222
+
223
+ print(f"Distributed {len(params)} shards to {len(workers)} workers")
224
+ return True
225
+
226
+
227
+ def head_run_inference(input_data: np.ndarray) -> np.ndarray:
228
+ """Run distributed inference across workers."""
229
+ workers = STATE.get_workers()
230
+
231
+ if not workers:
232
+ # No workers, run locally
233
+ return worker_forward_pass(input_data)
234
+
235
+ # Pipeline through workers
236
+ current_data = input_data
237
+
238
+ for worker in workers:
239
+ result = make_request(
240
+ worker["url"],
241
+ "forward",
242
+ {
243
+ "data": current_data.tolist(),
244
+ "shape": list(current_data.shape),
245
+ },
246
+ timeout=60
247
+ )
248
+
249
+ if result and "output" in result:
250
+ current_data = np.array(result["output"])
251
+ else:
252
+ print(f"Worker {worker['worker_id']} failed, using local fallback")
253
+ current_data = worker_forward_pass(current_data)
254
+
255
+ return current_data
256
+
257
+
258
+ # ============================================================================
259
+ # API ENDPOINTS (Gradio doesn't have native API, so we use a simple approach)
260
+ # ============================================================================
261
+
262
+ def api_handler(endpoint: str, data: Dict) -> Dict:
263
+ """Handle API requests based on endpoint."""
264
+
265
+ # Verify token
266
+ # (In production, check Authorization header)
267
+
268
+ if endpoint == "register_worker":
269
+ success = STATE.register_worker(
270
+ data["worker_id"],
271
+ data["worker_url"],
272
+ data.get("info", {})
273
+ )
274
+ return {"success": success, "message": "Worker registered" if success else "Failed"}
275
+
276
+ elif endpoint == "heartbeat":
277
+ STATE.heartbeat(data.get("worker_id", ""))
278
+ return {"success": True}
279
+
280
+ elif endpoint == "store_shard":
281
+ shard_data = np.array(data["data"], dtype=data.get("dtype", "float32"))
282
+ shard_data = shard_data.reshape(data["shape"])
283
+ STATE.store_shard(data["name"], shard_data)
284
+ return {"success": True, "shard": data["name"]}
285
+
286
+ elif endpoint == "forward":
287
+ input_data = np.array(data["data"]).reshape(data["shape"])
288
+ output = worker_forward_pass(input_data)
289
+ return {"output": output.tolist(), "shape": list(output.shape)}
290
+
291
+ elif endpoint == "status":
292
+ return {
293
+ "node_id": CONFIG.node_id,
294
+ "role": CONFIG.role,
295
+ "workers": len(STATE.get_workers()),
296
+ "shards": list(STATE.shards.keys()),
297
+ "jax_devices": len(jax.devices()),
298
+ }
299
+
300
+ elif endpoint == "get_workers":
301
+ return {"workers": STATE.get_workers()}
302
+
303
+ else:
304
+ return {"error": f"Unknown endpoint: {endpoint}"}
305
+
306
+
307
+ # ============================================================================
308
+ # GRADIO INTERFACE
309
+ # ============================================================================
310
+
311
+ def create_test_model(num_layers: int = 4, hidden_size: int = 128) -> Dict[str, np.ndarray]:
312
+ """Create a simple test model."""
313
+ params = {}
314
+
315
+ for i in range(num_layers):
316
+ params[f"layer_{i}_weight"] = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02
317
+ params[f"layer_{i}_bias"] = np.zeros(hidden_size, dtype=np.float32)
318
+
319
+ return params
320
+
321
+
322
+ def gradio_run_inference(input_text: str) -> str:
323
+ """Run inference from Gradio UI."""
324
+ # Simple tokenization (ASCII values normalized)
325
+ tokens = np.array([ord(c) / 128.0 for c in input_text[:128]], dtype=np.float32)
326
+
327
+ # Pad to fixed size
328
+ if len(tokens) < 128:
329
+ tokens = np.pad(tokens, (0, 128 - len(tokens)))
330
+
331
+ # Run inference
332
+ start_time = time.time()
333
+
334
+ if CONFIG.role == "head":
335
+ output = head_run_inference(tokens)
336
+ else:
337
+ output = worker_forward_pass(tokens)
338
+
339
+ latency = (time.time() - start_time) * 1000
340
+
341
+ # Format output
342
+ result = f"Output shape: {output.shape}\n"
343
+ result += f"Output mean: {output.mean():.4f}\n"
344
+ result += f"Output std: {output.std():.4f}\n"
345
+ result += f"Latency: {latency:.1f}ms\n"
346
+ result += f"Workers used: {len(STATE.get_workers())}"
347
+
348
+ return result
349
+
350
+
351
+ def gradio_get_status() -> str:
352
+ """Get cluster status for Gradio UI."""
353
+ status = {
354
+ "Node ID": CONFIG.node_id,
355
+ "Role": CONFIG.role,
356
+ "JAX Devices": len(jax.devices()),
357
+ "JAX Backend": jax.default_backend(),
358
+ "Stored Shards": len(STATE.shards),
359
+ "Shard Names": list(STATE.shards.keys())[:10], # First 10
360
+ }
361
+
362
+ if CONFIG.role == "head":
363
+ workers = STATE.get_workers()
364
+ status["Connected Workers"] = len(workers)
365
+ status["Worker List"] = [
366
+ f"{w['worker_id']} @ {w['url']}"
367
+ for w in workers
368
+ ]
369
+ else:
370
+ status["Head URL"] = CONFIG.head_url
371
+ status["Registered"] = STATE.is_initialized
372
+
373
+ return json.dumps(status, indent=2)
374
+
375
+
376
+ def gradio_init_model(num_layers: int, hidden_size: int) -> str:
377
+ """Initialize and distribute model."""
378
+ params = create_test_model(int(num_layers), int(hidden_size))
379
+
380
+ if CONFIG.role == "head":
381
+ workers = STATE.get_workers()
382
+ if workers:
383
+ success = head_distribute_model(params)
384
+ if success:
385
+ return f"Distributed {len(params)} shards to {len(workers)} workers"
386
+ else:
387
+ return "Failed to distribute model"
388
+ else:
389
+ # Store locally
390
+ for name, data in params.items():
391
+ STATE.store_shard(name, data)
392
+ return f"No workers - stored {len(params)} shards locally"
393
+ else:
394
+ # Worker stores locally
395
+ for name, data in params.items():
396
+ STATE.store_shard(name, data)
397
+ return f"Stored {len(params)} shards locally"
398
+
399
+
400
+ def gradio_register_worker(worker_url: str) -> str:
401
+ """Manually register a worker (for head node)."""
402
+ if CONFIG.role != "head":
403
+ return "Only head node can register workers"
404
+
405
+ # Ping the worker
406
+ result = make_request(worker_url, "status", {})
407
+
408
+ if result:
409
+ worker_id = result.get("node_id", f"worker_{len(STATE.workers)}")
410
+ STATE.register_worker(worker_id, worker_url, result)
411
+ return f"Registered worker {worker_id}"
412
+ else:
413
+ return f"Failed to reach worker at {worker_url}"
414
+
415
+
416
+ def gradio_api_call(endpoint: str, json_data: str) -> str:
417
+ """Make API call (for testing)."""
418
+ try:
419
+ data = json.loads(json_data) if json_data else {}
420
+ result = api_handler(endpoint, data)
421
+ return json.dumps(result, indent=2)
422
+ except Exception as e:
423
+ return f"Error: {e}"
424
+
425
+
426
+ # ============================================================================
427
+ # MAIN APP
428
+ # ============================================================================
429
+
430
+ def create_app():
431
+ """Create Gradio app based on node role."""
432
+
433
+ # Start background tasks
434
+ if CONFIG.role == "worker" and CONFIG.head_url:
435
+ # Register with head
436
+ threading.Thread(target=lambda: time.sleep(5) or worker_register_with_head(), daemon=True).start()
437
+ # Heartbeat loop
438
+ threading.Thread(target=worker_heartbeat_loop, daemon=True).start()
439
+
440
+ # Create Gradio interface
441
+ with gr.Blocks(title=f"WAN-JAX {CONFIG.role.upper()} - {CONFIG.node_id}") as app:
442
+ gr.Markdown(f"""
443
+ # 🌐 WAN-Distributed JAX Inference
444
+
445
+ **Node ID:** `{CONFIG.node_id}` | **Role:** `{CONFIG.role.upper()}`
446
+
447
+ {"This is the **HEAD** node - it coordinates workers and runs inference." if CONFIG.role == "head" else "This is a **WORKER** node - it stores model shards and computes."}
448
+ """)
449
+
450
+ with gr.Tab("Status"):
451
+ status_output = gr.Textbox(label="Cluster Status", lines=15)
452
+ refresh_btn = gr.Button("Refresh Status")
453
+ refresh_btn.click(gradio_get_status, outputs=status_output)
454
+
455
+ # Auto-refresh on load
456
+ app.load(gradio_get_status, outputs=status_output)
457
+
458
+ with gr.Tab("Inference"):
459
+ with gr.Row():
460
+ with gr.Column():
461
+ input_text = gr.Textbox(
462
+ label="Input Text",
463
+ placeholder="Enter text to process...",
464
+ lines=3
465
+ )
466
+ infer_btn = gr.Button("Run Inference", variant="primary")
467
+
468
+ with gr.Column():
469
+ output_text = gr.Textbox(label="Output", lines=8)
470
+
471
+ infer_btn.click(gradio_run_inference, inputs=input_text, outputs=output_text)
472
+
473
+ with gr.Tab("Model"):
474
+ with gr.Row():
475
+ num_layers = gr.Slider(1, 12, value=4, step=1, label="Number of Layers")
476
+ hidden_size = gr.Slider(32, 512, value=128, step=32, label="Hidden Size")
477
+
478
+ init_btn = gr.Button("Initialize Model")
479
+ init_output = gr.Textbox(label="Result")
480
+
481
+ init_btn.click(
482
+ gradio_init_model,
483
+ inputs=[num_layers, hidden_size],
484
+ outputs=init_output
485
+ )
486
+
487
+ if CONFIG.role == "head":
488
+ with gr.Tab("Workers"):
489
+ worker_url_input = gr.Textbox(
490
+ label="Worker Space URL",
491
+ placeholder="https://username-spacename.hf.space"
492
+ )
493
+ register_btn = gr.Button("Register Worker")
494
+ register_output = gr.Textbox(label="Result")
495
+
496
+ register_btn.click(
497
+ gradio_register_worker,
498
+ inputs=worker_url_input,
499
+ outputs=register_output
500
+ )
501
+
502
+ with gr.Tab("API"):
503
+ gr.Markdown("""
504
+ ### Direct API Access
505
+ Use this tab to test API endpoints directly.
506
+
507
+ **Endpoints:**
508
+ - `status` - Get node status
509
+ - `register_worker` - Register a worker (head only)
510
+ - `store_shard` - Store a model shard
511
+ - `forward` - Run forward pass
512
+ - `get_workers` - List workers (head only)
513
+ """)
514
+
515
+ endpoint_input = gr.Textbox(label="Endpoint", value="status")
516
+ json_input = gr.Textbox(label="JSON Data", value="{}", lines=5)
517
+ api_btn = gr.Button("Call API")
518
+ api_output = gr.Textbox(label="Response", lines=10)
519
+
520
+ api_btn.click(
521
+ gradio_api_call,
522
+ inputs=[endpoint_input, json_input],
523
+ outputs=api_output
524
+ )
525
+
526
+ return app
527
+
528
+
529
+ # ============================================================================
530
+ # FASTAPI MOUNTING FOR TRUE API ACCESS
531
+ # ============================================================================
532
+
533
+ # Optional: Mount FastAPI for proper API endpoints
534
+ try:
535
+ from fastapi import FastAPI, Request, HTTPException
536
+ from fastapi.responses import JSONResponse
537
+
538
+ api_app = FastAPI()
539
+
540
+ @api_app.post("/api/{endpoint}")
541
+ async def api_endpoint(endpoint: str, request: Request):
542
+ # Check authorization
543
+ auth_header = request.headers.get("Authorization", "")
544
+ if not auth_header.startswith("Bearer "):
545
+ # Allow without auth for testing, but log it
546
+ pass
547
+
548
+ try:
549
+ data = await request.json()
550
+ except:
551
+ data = {}
552
+
553
+ result = api_handler(endpoint, data)
554
+ return JSONResponse(result)
555
+
556
+ @api_app.get("/api/status")
557
+ async def get_status():
558
+ return JSONResponse(api_handler("status", {}))
559
+
560
+ # Mount Gradio app
561
+ app = create_app()
562
+ api_app = gr.mount_gradio_app(api_app, app, path="/")
563
+
564
+ print("Running with FastAPI + Gradio")
565
+
566
+ except ImportError:
567
+ # FastAPI not available, use pure Gradio
568
+ app = create_app()
569
+ print("Running with pure Gradio")
570
+
571
+
572
+ # ============================================================================
573
+ # LAUNCH
574
+ # ============================================================================
575
+
576
+ if __name__ == "__main__":
577
+ print(f"Starting WAN-JAX Node")
578
+ print(f" Node ID: {CONFIG.node_id}")
579
+ print(f" Role: {CONFIG.role}")
580
+ print(f" Head URL: {CONFIG.head_url}")
581
+ print(f" JAX devices: {jax.devices()}")
582
+
583
+ app = create_app()
584
+ app.launch(server_name="0.0.0.0", server_port=CONFIG.port)