Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| SLICK Cloud Relay — lightweight Flask server deployed on HuggingFace Spaces. | |
| The Spark pushes training metrics and chat responses here. | |
| The user's phone/browser pulls from here. | |
| No inbound connections to the Spark needed. | |
| Auth: | |
| - Spark endpoints (/api/push/*, /api/pull/*) use X-Token header | |
| - Frontend endpoints use session cookie set via /login | |
| """ | |
| import base64 | |
| import io | |
| import os | |
| import secrets | |
| import threading | |
| import time | |
| import uuid | |
| from flask import Flask, jsonify, request, send_file, send_from_directory, session, redirect | |
| app = Flask(__name__, static_folder="static") | |
| app.secret_key = os.environ.get("FLASK_SECRET", secrets.token_hex(32)) | |
| # Auth tokens | |
| RELAY_TOKEN = os.environ.get("RELAY_TOKEN", "") | |
| # --------------------------------------------------------------------------- | |
| # In-memory state | |
| # --------------------------------------------------------------------------- | |
| latest_status = { | |
| "status": "waiting", | |
| "step": 0, | |
| "total_steps": 40461, | |
| "loss": 0, | |
| "percent": 0, | |
| "speed_s_per_batch": 0, | |
| "active_watchers": 0, | |
| "spark_connected": False, | |
| "last_heartbeat": 0, | |
| } | |
| status_lock = threading.Lock() | |
| pending_questions = {} # query_id -> {message, created_at} | |
| pending_lock = threading.Lock() | |
| responses = {} # query_id -> {status, response, has_audio, metrics} | |
| responses_lock = threading.Lock() | |
| audio_store = {} # query_id -> bytes | |
| audio_lock = threading.Lock() | |
| notifications = [] | |
| notifications_lock = threading.Lock() | |
| chat_messages = [] # [{sender, text, timestamp}] | |
| chat_lock = threading.Lock() | |
| # Voice message queue (phone uploads audio → Spark transcribes + processes) | |
| voice_queue = [] # [{id, audio_bytes, created_at}] | |
| voice_lock = threading.Lock() | |
| # Dashboard data (pushed by Spark agent from :5050) | |
| dashboard_data = { | |
| "training": {}, | |
| "gpu": {}, | |
| "system": {}, | |
| "cpu": {}, | |
| "disk": {}, | |
| "download": {}, | |
| "control": {}, | |
| "loss_history": [], # [{step, loss}] | |
| "gpu_history": [], # [{t, gpu, temp}] | |
| "alter_ego": {}, | |
| } | |
| dashboard_lock = threading.Lock() | |
| # Command queue (frontend → Spark): pause, resume | |
| command_queue = [] # [{id, action, created_at}] | |
| command_results = {} # cmd_id -> {ok, status, message} | |
| command_lock = threading.Lock() | |
| # --------------------------------------------------------------------------- | |
| # Cleanup old data (prevent memory bloat) | |
| # --------------------------------------------------------------------------- | |
| def cleanup_old_data(): | |
| """Remove responses and audio older than 1 hour.""" | |
| while True: | |
| time.sleep(300) # Every 5 min | |
| cutoff = time.time() - 3600 | |
| with responses_lock: | |
| old = [k for k, v in responses.items() | |
| if v.get("completed_at", time.time()) < cutoff] | |
| for k in old: | |
| del responses[k] | |
| with audio_lock: | |
| old = [k for k, v in audio_store.items() if True] | |
| if len(audio_store) > 50: | |
| # Keep only newest 20 | |
| keys = sorted(audio_store.keys()) | |
| for k in keys[:-20]: | |
| del audio_store[k] | |
| threading.Thread(target=cleanup_old_data, daemon=True).start() | |
| # --------------------------------------------------------------------------- | |
| # Auth helpers | |
| # --------------------------------------------------------------------------- | |
| def check_spark_token(): | |
| """Verify Spark push/pull token.""" | |
| return request.headers.get("X-Token", "") == RELAY_TOKEN | |
| def check_session(): | |
| """Verify user is logged in.""" | |
| return session.get("authenticated") is True | |
| # --------------------------------------------------------------------------- | |
| # Login page | |
| # --------------------------------------------------------------------------- | |
| def login_page(): | |
| if check_session(): | |
| return redirect("/") | |
| return send_from_directory("static", "login.html") | |
| def login(): | |
| data = request.get_json() | |
| token = data.get("token", "").strip() | |
| if token == RELAY_TOKEN: | |
| session["authenticated"] = True | |
| session.permanent = True | |
| return jsonify({"ok": True}) | |
| return jsonify({"ok": False, "error": "Invalid token"}), 401 | |
| def logout(): | |
| session.clear() | |
| return jsonify({"ok": True}) | |
| # --------------------------------------------------------------------------- | |
| # Spark push endpoints (token-protected) | |
| # --------------------------------------------------------------------------- | |
| def push_status(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| data = request.get_json() | |
| with status_lock: | |
| latest_status.update(data) | |
| latest_status["spark_connected"] = True | |
| latest_status["last_heartbeat"] = time.time() | |
| return jsonify({"ok": True}) | |
| def push_transcript(): | |
| """Push voice transcript immediately after Whisper transcription (before Claude).""" | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| data = request.get_json() | |
| qid = data.get("query_id") | |
| text = data.get("transcript", "") | |
| if qid and text: | |
| with responses_lock: | |
| if qid in responses: | |
| responses[qid]["transcript"] = text | |
| # Update the user's "[Voice message]" chat entry with actual text | |
| with chat_lock: | |
| for msg in reversed(chat_messages): | |
| if msg["text"] == "[Voice message]" and msg["sender"] == "user": | |
| msg["text"] = text | |
| break | |
| return jsonify({"ok": True}) | |
| def push_response(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| data = request.get_json() | |
| qid = data.get("query_id") | |
| if not qid: | |
| return jsonify({"error": "missing query_id"}), 400 | |
| audio_only = data.get("audio_only", False) | |
| if audio_only: | |
| # Audio-only update: just update has_audio flag and store audio bytes | |
| with responses_lock: | |
| if qid in responses: | |
| responses[qid]["has_audio"] = True | |
| else: | |
| with responses_lock: | |
| responses[qid] = { | |
| "status": "complete", | |
| "response": data.get("response", ""), | |
| "has_audio": data.get("has_audio", False), | |
| "metrics": data.get("metrics", {}), | |
| "completed_at": time.time(), | |
| } | |
| # Remove from pending | |
| with pending_lock: | |
| pending_questions.pop(qid, None) | |
| # Save to chat history (only for first push, not audio updates) | |
| with chat_lock: | |
| chat_messages.append({ | |
| "sender": "tars", | |
| "text": data.get("response", ""), | |
| "timestamp": time.time(), | |
| }) | |
| # Store audio if present | |
| audio_b64 = data.get("audio_base64") | |
| if audio_b64: | |
| try: | |
| with audio_lock: | |
| audio_store[qid] = base64.b64decode(audio_b64) | |
| except Exception: | |
| pass | |
| return jsonify({"ok": True}) | |
| def push_notification(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| data = request.get_json() | |
| notif = { | |
| "id": data.get("id", str(uuid.uuid4())[:8]), | |
| "response": data.get("response", ""), | |
| "condition": data.get("condition", ""), | |
| "has_audio": data.get("has_audio", False), | |
| } | |
| with notifications_lock: | |
| notifications.append(notif) | |
| # Store audio if present | |
| audio_b64 = data.get("audio_base64") | |
| if audio_b64: | |
| try: | |
| with audio_lock: | |
| audio_store[notif["id"]] = base64.b64decode(audio_b64) | |
| except Exception: | |
| pass | |
| return jsonify({"ok": True}) | |
| def push_dashboard(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| data = request.get_json() | |
| with dashboard_lock: | |
| for key in ("training", "gpu", "system", "cpu", "disk", | |
| "download", "control", "alter_ego"): | |
| if key in data: | |
| dashboard_data[key] = data[key] | |
| if "loss_history" in data: | |
| dashboard_data["loss_history"] = data["loss_history"][-5000:] | |
| if "gpu_history" in data: | |
| dashboard_data["gpu_history"] = data["gpu_history"][-200:] | |
| return jsonify({"ok": True}) | |
| def push_command_result(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| data = request.get_json() | |
| cmd_id = data.get("command_id") | |
| if cmd_id: | |
| with command_lock: | |
| command_results[cmd_id] = { | |
| "ok": data.get("ok", False), | |
| "status": data.get("status", ""), | |
| "message": data.get("message", ""), | |
| } | |
| return jsonify({"ok": True}) | |
| def pull_pending(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| with pending_lock: | |
| items = [ | |
| {"query_id": qid, "message": q["message"]} | |
| for qid, q in pending_questions.items() | |
| ] | |
| # Clear immediately so next poll doesn't re-fetch same questions | |
| pending_questions.clear() | |
| return jsonify({"pending": items}) | |
| def pull_voice(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| with voice_lock: | |
| items = [] | |
| for v in voice_queue: | |
| items.append({ | |
| "id": v["id"], | |
| "audio_base64": base64.b64encode(v["audio_bytes"]).decode(), | |
| }) | |
| voice_queue.clear() | |
| return jsonify({"pending": items}) | |
| def pull_commands(): | |
| if not check_spark_token(): | |
| return jsonify({"error": "unauthorized"}), 401 | |
| with command_lock: | |
| cmds = list(command_queue) | |
| command_queue.clear() | |
| return jsonify({"commands": cmds}) | |
| # --------------------------------------------------------------------------- | |
| # Frontend endpoints (session-protected) | |
| # --------------------------------------------------------------------------- | |
| def ask(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| data = request.get_json() | |
| if not data or "message" not in data: | |
| return jsonify({"error": "Missing message"}), 400 | |
| message = data["message"].strip() | |
| if not message: | |
| return jsonify({"error": "Empty message"}), 400 | |
| # Check Spark connection | |
| with status_lock: | |
| connected = latest_status.get("spark_connected", False) | |
| stale = time.time() - latest_status.get("last_heartbeat", 0) > 120 | |
| if not connected or stale: | |
| return jsonify({"error": "Spark not connected. Is the relay agent running?"}), 503 | |
| query_id = str(uuid.uuid4())[:8] | |
| with pending_lock: | |
| # Clean old pending (>5 min) | |
| old = [k for k, v in pending_questions.items() | |
| if v.get("created_at", 0) < time.time() - 300] | |
| for k in old: | |
| del pending_questions[k] | |
| pending_questions[query_id] = { | |
| "message": message, | |
| "created_at": time.time(), | |
| } | |
| with responses_lock: | |
| responses[query_id] = {"status": "processing"} | |
| # Save user message to chat history | |
| with chat_lock: | |
| chat_messages.append({ | |
| "sender": "user", | |
| "text": message, | |
| "timestamp": time.time(), | |
| }) | |
| return jsonify({"id": query_id, "status": "processing"}) | |
| def result(query_id): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with responses_lock: | |
| resp = responses.get(query_id) | |
| if not resp: | |
| return jsonify({"error": "Unknown query ID"}), 404 | |
| return jsonify({"id": query_id, **resp}) | |
| def audio(query_id): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with audio_lock: | |
| data = audio_store.get(query_id) | |
| if not data: | |
| return jsonify({"error": "No audio"}), 404 | |
| # Detect format: MP3 starts with ID3 or 0xFF 0xFB | |
| is_mp3 = data[:3] == b'ID3' or (len(data) > 1 and data[0] == 0xFF and (data[1] & 0xE0) == 0xE0) | |
| mime = "audio/mpeg" if is_mp3 else "audio/wav" | |
| ext = "mp3" if is_mp3 else "wav" | |
| return send_file( | |
| io.BytesIO(data), | |
| mimetype=mime, | |
| as_attachment=False, | |
| download_name=f"tars_{query_id}.{ext}", | |
| ) | |
| def status(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with status_lock: | |
| s = dict(latest_status) | |
| # Mark as disconnected if no heartbeat in 60s | |
| if time.time() - s.get("last_heartbeat", 0) > 60: | |
| s["spark_connected"] = False | |
| return jsonify(s) | |
| def connection_status(): | |
| """Quick connection check — returns Spark online/offline + last seen.""" | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with status_lock: | |
| hb = latest_status.get("last_heartbeat", 0) | |
| ago = time.time() - hb if hb > 0 else -1 | |
| connected = hb > 0 and ago < 60 | |
| return jsonify({ | |
| "connected": connected, | |
| "last_heartbeat": hb, | |
| "seconds_ago": round(ago, 1) if ago >= 0 else -1, | |
| }) | |
| def get_notifications(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with notifications_lock: | |
| pending = list(notifications) | |
| notifications.clear() | |
| return jsonify({"notifications": pending}) | |
| def get_watchers(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| return jsonify({"watchers": []}) | |
| def get_history(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with chat_lock: | |
| msgs = [ | |
| {"sender": m["sender"], "text": m["text"], | |
| "timestamp": m.get("timestamp", "")} | |
| for m in chat_messages[-200:] # Last 200 messages | |
| ] | |
| return jsonify({"messages": msgs}) | |
| def list_archives(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| return jsonify({"archives": []}) | |
| def get_dashboard(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with dashboard_lock: | |
| d = {k: v for k, v in dashboard_data.items()} | |
| return jsonify(d) | |
| def get_dashboard_history(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with dashboard_lock: | |
| return jsonify({ | |
| "loss_history": dashboard_data.get("loss_history", []), | |
| "gpu_history": dashboard_data.get("gpu_history", []), | |
| }) | |
| def send_command(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| data = request.get_json() | |
| action = data.get("action", "") | |
| if action not in ("pause", "resume"): | |
| return jsonify({"error": "Invalid action"}), 400 | |
| cmd_id = str(uuid.uuid4())[:8] | |
| with command_lock: | |
| command_queue.append({ | |
| "id": cmd_id, | |
| "action": action, | |
| "created_at": time.time(), | |
| }) | |
| return jsonify({"ok": True, "command_id": cmd_id}) | |
| def get_command_result(cmd_id): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with command_lock: | |
| result = command_results.get(cmd_id) | |
| if result: | |
| return jsonify(result) | |
| return jsonify({"status": "pending"}) | |
| def voice_message(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| if "audio" not in request.files: | |
| return jsonify({"error": "No audio file"}), 400 | |
| audio_file = request.files["audio"] | |
| audio_bytes = audio_file.read() | |
| if not audio_bytes: | |
| return jsonify({"error": "Empty audio"}), 400 | |
| # Check Spark connection | |
| with status_lock: | |
| connected = latest_status.get("spark_connected", False) | |
| stale = time.time() - latest_status.get("last_heartbeat", 0) > 120 | |
| if not connected or stale: | |
| return jsonify({"error": "Spark not connected"}), 503 | |
| query_id = str(uuid.uuid4())[:8] | |
| with voice_lock: | |
| voice_queue.append({ | |
| "id": query_id, | |
| "audio_bytes": audio_bytes, | |
| "created_at": time.time(), | |
| }) | |
| with responses_lock: | |
| responses[query_id] = {"status": "processing"} | |
| with chat_lock: | |
| chat_messages.append({ | |
| "sender": "user", | |
| "text": "[Voice message]", | |
| "timestamp": time.time(), | |
| }) | |
| return jsonify({"id": query_id, "status": "processing"}) | |
| def archive_chat(): | |
| if not check_session(): | |
| return jsonify({"error": "Not authenticated"}), 401 | |
| with chat_lock: | |
| count = len(chat_messages) | |
| chat_messages.clear() | |
| return jsonify({"archived": True, "count": count}) | |
| # --------------------------------------------------------------------------- | |
| # Serve frontend | |
| # --------------------------------------------------------------------------- | |
| def index(): | |
| if not check_session(): | |
| return redirect("/login") | |
| return send_from_directory("static", "index.html") | |
| # Health check (HF Spaces uses this) | |
| def health(): | |
| return jsonify({"status": "ok"}) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"SLICK Relay starting on port {port}") | |
| print(f"Token configured: {'yes' if RELAY_TOKEN else 'NO — set RELAY_TOKEN env var!'}") | |
| app.run(host="0.0.0.0", port=port, debug=False, threaded=True) | |