SACCP-head / app.py
Bc-AI's picture
Update app.py
41df2ce verified
"""
SharePUTERβ„’ v2.1 Head Node
Result fusion, live streaming, IPU/TPU, fixed stats display
"""
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, Any
import httpx, asyncio, uuid, os, json, time, ast, math
from datetime import datetime
app = FastAPI(title="SharePUTERβ„’ v2.1")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
DB_URL = os.environ.get("SACCP_DB_URL", "https://saccpshareputer1.pythonanywhere.com/")
DB_SECRET = os.environ.get("SACCP_DB_SECRET", "saccp-v2-master-key")
HEAD_ID = os.environ.get("HEAD_NODE_ID", f"head-{uuid.uuid4().hex[:8]}")
UNIT_PRICING = {
"cpu": 0.5, "ram": 0.1,
"gpu_t4": 4.0, "gpu_a10": 10.0, "gpu_a100": 25.0, "gpu_generic": 5.0,
"tpu_v2": 15.0, "tpu_v4": 40.0,
"ipu_pod4": 20.0, "ipu_pod16": 60.0,
}
NODE_PAY_RATES = {"RAM": 0.3, "CPU": 1.5, "GPU": 5.0, "TPU": 8.0, "IPU": 7.0}
active_nodes: Dict[str, dict] = {}
def db_h(): return {"X-SACCP-Secret": DB_SECRET, "Content-Type": "application/json"}
async def db(method, path, **kw):
try:
async with httpx.AsyncClient(timeout=30) as c:
return await getattr(c, method)(f"{DB_URL}{path}", headers=db_h(), **kw)
except Exception as e:
print(f"[HEAD] DB err {method} {path}: {e}")
return None
class UserCreate(BaseModel):
username: str
password: str
class TaskSubmit(BaseModel):
api_key: str
code: str
task_type: str = "general"
config: dict = {}
required_libs: list = []
class NodeRegister(BaseModel):
node_type: str
node_url: str
owner: str = "anonymous"
specs: dict = {}
installed_libs: list = []
# ─── Homepage (JS-loaded stats so page always renders) ─────────────────────
@app.get("/", response_class=HTMLResponse)
async def homepage():
return HTMLResponse(f"""<!DOCTYPE html>
<html><head><meta charset="UTF-8"><title>SharePUTERβ„’ v2.1</title>
<style>
body{{font-family:system-ui;background:#08080d;color:#ddd;margin:0;padding:40px;max-width:1000px;margin:0 auto}}
h1{{font-size:2.5rem;background:linear-gradient(135deg,#00ff88,#00aaff);-webkit-background-clip:text;-webkit-text-fill-color:transparent}}
.g{{display:grid;grid-template-columns:repeat(auto-fill,minmax(140px,1fr));gap:12px;margin:24px 0}}
.c{{background:#111;border:1px solid #222;border-radius:12px;padding:20px;text-align:center}}
.c .n{{font-size:1.8rem;font-weight:700;color:#00ff88;font-family:monospace}}
.c .l{{color:#888;font-size:.75rem;text-transform:uppercase;letter-spacing:1px;margin-top:4px}}
.pulse{{display:inline-block;width:8px;height:8px;border-radius:50%;background:#00ff88;
animation:p 2s infinite;margin-right:6px}}
@keyframes p{{0%,100%{{box-shadow:0 0 0 0 rgba(0,255,136,.4)}}50%{{box-shadow:0 0 0 6px rgba(0,255,136,0)}}}}
</style></head><body>
<h1>SharePUTERβ„’ v2.1</h1>
<p><span class="pulse"></span>Head Node <code>{HEAD_ID}</code> β€” Online</p>
<div class="g" id="stats"></div>
<script>
async function load(){{try{{const r=await fetch('/api/stats');const s=await r.json();
const items=[
['online_nodes','Nodes Online','#00ff88'],['cpu_nodes_online','CPU','#00aaff'],
['gpu_nodes_online','GPU','#aa66ff'],['tpu_nodes_online','TPU','#ff8844'],
['ipu_nodes_online','IPU','#00eeff'],['total_users','Users','#ffcc00'],
['active_tasks','Active Tasks','#00aaff'],['completed_tasks','Completed','#00ff88'],
['pending_fragments','Pending Frags','#ffcc00'],['completed_fragments','Done Frags','#00ff88'],
];
document.getElementById('stats').innerHTML=items.map(([k,l,c])=>
`<div class="c"><div class="n" style="color:${{c}}">${{s[k]||0}}</div><div class="l">${{l}}</div></div>`).join('')
}}catch(e){{document.getElementById('stats').innerHTML='<p style="color:#f44">Stats unavailable</p>'}}}}
load();setInterval(load,5000);
</script>
</body></html>""")
# ─── Auth ───────────────────────────────────────────────────────────────────
@app.post("/api/register")
async def register(u: UserCreate):
r = await db("post", "/users", json={"username": u.username, "password": u.password})
if not r or r.status_code != 201: raise HTTPException(r.status_code if r else 500, r.json().get("error") if r else "DB error")
return r.json()
@app.post("/api/login")
async def login(u: UserCreate):
r = await db("post", "/users/auth", json={"username": u.username, "password": u.password})
if not r or r.status_code != 200: raise HTTPException(401, "Invalid credentials")
return r.json()
@app.get("/api/balance")
async def balance(api_key: str):
r = await db("post", "/users/by_api_key", json={"api_key": api_key})
if not r or r.status_code != 200: raise HTTPException(401, "Invalid API key")
u = r.json(); return {"username": u["username"], "balance": u["balance"]}
@app.get("/api/capabilities")
async def caps():
r = await db("get", "/nodes/capabilities"); return r.json() if r and r.status_code == 200 else {}
@app.get("/api/pricing")
async def pricing(): return UNIT_PRICING
@app.get("/api/available_libs")
async def avlibs():
r = await db("get", "/libs/available"); return r.json() if r and r.status_code == 200 else {}
# ─── Cost ───────────────────────────────────────────────────────────────────
def calc_cost(cfg):
c = cfg.get("cpus", 1) * UNIT_PRICING["cpu"] + cfg.get("ram_gb", 4) * UNIT_PRICING["ram"]
gpus = cfg.get("gpus", 0)
gt = cfg.get("gpu_type", "generic")
if gpus: c += gpus * UNIT_PRICING.get(f"gpu_{gt}", 5)
tpus = cfg.get("tpus", 0)
tt = cfg.get("tpu_type", "v2")
if tpus: c += tpus * UNIT_PRICING.get(f"tpu_{tt}", 15)
ipus = cfg.get("ipus", 0)
it = cfg.get("ipu_type", "pod4")
if ipus: c += ipus * UNIT_PRICING.get(f"ipu_{it}", 20)
return round(c, 4)
# ─── Result Fusion Engine ──────────────────────────────────────────────────
def fuse_results(results):
"""Merge fragment results into one unified result."""
if not results: return None
if len(results) == 1: return results[0]
# All dicts β†’ merge keys (sum numbers, concat lists)
if all(isinstance(r, dict) for r in results):
merged = {}
all_keys = set()
for r in results: all_keys.update(r.keys())
for key in all_keys:
vals = [r[key] for r in results if key in r]
if not vals: continue
if all(isinstance(v, (int, float)) for v in vals):
merged[key] = sum(vals)
elif all(isinstance(v, list) for v in vals):
merged[key] = [item for v in vals for item in v]
elif all(isinstance(v, str) for v in vals):
merged[key] = "\n".join(vals)
else:
merged[key] = vals[-1] # Take last value
return merged
# All lists β†’ concatenate
if all(isinstance(r, list) for r in results):
return [item for r in results for item in r]
# All numbers β†’ sum
if all(isinstance(r, (int, float)) for r in results):
return sum(results)
# All strings β†’ merge (remove redundant lines)
if all(isinstance(r, str) for r in results):
return "\n".join(results)
return results
def fuse_stdout(fragments):
"""Merge fragment stdout into one continuous stream."""
merged = []
for f in sorted(fragments, key=lambda x: x["fragment_index"]):
stdout = f.get("stdout", "")
if stdout:
for line in stdout.split('\n'):
if line.strip():
merged.append(line)
return "\n".join(merged)
# ─── Fragmentation ─────────────────────────────────────────────────────────
# Add this to your head node fragmentation logic
def analyze_gpu_requirements(code):
"""
Detect if code needs GPU by looking for:
1. @gpu_required decorator
2. torch.cuda usage
3. device="cuda" patterns
"""
# Check for explicit decorator
if "@gpu_required" in code or "@requires_gpu" in code:
return True
# Check for torch GPU patterns
gpu_patterns = [
"torch.cuda",
'device="cuda"',
"device='cuda'",
".cuda()",
"torch.device('cuda')",
'torch.device("cuda")',
]
return any(pattern in code for pattern in gpu_patterns)
def is_parallelizable_safely(code, loop_info):
"""
Determine if a loop can be safely parallelized.
Returns: (can_parallelize: bool, reason: str)
"""
# Don't parallelize GPU training loops
gpu_training_keywords = [
"model.train()",
"optimizer.step()",
"loss.backward()",
".backward()",
"torch.nn",
"nn.Module",
]
if any(kw in code for kw in gpu_training_keywords):
return False, "GPU training loop (stateful)"
# Don't parallelize if it has shared state
stateful_keywords = [
"global ",
"nonlocal ",
".append(", # modifying shared list
"self.", # class methods
]
if any(kw in code for kw in stateful_keywords):
return False, "Stateful operations detected"
# Safe to parallelize: map-reduce style loops
safe_patterns = [
"for i in range",
"for idx in range",
"for chunk in range",
]
return True, "Independent iterations"
def fragment_code(code, cfg, task_id, libs):
"""
IMPROVED fragmentation with GPU awareness
"""
frags = []
cpus = max(1, cfg.get("cpus", 1))
# Detect GPU requirements
needs_gpu = cfg.get("gpus", 0) > 0 or analyze_gpu_requirements(code)
ram = cfg.get("ram_gb", 4)
# Find loops
loop = None
loop_safe = False
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.For) and isinstance(node.iter, ast.Call):
fn = node.iter.func
if isinstance(fn, ast.Name) and fn.id == "range":
args = node.iter.args
try:
if len(args) == 1:
loop = {"s": 0, "e": ast.literal_eval(args[0])}
elif len(args) >= 2:
loop = {"s": ast.literal_eval(args[0]), "e": ast.literal_eval(args[1])}
if loop:
loop["var"] = node.target.id if isinstance(node.target, ast.Name) else "i"
# Check if safe to parallelize
can_parallelize, reason = is_parallelizable_safely(code, loop)
if can_parallelize:
loop_safe = True
print(f"[HEAD] Loop is safe to parallelize: {reason}")
else:
print(f"[HEAD] Loop NOT safe to parallelize: {reason}")
loop = None
break
except:
loop = None
except:
pass
# Only fragment if safe AND big enough
if loop and loop_safe and (loop["e"] - loop["s"]) >= 50: # higher threshold
total = loop["e"] - loop["s"]
n = min(cpus * 2, total // 10) # at least 10 iterations per fragment
chunk = max(10, total // n)
print(f"[HEAD] Fragmenting loop: {total} iterations β†’ {n} fragments ({chunk} each)")
# Build pattern matching
patterns = []
if loop["s"] == 0:
patterns.append(f"range({loop['e']})")
patterns.append(f"range(0, {loop['e']})")
patterns.append(f"range({loop['s']}, {loop['e']})")
patterns.append(f"range({loop['s']},{loop['e']})")
for idx in range(n):
cs = loop["s"] + idx * chunk
ce = min(loop["s"] + (idx + 1) * chunk, loop["e"])
if cs >= loop["e"]:
break
mod = code
for pat in patterns:
if pat in mod:
mod = mod.replace(pat, f"range({cs}, {ce})", 1)
break
# Inject metadata
header = f"""# ═══ SACCP Fragment {idx}/{n} ═══
__saccp_rank__ = {idx}
__saccp_world_size__ = {n}
__saccp_chunk__ = ({cs}, {ce})
__saccp_is_fragment__ = True
"""
frags.append({
"fragment_id": f"{task_id}_frag_{idx}",
"task_id": task_id,
"fragment_index": idx,
"fragment_type": "compute",
"code": header + mod,
"input_data": json.dumps({"rank": idx, "world_size": n, "chunk": [cs, ce]}),
"required_libs": libs,
"required_gpu": needs_gpu,
"min_ram_gb": max(1, ram // n),
"timeout_seconds": 600,
})
# Fallback: single fragment
if not frags:
reason = "single" if not loop else ("unsafe" if not loop_safe else "too small")
print(f"[HEAD] No fragmentation ({reason}) β€” single fragment")
header = """# ═══ SACCP Single Fragment ═══
__saccp_rank__ = 0
__saccp_world_size__ = 1
__saccp_is_fragment__ = False
"""
frags.append({
"fragment_id": f"{task_id}_frag_0",
"task_id": task_id,
"fragment_index": 0,
"fragment_type": "compute",
"code": header + code,
"input_data": json.dumps({"rank": 0, "world_size": 1}),
"required_libs": libs,
"required_gpu": needs_gpu,
"min_ram_gb": ram,
"timeout_seconds": 600,
})
print(f"[HEAD] Task {task_id[:8]} β†’ {len(frags)} fragments (GPU required: {needs_gpu})")
return frags
# ─── Task Submit ─────────────────────────────────────────────────────────────
@app.post("/api/submit_task")
async def submit(task: TaskSubmit):
ur = await db("post", "/users/by_api_key", json={"api_key": task.api_key})
if not ur or ur.status_code != 200: raise HTTPException(401, "Invalid API key")
user = ur.json()
cfg = task.config
hourly = calc_cost(cfg)
if user["balance"] < hourly * 0.001: raise HTTPException(402, "Insufficient balance")
task_id = str(uuid.uuid4())
tr = await db("post", "/tasks", json={
"task_id": task_id, "owner": user["username"], "code": task.code,
"task_type": task.task_type, "config": cfg, "required_libs": task.required_libs,
})
if not tr or tr.status_code != 201: raise HTTPException(500, "Failed to create task")
fragments = fragment_code(task.code, cfg, task_id, task.required_libs)
fr = await db("post", "/fragments/batch", json={"fragments": fragments})
if not fr or fr.status_code != 201: raise HTTPException(500, "Failed to create fragments")
await db("patch", f"/tasks/{task_id}", json={"status": "running", "total_fragments": len(fragments)})
print(f"[HEAD] βœ… Task {task_id[:8]} submitted: {len(fragments)} frags, {hourly}/hr")
return {"task_id": task_id, "status": "running", "total_fragments": len(fragments), "hourly_cost": hourly}
# ─── Completion Check (called from submit_result) ──────────────────────────
async def try_complete(task_id):
fr = await db("get", f"/fragments/by_task/{task_id}")
if not fr or fr.status_code != 200: return
fragments = fr.json()
if not fragments: return
completed = [f for f in fragments if f["status"] == "completed"]
failed = [f for f in fragments if f["status"] == "failed"]
await db("patch", f"/tasks/{task_id}", json={
"completed_fragments": len(completed), "failed_fragments": len(failed)})
if len(completed) + len(failed) < len(fragments): return
# ═══ ALL DONE β€” FUSE RESULTS ═══
print(f"[HEAD] Task {task_id[:8]} complete β€” fusing {len(fragments)} fragments")
# Collect raw results and parse JSON strings
raw_results = []
for f in sorted(completed, key=lambda x: x["fragment_index"]):
r = f.get("result")
if r is not None:
if isinstance(r, str):
try: r = json.loads(r)
except: pass
raw_results.append(r)
# Fuse results into one unified output
fused = fuse_results(raw_results)
merged_stdout = fuse_stdout(fragments)
errors = [f"[Fragment {f['fragment_index']}] {f.get('error', '?')}" for f in failed]
status = "failed" if len(failed) == len(fragments) else "completed"
final = {
"fused_result": fused,
"stdout": merged_stdout,
"fragments_total": len(fragments),
"fragments_ok": len(completed),
"fragments_failed": len(failed),
"errors": errors,
"per_fragment": raw_results,
}
# Cost
tr = await db("get", f"/tasks/{task_id}")
owner, cfg = "unknown", {}
elapsed = 0.01
if tr and tr.status_code == 200:
t = tr.json(); owner = t.get("owner", "unknown"); cfg = t.get("config", {})
try:
created = datetime.fromisoformat(t["created_at"].split("+")[0].split("Z")[0])
elapsed = max((datetime.utcnow() - created).total_seconds() / 3600, 0.001)
except: pass
cost = round(calc_cost(cfg) * elapsed, 6)
if owner != "unknown":
await db("patch", f"/users/{owner}/balance", json={"amount": -cost, "reason": f"task_{task_id}"})
await db("patch", f"/tasks/{task_id}", json={
"status": status, "result": json.dumps(final), "cost": cost,
"completed_at": datetime.utcnow().isoformat(),
"error": "\n".join(errors) if errors else None,
})
print(f"[HEAD] πŸŽ‰ Task {task_id[:8]} β†’ {status} | Cost: {cost}")
# ─── Submit Result ──────────────────────────────────────────────────────────
@app.post("/api/submit_result")
async def submit_result(request: Request):
data = await request.json()
fid = data.get("fragment_id", "")
if not fid: raise HTTPException(400, "fragment_id required")
await db("patch", f"/fragments/{fid}", json={
"status": data.get("status", "failed"),
"result": data.get("result"), "error": data.get("error") or None,
"stdout": data.get("stdout") or "", "completed_at": datetime.utcnow().isoformat(),
"resource_usage": data.get("resource_usage", {}),
})
print(f"[HEAD] πŸ“₯ {fid[:20]}... β†’ {data.get('status')}")
nid = data.get("node_id", "")
if data.get("status") == "completed" and nid:
nt = active_nodes.get(nid, {}).get("node_type", "CPU")
pay = NODE_PAY_RATES.get(nt, 0) / 360
if pay > 0: await db("post", f"/nodes/{nid}/pay", json={"amount": round(pay, 6)})
# Check completion
fg = await db("get", f"/fragments/{fid}")
if fg and fg.status_code == 200:
tid = fg.json().get("task_id")
if tid: await try_complete(tid)
return {"ok": True}
# ─── Live Task Stream ──────────────────────────────────────────────────────
@app.get("/api/task/{task_id}/live")
async def live_task(task_id: str):
"""Return current state with merged stdout for live streaming."""
tr = await db("get", f"/tasks/{task_id}")
if not tr or tr.status_code != 200: raise HTTPException(404, "Not found")
task = tr.json()
fr = await db("get", f"/fragments/by_task/{task_id}")
fragments = fr.json() if fr and fr.status_code == 200 else []
merged_stdout = fuse_stdout(fragments)
completed = len([f for f in fragments if f["status"] == "completed"])
running = len([f for f in fragments if f["status"] == "running"])
total = len(fragments)
return {
"task_id": task_id, "status": task["status"],
"completed": completed, "running": running, "total": total,
"stdout": merged_stdout, "result": task.get("result"),
"cost": task.get("cost", 0),
}
# ─── Node + Task Endpoints ─────────────────────────────────────────────────
@app.post("/api/register_node")
async def reg_node(node: NodeRegister):
nid = f"node-{uuid.uuid4().hex[:12]}"
await db("post", "/nodes", json={
"node_id": nid, "node_type": node.node_type.upper(),
"node_url": node.node_url, "owner": node.owner,
"specs": node.specs, "installed_libs": node.installed_libs,
})
active_nodes[nid] = {"node_type": node.node_type.upper(), "last_hb": time.time()}
print(f"[HEAD] Node {nid} ({node.node_type})")
return {"node_id": nid, "pay_rate": f"{NODE_PAY_RATES.get(node.node_type.upper(), 0)} SACCP/hr"}
@app.post("/api/node_heartbeat")
async def hb(data: dict):
nid = data.get("node_id", "")
if nid in active_nodes: active_nodes[nid]["last_hb"] = time.time()
await db("post", f"/nodes/{nid}/heartbeat", json=data)
return {"ok": True}
@app.get("/api/get_work")
async def get_work(node_id: str, node_type: str = "CPU", has_gpu: bool = False, ram_gb: float = 4, libs: str = ""):
r = await db("get", "/fragments/pending?limit=10")
if not r or r.status_code != 200: return {"work": None}
for f in r.json():
if f.get("required_gpu") and not has_gpu: continue
if f.get("min_ram_gb", 0) > ram_gb > 0: continue
fid = f["fragment_id"]
await db("patch", f"/fragments/{fid}", json={
"status": "running", "assigned_node": node_id, "started_at": datetime.utcnow().isoformat()})
print(f"[HEAD] πŸ“€ {fid[:20]}... β†’ {node_id}")
return {"work": f}
return {"work": None}
@app.get("/api/task/{task_id}")
async def get_task(task_id: str):
r = await db("get", f"/tasks/{task_id}")
if not r or r.status_code != 200: raise HTTPException(404, "Not found")
return r.json()
@app.get("/api/task/{task_id}/fragments")
async def get_frags(task_id: str):
r = await db("get", f"/fragments/by_task/{task_id}")
return r.json() if r and r.status_code == 200 else []
@app.get("/api/my_tasks")
async def my_tasks(api_key: str):
ur = await db("post", "/users/by_api_key", json={"api_key": api_key})
if not ur or ur.status_code != 200: raise HTTPException(401, "Invalid API key")
r = await db("get", f"/tasks?owner={ur.json()['username']}")
return r.json() if r and r.status_code == 200 else []
@app.get("/api/stats")
async def stats():
r = await db("get", "/stats"); return r.json() if r and r.status_code == 200 else {}
@app.get("/health")
async def health():
return {"status": "healthy", "head_id": HEAD_ID, "version": "2.1", "nodes": len(active_nodes)}
# ─── Background ─────────────────────────────────────────────────────────────
@app.on_event("startup")
async def startup():
print(f"[HEAD] SharePUTER v2.1 β€” {HEAD_ID}")
asyncio.create_task(stuck_checker())
asyncio.create_task(node_cleaner())
async def stuck_checker():
while True:
await asyncio.sleep(15)
try:
r = await db("get", "/tasks")
if not r or r.status_code != 200: continue
for t in r.json():
if t["status"] not in ("running", "pending"): continue
fr = await db("get", f"/fragments/by_task/{t['task_id']}")
if not fr or fr.status_code != 200: continue
fs = fr.json()
if fs and len([f for f in fs if f["status"] in ("completed", "failed")]) >= len(fs):
print(f"[HEAD] πŸ”§ Unsticking {t['task_id'][:8]}")
await try_complete(t["task_id"])
except: pass
async def node_cleaner():
while True:
await asyncio.sleep(60)
dead = [n for n, d in list(active_nodes.items()) if d.get("last_hb", 0) < time.time() - 120]
for n in dead:
active_nodes.pop(n, None)
await db("post", f"/nodes/{n}/heartbeat", json={"status": "offline"})
if __name__ == "__main__":
import uvicorn; uvicorn.run(app, host="0.0.0.0", port=7860)