Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import Response | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| import uuid | |
| import threading | |
| import os | |
| import json | |
| from pathlib import Path | |
| from core.trae_bot import TraeBot | |
| import logging | |
| from collections import deque | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("API") | |
| app = FastAPI() | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Session storage | |
| sessions = {} | |
| sessions_lock = threading.Lock() | |
| max_slots = 2 | |
| queue = deque() | |
| class StartResponse(BaseModel): | |
| session_id: str | |
| status: str | |
| queue_position: int | None = None | |
| class StatusResponse(BaseModel): | |
| session_id: str | |
| status: str | |
| current_stage: str | |
| logs: list | |
| qr_code_url: str | None | |
| has_token: bool | |
| queue_position: int | None = None | |
| active_slots: int | |
| max_slots: int | |
| last_error: str | None = None | |
| last_traceback: str | None = None | |
| last_url: str | None = None | |
| last_seen_url: str | None = None | |
| last_title: str | None = None | |
| has_screenshot: bool = False | |
| def run_bot(session_id: str): | |
| with sessions_lock: | |
| session = sessions.get(session_id) | |
| if not session: | |
| return | |
| bot = session["bot"] | |
| try: | |
| bot.run() | |
| except Exception as e: | |
| logger.error(f"Bot session {session_id} failed: {e}") | |
| finally: | |
| with sessions_lock: | |
| session = sessions.get(session_id) | |
| if not session: | |
| return | |
| bot = session.get("bot") | |
| if bot and bot.status in ("running", "waiting_for_scan"): | |
| bot.status = "error" | |
| bot.update_stage("Failed: worker stopped unexpectedly") | |
| _promote_queue_locked() | |
| def _active_slots_locked() -> int: | |
| return sum( | |
| 1 | |
| for session in sessions.values() | |
| for bot in [session.get("bot")] | |
| if bot and bot.status in ("running", "waiting_for_scan") | |
| ) | |
| def _queue_position_locked(session_id: str) -> int | None: | |
| try: | |
| return list(queue).index(session_id) + 1 | |
| except ValueError: | |
| return None | |
| def _start_session_locked(session_id: str) -> bool: | |
| session = sessions.get(session_id) | |
| if not session: | |
| return False | |
| bot = session.get("bot") | |
| if not bot: | |
| return False | |
| thread = session.get("thread") | |
| if thread and thread.is_alive(): | |
| return True | |
| bot.status = "running" | |
| thread = threading.Thread(target=run_bot, args=(session_id,)) | |
| session["thread"] = thread | |
| thread.start() | |
| return True | |
| def _promote_queue_locked(): | |
| while _active_slots_locked() < max_slots and queue: | |
| next_session_id = queue.popleft() | |
| session = sessions.get(next_session_id) | |
| if not session: | |
| continue | |
| bot = session.get("bot") | |
| if not bot or bot.status != "queued": | |
| continue | |
| _start_session_locked(next_session_id) | |
| async def start_session(request: Request): | |
| client_key = request.client.host if request.client else "unknown" | |
| with sessions_lock: | |
| for existing_session_id, session in sessions.items(): | |
| if session.get("client_key") != client_key: | |
| continue | |
| bot = session.get("bot") | |
| if bot and bot.status not in ("success", "error"): | |
| return { | |
| "session_id": existing_session_id, | |
| "status": bot.status, | |
| "queue_position": _queue_position_locked(existing_session_id), | |
| } | |
| session_id = str(uuid.uuid4()) | |
| bot = TraeBot(session_id) | |
| sessions[session_id] = { | |
| "bot": bot, | |
| "thread": None, | |
| "client_key": client_key, | |
| } | |
| if _active_slots_locked() < max_slots: | |
| _start_session_locked(session_id) | |
| return {"session_id": session_id, "status": "running"} | |
| bot.status = "queued" | |
| bot.update_stage("Queued") | |
| queue.append(session_id) | |
| return { | |
| "session_id": session_id, | |
| "status": "queued", | |
| "queue_position": _queue_position_locked(session_id), | |
| } | |
| async def get_status(session_id: str): | |
| with sessions_lock: | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| bot = sessions[session_id]["bot"] | |
| active_slots = _active_slots_locked() | |
| queue_position = _queue_position_locked(session_id) if bot.status == "queued" else None | |
| return { | |
| "session_id": session_id, | |
| "status": bot.status, | |
| "current_stage": bot.current_stage, | |
| "logs": bot.logs[-50:], | |
| "qr_code_url": bot.qr_code_url, | |
| "has_token": bot.ifGotToken, | |
| "queue_position": queue_position, | |
| "active_slots": active_slots, | |
| "max_slots": max_slots, | |
| "last_error": getattr(bot, "last_error", None), | |
| "last_traceback": getattr(bot, "last_traceback", None), | |
| "last_url": getattr(bot, "last_url", None), | |
| "last_seen_url": getattr(bot, "last_seen_url", None), | |
| "last_title": getattr(bot, "last_title", None), | |
| "has_screenshot": bool(getattr(bot, "last_screenshot_png_b64", None)), | |
| } | |
| async def get_debug(session_id: str): | |
| with sessions_lock: | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| bot = sessions[session_id]["bot"] | |
| return { | |
| "session_id": session_id, | |
| "status": bot.status, | |
| "current_stage": bot.current_stage, | |
| "last_error": getattr(bot, "last_error", None), | |
| "last_traceback": getattr(bot, "last_traceback", None), | |
| "last_url": getattr(bot, "last_url", None), | |
| "last_seen_url": getattr(bot, "last_seen_url", None), | |
| "last_title": getattr(bot, "last_title", None), | |
| "last_html": getattr(bot, "last_html", None), | |
| "has_screenshot": bool(getattr(bot, "last_screenshot_png_b64", None)), | |
| } | |
| async def get_screenshot(session_id: str): | |
| with sessions_lock: | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| bot = sessions[session_id]["bot"] | |
| b64 = getattr(bot, "last_screenshot_png_b64", None) | |
| if not b64: | |
| raise HTTPException(status_code=404, detail="Screenshot not available") | |
| try: | |
| import base64 | |
| png = base64.b64decode(b64.encode("ascii")) | |
| except Exception: | |
| raise HTTPException(status_code=500, detail="Screenshot decode failed") | |
| return Response( | |
| content=png, | |
| media_type="image/png", | |
| headers={"Content-Disposition": f'inline; filename="debug_{session_id}.png"'}, | |
| ) | |
| async def download_token(session_id: str): | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| bot = sessions[session_id]["bot"] | |
| if not bot.ifGotToken or not bot.token_data: | |
| raise HTTPException(status_code=400, detail="Token not yet available") | |
| return Response( | |
| content=bot.token_data, | |
| media_type="application/json", | |
| headers={"Content-Disposition": 'attachment; filename="account.json"'}, | |
| ) | |
| frontend_dist_dir = Path(__file__).resolve().parents[1] / "frontend" / "dist" | |
| if frontend_dist_dir.exists(): | |
| app.mount("/", StaticFiles(directory=str(frontend_dist_dir), html=True), name="frontend") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |