import asyncio import json import io import logging import os import uuid import psutil from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional import torch from huggingface_hub import snapshot_download from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from PIL import Image from pymongo import MongoClient, ReturnDocument from pymongo.errors import PyMongoError, DuplicateKeyError from pymongo.collection import Collection from pymongo.database import Database import hashlib from torchvision import transforms from transformers import AutoModelForImageSegmentation from dotenv import load_dotenv load_dotenv() print(os.getenv("NEXT_MONGODB_URI") and "MONGO" or "No MONGO URI configured") from huggingface_hub import login HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) print("Logged into Hugging Face Hub") else: print("No HF_TOKEN configured; relying on cached models or public access") UPLOADS_DIR = Path(__file__).parent / "uploads" ORG_DIR = UPLOADS_DIR / "org" PROCESSED_DIR = UPLOADS_DIR / "processed" ORG_DIR.mkdir(parents=True, exist_ok=True) PROCESSED_DIR.mkdir(parents=True, exist_ok=True) ALLOWED_CONTENT_TYPES = {"image/png", "image/jpeg", "image/jpg", "image/webp", "image/tiff", "image/tif", "image/heif", "image/heic", "image/avif"} MAX_UPLOAD_SIZE_BYTES = int(os.getenv("WORKER_MAX_UPLOAD_SIZE_BYTES", str(20 * 1024 * 1024))) MAX_CONCURRENCY = max(1, int(os.getenv("WORKER_MAX_CONCURRENCY", "2"))) MAX_JOBS_PER_CLIENT = max(1, int(os.getenv("WORKER_MAX_JOBS_PER_CLIENT", "1"))) # Support both WORKER_JOB_RETENTION_MINUTES (new) and WORKER_JOB_RETENTION_HOURS (legacy for backward compatibility) # Minimum of 1 minute for testing. Default: 10 minutes if "WORKER_JOB_RETENTION_MINUTES" in os.environ: JOB_RETENTION_MINUTES = max(1, int(os.getenv("WORKER_JOB_RETENTION_MINUTES"))) else: legacy_retention_hours = os.getenv("WORKER_JOB_RETENTION_HOURS") if legacy_retention_hours is not None: JOB_RETENTION_MINUTES = max(1, int(legacy_retention_hours)) * 60 else: JOB_RETENTION_MINUTES = 10 QUEUE_POLL_SECONDS = float(os.getenv("WORKER_QUEUE_POLL_SECONDS", "1.0")) CLEANUP_INTERVAL_SECONDS = int(os.getenv("WORKER_CLEANUP_INTERVAL_SECONDS", "60")) CPU_THRESHOLD_PERCENT = int(os.getenv("WORKER_CPU_THRESHOLD_PERCENT", "80")) MEMORY_THRESHOLD_PERCENT = int(os.getenv("WORKER_MEMORY_THRESHOLD_PERCENT", "80")) MONGO_URI = os.getenv("NEXT_MONGODB_URI") MONGO_DB_NAME = os.getenv("NEXT_MONGODB_DB", "bgremover") WORKER_INTERNAL_TOKEN = os.getenv("WORKER_INTERNAL_TOKEN") ALLOWED_ORIGINS = [ origin.strip().rstrip("/") for origin in os.getenv( "WORKER_CORS_ORIGINS", "http://localhost:3000,http://localhost:3001", ).split(",") if origin.strip() ] logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"), format="%(asctime)s %(levelname)s %(name)s %(message)s") logger = logging.getLogger("bgremover.worker") MODEL_REPO_ID = os.getenv("WORKER_MODEL_REPO_ID", "Joker5800/ZhengPeng7_BiRefNet_lite") MODEL_LOCAL_DIR = Path( os.getenv("WORKER_MODEL_LOCAL_DIR", str(Path(__file__).parent / "ZhengPeng7_BiRefNet_lite")) ) MODEL_CACHE_DIR = os.getenv("WORKER_MODEL_CACHE_DIR") mongo_client: Optional[MongoClient] = None db: Optional[Database] = None jobs_collection: Optional[Collection] = None model = None device = None dispatcher_task: Optional[asyncio.Task] = None cleanup_task: Optional[asyncio.Task] = None dispatcher_wakeup: Optional[asyncio.Event] = None active_tasks: set[asyncio.Task] = set() active_tasks_lock = asyncio.Lock() job_subscribers: dict[str, set[asyncio.Queue[dict]]] = {} job_subscribers_lock = asyncio.Lock() def get_dynamic_concurrency() -> int: try: cpu_percent = psutil.cpu_percent(interval=0.1) memory = psutil.virtual_memory() memory_percent = memory.percent if cpu_percent > CPU_THRESHOLD_PERCENT or memory_percent > MEMORY_THRESHOLD_PERCENT: return 1 if cpu_percent < CPU_THRESHOLD_PERCENT * 0.5 and memory_percent < MEMORY_THRESHOLD_PERCENT * 0.5: return min(MAX_CONCURRENCY, 2) return 1 except Exception: return 1 def utcnow() -> datetime: return datetime.now(timezone.utc) def ensure_database() -> Collection: global mongo_client, db, jobs_collection if jobs_collection is not None: return jobs_collection if not MONGO_URI: raise RuntimeError("NEXT_MONGODB_URI not configured") mongo_client = MongoClient(MONGO_URI, tz_aware=True) db = mongo_client.get_database(MONGO_DB_NAME) jobs_collection = db["jobs"] # analytics collection stores aggregated counts per day and hourly breakdowns try: db.create_collection("analytics") except Exception: pass try: db.create_collection("analytics_seen") except Exception: pass analytics_collection = db["analytics"] analytics_seen = db["analytics_seen"] # ensure indexes try: analytics_collection.create_index([("date", 1)], unique=True) except Exception: pass # Temporary dedupe store for user hashes (TTL ~ 25 hours) try: analytics_seen.create_index([("date", 1), ("h", 1)], unique=True) analytics_seen.create_index("createdAt", expireAfterSeconds=int(25 * 3600)) except Exception: pass def safe_create_index(keys, **kwargs): try: jobs_collection.create_index(keys, **kwargs) except PyMongoError as exc: details = getattr(exc, "details", {}) or {} code_name = getattr(exc, "code_name", None) or details.get("codeName") code = getattr(exc, "code", None) or details.get("code") if code == 85 or code_name == "IndexOptionsConflict": logger.warning("Index conflict on jobs collection; continuing", extra={"keys": keys, "options": kwargs}) return raise safe_create_index([("jobId", 1)], unique=True) safe_create_index([("status", 1), ("createdAt", 1)]) safe_create_index([("clientKey", 1), ("status", 1), ("createdAt", 1)]) safe_create_index([("expiresAt", 1)], expireAfterSeconds=0) return jobs_collection def record_analytics(client_key: Optional[str]) -> None: """Record job and unique user counts without storing raw IPs. We store only a hash of client_key temporarily in `analytics_seen` to dedupe unique users per day. """ try: if db is None: return analytics = db["analytics"] analytics_seen = db["analytics_seen"] now = utcnow() date_str = now.date().isoformat() hour_str = f"{now.hour:02d}" # Increment job counters (daily + hourly) update_ops = { "$inc": {"jobs": 1, f"hours.{hour_str}.jobs": 1}, "$setOnInsert": {"date": date_str}, } analytics.update_one({"date": date_str}, update_ops, upsert=True) # If we have a client key, store a hash in analytics_seen to dedupe unique users per day. if client_key: h = hashlib.sha256(client_key.encode("utf-8")).hexdigest() seen_doc = {"date": date_str, "h": h, "createdAt": utcnow()} try: analytics_seen.insert_one(seen_doc) # first-seen for this day -> increment unique user counters analytics.update_one({"date": date_str}, {"$inc": {"unique_users": 1, f"hours.{hour_str}.users": 1}, "$setOnInsert": {"date": date_str}}, upsert=True) except DuplicateKeyError: # already seen today, do nothing for unique user counts pass except Exception: logger.exception("Failed to record analytics") def load_model(): global model, device def has_checkpoint_files(model_dir: Path) -> bool: if not model_dir.exists(): return False patterns = [ "model.safetensors", "pytorch_model.bin", "pytorch_model.bin.index.json", "*.safetensors", ] return any(any(model_dir.glob(pattern)) for pattern in patterns) local_model_path = MODEL_LOCAL_DIR if not has_checkpoint_files(local_model_path): logger.warning( "No local model checkpoint found at %s. Downloading %s...", local_model_path, MODEL_REPO_ID, ) local_model_path.mkdir(parents=True, exist_ok=True) snapshot_kwargs = { "repo_id": MODEL_REPO_ID, "local_dir": str(local_model_path), "local_dir_use_symlinks": False, } if MODEL_CACHE_DIR: snapshot_kwargs["cache_dir"] = MODEL_CACHE_DIR snapshot_download(**snapshot_kwargs) if not has_checkpoint_files(local_model_path): raise RuntimeError( f"Model download completed but checkpoint files are still missing in {local_model_path}" ) logger.info("Model weights downloaded to %s", local_model_path) device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForImageSegmentation.from_pretrained( str(local_model_path), trust_remote_code=True, local_files_only=True, ) model.to(device) model.eval() return model, device def cleanup_files(job_id: str): org_path = ORG_DIR / f"{job_id}.png" processed_path = PROCESSED_DIR / f"{job_id}.png" if org_path.exists(): org_path.unlink() if processed_path.exists(): processed_path.unlink() async def broadcast_job_state(job: dict) -> None: payload = build_status_payload(job) async with job_subscribers_lock: queues = list(job_subscribers.get(job["jobId"], set())) for queue in queues: try: queue.put_nowait(payload) except asyncio.QueueFull: logger.warning("Dropping stale SSE update for job %s", job["jobId"]) async def register_job_subscriber(job_id: str) -> asyncio.Queue[dict]: queue: asyncio.Queue[dict] = asyncio.Queue(maxsize=10) async with job_subscribers_lock: subscribers = job_subscribers.setdefault(job_id, set()) subscribers.add(queue) return queue async def unregister_job_subscriber(job_id: str, queue: asyncio.Queue[dict]) -> None: async with job_subscribers_lock: subscribers = job_subscribers.get(job_id) if not subscribers: return subscribers.discard(queue) if not subscribers: job_subscribers.pop(job_id, None) def process_image_bytes(image_bytes: bytes) -> bytes: global model, device image = Image.open(io.BytesIO(image_bytes)).convert("RGB") transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) input_tensor = transform(image).unsqueeze(0).to(device) input_tensor = input_tensor.to(next(model.parameters()).dtype) with torch.inference_mode(): preds = model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() mask = transforms.ToPILImage()(pred) mask = mask.resize(image.size) image.putalpha(mask) output_buffer = io.BytesIO() image.save( output_buffer, format="PNG", optimize=True, # Compresses the PNG structure compress_level=6 # Standard balance between speed and file size (0-9) ) return output_buffer.getvalue() def set_dispatcher_wakeup() -> None: if dispatcher_wakeup is not None: dispatcher_wakeup.set() def get_job_store() -> Collection: return ensure_database() def build_job_paths(job_id: str) -> tuple[Path, Path]: return ORG_DIR / f"{job_id}.png", PROCESSED_DIR / f"{job_id}.png" def is_internal_request(request: Request) -> bool: if not WORKER_INTERNAL_TOKEN: return True if request.headers.get("x-internal-token") == WORKER_INTERNAL_TOKEN: return True origin = request.headers.get("origin") if origin and origin in ALLOWED_ORIGINS: return True return False async def update_job(job_id: str, updates: dict) -> None: """Update job document. Always removes 'progress' field to use status-derived progress.""" collection = get_job_store() # Remove 'progress' field as it's redundant (derived from status) updates.pop("progress", None) updates.pop("completedAt", None) # No need to store completion time separately updates["updatedAt"] = utcnow() result = await asyncio.to_thread(collection.update_one, {"jobId": job_id}, {"$set": updates}) job = await asyncio.to_thread(collection.find_one, {"jobId": job_id}, {"_id": 0}) if job: await broadcast_job_state(job) async def process_job(job_id: str) -> None: collection = get_job_store() now = utcnow() try: job = await asyncio.to_thread(collection.find_one, {"jobId": job_id}) if not job: logger.warning(f"Job {job_id} not found, skipping") return input_path = Path(job["inputPath"]) output_path = Path(job["outputPath"]) # Update to running (progress derived from status) await update_job(job_id, {"status": "running", "startedAt": now}) logger.info(f"[Job {job_id}] Processing started") if not input_path.exists(): logger.warning(f"[Job {job_id}] Input file not found, marking failed") await update_job(job_id, {"status": "failed", "error": "Input file not found"}) return image_bytes = input_path.read_bytes() logger.info(f"[Job {job_id}] Image loaded: {len(image_bytes)} bytes") processed_bytes = await asyncio.to_thread(process_image_bytes, image_bytes) output_path.write_bytes(processed_bytes) logger.info(f"[Job {job_id}] Image processed and saved") # Update to completed (no progress field needed) await update_job( job_id, { "status": "completed", "error": None, }, ) logger.info(f"[Job {job_id}] Successfully completed") except Exception as exc: logger.exception("Job %s failed", job_id) await update_job( job_id, { "status": "failed", "error": str(exc), }, ) logger.error(f"[Job {job_id}] Failed: {str(exc)}") finally: current_task = asyncio.current_task() async with active_tasks_lock: if current_task in active_tasks: active_tasks.remove(current_task) set_dispatcher_wakeup() def can_run_job(job: dict) -> bool: collection = get_job_store() client_key = job.get("clientKey") or "anonymous" active_count = collection.count_documents( { "clientKey": client_key, "status": {"$in": ["starting", "running"]}, } ) return active_count < MAX_JOBS_PER_CLIENT def claim_next_job() -> Optional[dict]: collection = get_job_store() now = utcnow() # Clean up stale "starting" jobs that have been stuck for >2 minutes # These could be from crashed workers or hung processes stale_starting = list(collection.find({ "status": "starting", "startedAt": {"$lt": now - timedelta(minutes=2)} })) for job in stale_starting: logger.warning(f"[Job {job['jobId']}] Stale starting job, resetting to queued") collection.update_one( {"jobId": job["jobId"]}, {"$set": {"status": "queued", "updatedAt": now}}, ) # Clean up stale "running" jobs that have been stuck for >5 minutes stale_running = list(collection.find({ "status": "running", "startedAt": {"$lt": now - timedelta(minutes=5)} })) for job in stale_running: logger.warning(f"[Job {job['jobId']}] Stale running job, marking failed") collection.update_one( {"jobId": job["jobId"]}, {"$set": {"status": "failed", "error": "Job timed out", "updatedAt": now}}, ) candidates = list( collection.find({"status": "queued"}).sort("createdAt", 1).limit(100) ) # Clean up queued jobs with missing input files for job in candidates: input_path = Path(job.get("inputPath", "")) if input_path and not input_path.exists(): logger.warning(f"[Job {job['jobId']}] Queued job has missing input file, marking failed") collection.update_one( {"jobId": job["jobId"]}, { "$set": { "status": "failed", "error": "Queued job input file is missing", "updatedAt": now, } }, ) # Re-fetch candidates after cleanup candidates = list( collection.find({"status": "queued"}).sort("createdAt", 1).limit(100) ) for job in candidates: if not can_run_job(job): continue claimed = collection.find_one_and_update( {"jobId": job["jobId"], "status": "queued"}, { "$set": { "status": "starting", "updatedAt": now, "startedAt": now, } }, return_document=ReturnDocument.AFTER, ) if claimed: logger.info(f"[Job {claimed['jobId']}] Claimed for processing") return claimed return None async def dispatcher_loop() -> None: assert dispatcher_wakeup is not None while True: # Claim and process as many jobs as we can while True: # Check current concurrency async with active_tasks_lock: active_count = len(active_tasks) current_concurrency = get_dynamic_concurrency() if active_count >= current_concurrency: break # claim_next_job handles cleanup and returns an already-claimed job job = await asyncio.to_thread(claim_next_job) if not job: break logger.info(f"[Dispatcher] Processing job {job['jobId']}") task = asyncio.create_task(process_job(job["jobId"])) async with active_tasks_lock: active_tasks.add(task) # No more jobs to claim, wait for work try: await asyncio.wait_for(dispatcher_wakeup.wait(), timeout=QUEUE_POLL_SECONDS) except asyncio.TimeoutError: pass dispatcher_wakeup.clear() dispatcher_wakeup.clear() logger.debug(f"[Dispatcher] Woke up, checking for new jobs") async def cleanup_loop() -> None: """ Monitor and clean up expired jobs. Calls cleanup_files for each expired job before MongoDB TTL auto-deletes the documents. """ collection = get_job_store() cleanup_count = 0 files_cleaned = 0 while True: try: now = utcnow() # Find expired jobs that haven't been cleaned yet expired_jobs = list(collection.find( {"expiresAt": {"$lt": now}, "status": {"$ne": "cleaned"}}, {"jobId": 1, "status": 1, "resultPath": 1} )) if expired_jobs: logger.info(f"[TTL Cleanup] Found {len(expired_jobs)} expired jobs to clean") for job in expired_jobs: job_id = job.get("jobId") if job_id: # Clean up the files on disk cleanup_files(job_id) files_cleaned += 1 logger.debug(f"[TTL Cleanup] Cleaned files for job {job_id}") # Mark as cleaned so we don't process again try: collection.update_one( {"jobId": job_id}, {"$set": {"status": "cleaned"}} ) except Exception as e: logger.warning(f"[TTL Cleanup] Failed to mark job {job_id} as cleaned: {e}") logger.info(f"[TTL Cleanup] Cleaned {len(expired_jobs)} jobs, {files_cleaned} total file cleanups") # Also log what's coming up soon_expire = now + timedelta(minutes=5) expiring_soon = collection.count_documents({"expiresAt": {"$gte": now, "$lte": soon_expire}}) if expiring_soon > 0: logger.info(f"[TTL Cleanup] {expiring_soon} jobs expiring in next 5 minutes") cleanup_count += len(expired_jobs) except Exception as e: logger.error(f"[TTL Cleanup] Error during monitoring: {e}", exc_info=True) await asyncio.sleep(CLEANUP_INTERVAL_SECONDS) def get_queue_position(job_id: str, created_at: datetime) -> Optional[int]: """Calculate job's position in queue. Returns None if not queued.""" collection = get_job_store() count = collection.count_documents({ "status": "queued", "createdAt": {"$lt": created_at} }) return count + 1 def estimate_wait(queue_position: Optional[int], avg_seconds_per_job: int = 12) -> Optional[int]: """Estimate wait time in seconds based on queue position.""" if queue_position is None: return None return queue_position * avg_seconds_per_job def get_progress_from_status(status: str) -> int: """Derive progress percentage from status. This eliminates the need to store redundant progress field. """ status_map = { "queued": 0, "starting": 25, "running": 50, "completed": 100, "failed": 0, "cancelled": 0, "expired": 0, } return status_map.get(status, 0) def build_status_payload(job: dict) -> dict: """Build status response for API, deriving progress from status.""" public_status = job.get("status", "queued") if public_status == "starting": public_status = "running" # Progress is now derived from status instead of stored separately progress = get_progress_from_status(public_status) queue_position = None estimated_wait = None if public_status == "queued": queue_position = get_queue_position(job["jobId"], job["createdAt"]) estimated_wait = estimate_wait(queue_position) return { "job_id": job["jobId"], "status": public_status, "progress": progress, "error": job.get("error"), "queue_position": queue_position, "estimated_wait_seconds": estimated_wait, } def format_sse_payload(job: dict) -> str: return f"data: {json.dumps(build_status_payload(job))}\n\n" @asynccontextmanager async def lifespan(app: FastAPI): global dispatcher_task, cleanup_task, dispatcher_wakeup print("Loading model...") load_model() get_job_store() dispatcher_wakeup = asyncio.Event() dispatcher_task = asyncio.create_task(dispatcher_loop()) cleanup_task = asyncio.create_task(cleanup_loop()) print(f"Model loaded on {device}") try: yield finally: for task in [dispatcher_task, cleanup_task]: if task is not None: task.cancel() await asyncio.gather(*[task for task in [dispatcher_task, cleanup_task] if task is not None], return_exceptions=True) print("Shutting down...") app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS or ["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/remove") async def remove_background( request: Request, wait: bool = Form(False), file: UploadFile = File(...), ): logger.info(f"Received remove request: filename={file.filename}, size={file.size}") if not is_internal_request(request): raise HTTPException(status_code=401, detail="Unauthorized") collection = get_job_store() if file.content_type and file.content_type not in ALLOWED_CONTENT_TYPES: raise HTTPException(status_code=415, detail="Unsupported media type") file_bytes = await file.read() logger.info(f"File read complete: {len(file_bytes)} bytes") if len(file_bytes) > MAX_UPLOAD_SIZE_BYTES: raise HTTPException(status_code=413, detail="File too large") job_id = str(uuid.uuid4()) input_path, output_path = build_job_paths(job_id) # Try to decode and normalize to PNG where possible so downstream model # processing reliably receives a PNG/RGBA image. If decoding fails, fall # back to saving the original bytes and let downstream report errors. normalized_saved = False try: img = Image.open(io.BytesIO(file_bytes)) img = img.convert("RGBA") normalized_buf = io.BytesIO() img.save(normalized_buf, format="PNG") normalized_buf.seek(0) png_bytes = normalized_buf.read() input_path = input_path.with_suffix(".png") input_path.write_bytes(png_bytes) logger.info(f"Job {job_id} created, normalized PNG saved to {input_path}") normalized_saved = True except Exception: logger.debug("PIL failed to decode; attempting format-specific decoders if available") # Try pillow_heif for HEIF/HEIC try: import pillow_heif heif = pillow_heif.read_heif(file_bytes) img = Image.frombytes(heif.mode, heif.size, heif.data) img = img.convert("RGBA") normalized_buf = io.BytesIO() img.save(normalized_buf, format="PNG") normalized_buf.seek(0) png_bytes = normalized_buf.read() input_path = input_path.with_suffix(".png") input_path.write_bytes(png_bytes) logger.info(f"Job {job_id} created, decoded HEIF/HEIC and saved PNG to {input_path}") normalized_saved = True except Exception: logger.debug("pillow_heif not available or failed to decode") if not normalized_saved: # Final fallback: write original upload bytes input_path.write_bytes(file_bytes) logger.info(f"Job {job_id} created, original file saved to {input_path}") client_key = request.headers.get("x-client-ip") if not client_key and request.client is not None: client_key = request.client.host job_record = { "jobId": job_id, "status": "queued", "createdAt": utcnow(), "updatedAt": utcnow(), "expiresAt": utcnow() + timedelta(minutes=JOB_RETENTION_MINUTES), "inputPath": str(input_path), "outputPath": str(output_path), "fileName": file.filename or f"{job_id}.png", "clientKey": client_key or "anonymous", "error": None, } await asyncio.to_thread(collection.insert_one, job_record) logger.info(f"[Job {job_id}] Created - expires at {job_record['expiresAt'].isoformat()}") # Record analytics (jobs + unique users) asynchronously; we store only hashed client keys temporarily. try: asyncio.create_task(asyncio.to_thread(record_analytics, client_key)) except Exception: logger.exception("Failed to schedule analytics recording") if wait: await process_job(job_id) completed_job = await asyncio.to_thread(collection.find_one, {"jobId": job_id}, {"_id": 0}) if not completed_job: raise HTTPException(status_code=500, detail="Job not found after processing") if completed_job.get("status") != "completed": return JSONResponse( { "job_id": job_id, "status": completed_job.get("status", "failed"), "progress": get_progress_from_status(completed_job.get("status", "failed")), "error": completed_job.get("error"), }, status_code=500, ) output_path = Path(completed_job["outputPath"]) if not output_path.exists(): raise HTTPException(status_code=500, detail="Processed image not found") return FileResponse( output_path, media_type="image/png", headers={"X-Job-Id": job_id}, ) set_dispatcher_wakeup() await broadcast_job_state(job_record) return JSONResponse( { "job_id": job_id, "status": "queued", "progress": 0, }, status_code=202, ) @app.get("/status/{job_id}") async def get_status(job_id: str): collection = get_job_store() job = await asyncio.to_thread(collection.find_one, {"jobId": job_id}, {"_id": 0}) if not job: raise HTTPException(status_code=404, detail="Job not found") return JSONResponse(build_status_payload(job)) @app.get("/queue-status") async def queue_status(): collection = get_job_store() queued_jobs = await asyncio.to_thread(collection.count_documents, {"status": "queued"}) running_jobs = await asyncio.to_thread(collection.count_documents, {"status": {"$in": ["starting", "running"]}}) failed_jobs = await asyncio.to_thread(collection.count_documents, {"status": "failed"}) completed_jobs = await asyncio.to_thread(collection.count_documents, {"status": "completed"}) return JSONResponse( { "queue_length": queued_jobs, "running_jobs": running_jobs, "batch_size": MAX_CONCURRENCY, "max_concurrency": MAX_CONCURRENCY, "failed_jobs": failed_jobs, "completed_jobs": completed_jobs, } ) @app.get("/result/{job_id}") async def get_result(job_id: str): collection = get_job_store() job = await asyncio.to_thread(collection.find_one, {"jobId": job_id}, {"_id": 0}) if not job: raise HTTPException(status_code=404, detail="Job not found") if job.get("status") != "completed": raise HTTPException(status_code=409, detail="Job not completed") output_path = Path(job["outputPath"]) if not output_path.exists(): raise HTTPException(status_code=404, detail="Result not found") return FileResponse( output_path, media_type="image/png", headers={"X-Job-Id": job_id}, ) @app.get("/events/{job_id}") async def job_events(job_id: str, request: Request): collection = get_job_store() job = await asyncio.to_thread(collection.find_one, {"jobId": job_id}, {"_id": 0}) if not job: raise HTTPException(status_code=404, detail="Job not found") subscriber_queue = await register_job_subscriber(job_id) async def event_stream(): try: yield format_sse_payload(job) while True: if await request.is_disconnected(): break try: payload = await asyncio.wait_for(subscriber_queue.get(), timeout=15) except asyncio.TimeoutError: yield ": keep-alive\n\n" continue yield f"data: {json.dumps(payload)}\n\n" if payload.get("status") in {"completed", "failed"}: break finally: await unregister_job_subscriber(job_id, subscriber_queue) return StreamingResponse(event_stream(), media_type="text/event-stream") @app.get("/health") async def health(): collection = get_job_store() queued_jobs = collection.count_documents({"status": "queued"}) running_jobs = collection.count_documents({"status": {"$in": ["starting", "running"]}}) current_concurrency = get_dynamic_concurrency() cpu_percent = 0 memory_percent = 0 try: cpu_percent = psutil.cpu_percent(interval=0.1) memory = psutil.virtual_memory() memory_percent = memory.percent except Exception: pass return { "status": "healthy", "device": device, "model_loaded": model is not None, "queued_jobs": queued_jobs, "running_jobs": running_jobs, "max_concurrency": current_concurrency, "cpu_percent": cpu_percent, "memory_percent": memory_percent, } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8000")))