""" BG Remover Pro — FastAPI Backend Supports: Fast Mode (u2net) & Thinking Mode (BiRefNet + Claude AI) Queue: max 10 waiting | Rate limiting | Anti-spam """ import asyncio import base64 import gc import io import json import logging import os import time import uuid from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import Dict, List, Optional import anthropic from fastapi import FastAPI, File, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles from PIL import Image, ImageFilter import numpy as np # ───────────────────────────────────────────── # LOGGING # ───────────────────────────────────────────── logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger("bgremover") # ───────────────────────────────────────────── # CONSTANTS # ───────────────────────────────────────────── ALLOWED_MIME_TYPES = { "image/jpeg", "image/jpg", "image/png", "image/webp", "image/gif", "image/bmp", "image/tiff", "image/avif", "image/heic", "image/heif", "image/x-png", } ALLOWED_EXTENSIONS = { ".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp", ".tiff", ".tif", ".avif", } MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB MAX_QUEUE_SIZE = 10 # max waiting tasks RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_MAX = 5 # requests per window per IP MAX_ACTIVE_PER_IP = 2 # concurrent tasks per IP THINKING_TIMEOUT = 120 # seconds (2 min max) RESULT_TTL = 3600 # keep results for 1 hour # ───────────────────────────────────────────── # ENUMS & DATA CLASSES # ───────────────────────────────────────────── class Mode(str, Enum): FAST = "fast" THINKING = "thinking" class TaskStatus(str, Enum): PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" @dataclass class Task: id: str mode: Mode image_data: bytes filename: str ip: str status: TaskStatus = TaskStatus.PENDING queue_pos: int = 0 created_at: float = field(default_factory=time.time) result_png: Optional[bytes] = None result_webp: Optional[bytes] = None error: Optional[str] = None analysis: Optional[str] = None orig_size: Optional[tuple] = None proc_time: Optional[float] = None stage: str = "انتظار" # ───────────────────────────────────────────── # GLOBAL STATE # ───────────────────────────────────────────── tasks: Dict[str, Task] = {} pending_queue: List[str] = [] queue_lock: asyncio.Lock = asyncio.Lock() ws_map: Dict[str, List[WebSocket]] = defaultdict(list) ip_times: Dict[str, List[float]] = defaultdict(list) ip_active: Dict[str, int] = defaultdict(int) current_task: Optional[str] = None # Sessions (loaded at startup) fast_session = None thinking_session = None anthropic_client = None # ───────────────────────────────────────────── # APP # ───────────────────────────────────────────── app = FastAPI(title="BG Remover Pro", version="2.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ───────────────────────────────────────────── # STARTUP # ───────────────────────────────────────────── @app.on_event("startup") async def startup_event(): global fast_session, thinking_session, anthropic_client log.info("Loading fast model (u2net)...") from rembg import new_session fast_session = new_session("u2net") log.info("✓ u2net loaded") log.info("Loading thinking model (birefnet-general)...") thinking_session = new_session("birefnet-general") log.info("✓ birefnet-general loaded") api_key = os.getenv("ANTHROPIC_API_KEY", "") if api_key: anthropic_client = anthropic.Anthropic(api_key=api_key) log.info("✓ Anthropic client initialized") else: log.warning("ANTHROPIC_API_KEY not set — AI analysis disabled") asyncio.create_task(queue_worker()) asyncio.create_task(cleanup_worker()) log.info("✓ Workers started") # ───────────────────────────────────────────── # RATE LIMITING # ───────────────────────────────────────────── def check_rate_limit(ip: str) -> tuple[bool, str]: now = time.time() ip_times[ip] = [t for t in ip_times[ip] if now - t < RATE_LIMIT_WINDOW] if len(ip_times[ip]) >= RATE_LIMIT_MAX: remaining = int(RATE_LIMIT_WINDOW - (now - ip_times[ip][0])) return False, f"تجاوزت الحد المسموح به ({RATE_LIMIT_MAX} طلبات/{RATE_LIMIT_WINDOW}ث). انتظر {remaining}ث" if ip_active[ip] >= MAX_ACTIVE_PER_IP: return False, f"لديك {MAX_ACTIVE_PER_IP} مهام نشطة بالفعل. انتظر اكتمالها" ip_times[ip].append(now) return True, "" # ───────────────────────────────────────────── # IMAGE VALIDATION # ───────────────────────────────────────────── async def validate_image(file: UploadFile, data: bytes) -> tuple[bool, str]: if len(data) > MAX_FILE_SIZE: return False, "حجم الملف يتجاوز 100MB" fname = file.filename or "" ext = Path(fname).suffix.lower() if ext and ext not in ALLOWED_EXTENSIONS: return False, f"امتداد غير مسموح: {ext}. المسموح: {', '.join(sorted(ALLOWED_EXTENSIONS))}" ct = (file.content_type or "").lower().split(";")[0].strip() if ct and ct not in ALLOWED_MIME_TYPES and not ct.startswith("image/"): return False, f"نوع الملف غير مسموح: {ct}" # Verify actual image bytes try: img = Image.open(io.BytesIO(data)) img.verify() except Exception: try: img = Image.open(io.BytesIO(data)) img.load() except Exception: return False, "الملف تالف أو ليس صورة صالحة" return True, "" # ───────────────────────────────────────────── # AI ANALYSIS (Claude) # ───────────────────────────────────────────── async def analyze_image(image_data: bytes, mode: Mode) -> str: if not anthropic_client: return "تحليل الذكاء الاصطناعي غير متاح (ANTHROPIC_API_KEY غير محدد)" try: # Resize for API if too large (saves tokens) img = Image.open(io.BytesIO(image_data)).convert("RGB") if max(img.size) > 1024: img.thumbnail((1024, 1024), Image.LANCZOS) buf = io.BytesIO() img.save(buf, format="JPEG", quality=85) b64 = base64.standard_b64encode(buf.getvalue()).decode() if mode == Mode.THINKING: # Extended thinking for maximum precision analysis response = anthropic_client.messages.create( model="claude-sonnet-4-20250514", max_tokens=2000, thinking={"type": "enabled", "budget_tokens": 8000}, messages=[{ "role": "user", "content": [ { "type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": b64} }, { "type": "text", "text": ( "أنت خبير محترف في معالجة الصور وإزالة الخلفيات. حلل هذه الصورة تحليلاً دقيقاً:\n\n" "1. **الموضوع الرئيسي**: ما هو؟ (شخص، حيوان، منتج، إلخ)\n" "2. **الخلفية**: طبيعتها ومدى تعقيدها\n" "3. **الحواف الصعبة**: هل يوجد شعر، فراء، شفافية، ظلال؟\n" "4. **مستوى الصعوبة**: سهل / متوسط / صعب جداً\n" "5. **توصية**: ما الإستراتيجية المثلى لإزالة الخلفية؟\n\n" "كن دقيقاً ومختصراً." ) } ] }] ) else: response = anthropic_client.messages.create( model="claude-sonnet-4-20250514", max_tokens=300, messages=[{ "role": "user", "content": [ { "type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": b64} }, { "type": "text", "text": "ما الموضوع الرئيسي في هذه الصورة؟ هل الخلفية بسيطة أم معقدة؟ جملتان فقط." } ] }] ) text_blocks = [b for b in response.content if b.type == "text"] return text_blocks[0].text if text_blocks else "تم التحليل" except Exception as e: log.error(f"Claude analysis error: {e}") return f"تعذر التحليل: {str(e)[:120]}" # ───────────────────────────────────────────── # BACKGROUND REMOVAL # ───────────────────────────────────────────── def _do_remove_fast(data: bytes) -> bytes: """Fast removal using u2net — standard quality, quick.""" from rembg import remove return remove( data, session=fast_session, alpha_matting=False, post_process_mask=True, bgcolor=None, ) def _do_remove_thinking(data: bytes) -> bytes: """ Thinking removal using BiRefNet + alpha matting. Multi-pass for maximum edge precision. """ from rembg import remove # Pass 1: BiRefNet segmentation with alpha matting result_bytes = remove( data, session=thinking_session, alpha_matting=True, alpha_matting_foreground_threshold=240, alpha_matting_background_threshold=10, alpha_matting_erode_size=10, post_process_mask=True, bgcolor=None, ) # Pass 2: Alpha channel refinement try: result_img = Image.open(io.BytesIO(result_bytes)).convert("RGBA") r, g, b, alpha = result_img.split() # Denoise alpha channel — reduces haloing artifacts alpha_arr = np.array(alpha, dtype=np.float32) # Bilateral-style smoothing on edge regions # Only smooth near-edge pixels (20–200), keep full opacity/transparency edge_mask = (alpha_arr > 20) & (alpha_arr < 235) if edge_mask.any(): from PIL import ImageFilter alpha_smooth = alpha.filter(ImageFilter.SMOOTH_MORE) alpha_arr2 = np.array(alpha_smooth, dtype=np.float32) # Blend only at edge pixels alpha_arr[edge_mask] = ( alpha_arr[edge_mask] * 0.4 + alpha_arr2[edge_mask] * 0.6 ) alpha_final = Image.fromarray(alpha_arr.clip(0, 255).astype(np.uint8)) final_img = Image.merge("RGBA", (r, g, b, alpha_final)) out = io.BytesIO() final_img.save(out, format="PNG", optimize=False, compress_level=1) return out.getvalue() except Exception as e: log.warning(f"Pass 2 refinement failed (returning pass 1): {e}") return result_bytes async def run_removal(task: Task) -> bytes: loop = asyncio.get_event_loop() if task.mode == Mode.FAST: return await loop.run_in_executor(None, _do_remove_fast, task.image_data) else: return await asyncio.wait_for( loop.run_in_executor(None, _do_remove_thinking, task.image_data), timeout=THINKING_TIMEOUT, ) # ───────────────────────────────────────────── # WEBSOCKET BROADCAST # ───────────────────────────────────────────── async def broadcast(task_id: str, payload: dict): dead = [] for ws in ws_map.get(task_id, []): try: await ws.send_json(payload) except Exception: dead.append(ws) for ws in dead: try: ws_map[task_id].remove(ws) except ValueError: pass async def broadcast_all_positions(): """Notify all waiting tasks of their new queue positions.""" async with queue_lock: for i, tid in enumerate(pending_queue): await broadcast(tid, { "event": "position_update", "position": i + 1, "total": len(pending_queue), }) # ───────────────────────────────────────────── # QUEUE WORKER # ───────────────────────────────────────────── async def queue_worker(): global current_task log.info("Queue worker started") while True: task_id = None async with queue_lock: if pending_queue: task_id = pending_queue.pop(0) t = tasks.get(task_id) if t: t.status = TaskStatus.PROCESSING t.stage = "تحليل الصورة" t.queue_pos = 0 current_task = task_id # Update remaining positions for i, tid in enumerate(pending_queue): if tid in tasks: tasks[tid].queue_pos = i + 1 if not task_id: await asyncio.sleep(0.3) continue task = tasks.get(task_id) if not task: current_task = None continue start = time.time() try: # Step 1: AI analysis await broadcast(task_id, {"event": "stage", "stage": "تحليل الصورة بالذكاء الاصطناعي..."}) task.stage = "تحليل" task.analysis = await analyze_image(task.image_data, task.mode) # Step 2: Background removal stage_msg = ( "إزالة الخلفية — وضع التفكير العميق (حتى دقيقتين)..." if task.mode == Mode.THINKING else "إزالة الخلفية — الوضع السريع..." ) await broadcast(task_id, {"event": "stage", "stage": stage_msg, "analysis": task.analysis}) task.stage = "إزالة الخلفية" result_bytes = await run_removal(task) task.result_png = result_bytes # Step 3: Generate WebP lossless await broadcast(task_id, {"event": "stage", "stage": "توليد ملف WebP..."}) result_img = Image.open(io.BytesIO(result_bytes)).convert("RGBA") webp_buf = io.BytesIO() result_img.save(webp_buf, format="WEBP", lossless=True, quality=100) task.result_webp = webp_buf.getvalue() task.proc_time = time.time() - start task.status = TaskStatus.COMPLETED task.stage = "مكتمل" log.info(f"Task {task_id[:8]} completed in {task.proc_time:.1f}s ({task.mode})") await broadcast(task_id, { "event": "completed", "task_id": task_id, "proc_time": f"{task.proc_time:.1f}", "analysis": task.analysis, "size_kb": len(task.result_png) // 1024, }) except asyncio.TimeoutError: task.status = TaskStatus.FAILED task.error = "انتهت مهلة المعالجة (120 ثانية). جرب الوضع السريع" log.warning(f"Task {task_id[:8]} timed out") await broadcast(task_id, {"event": "failed", "error": task.error}) except Exception as exc: task.status = TaskStatus.FAILED task.error = str(exc) log.error(f"Task {task_id[:8]} failed: {exc}", exc_info=True) await broadcast(task_id, {"event": "failed", "error": str(exc)[:300]}) finally: ip_active[task.ip] = max(0, ip_active[task.ip] - 1) current_task = None del task.image_data # free memory immediately gc.collect() await broadcast_all_positions() await asyncio.sleep(0.1) # ───────────────────────────────────────────── # CLEANUP WORKER — removes old results # ───────────────────────────────────────────── async def cleanup_worker(): while True: await asyncio.sleep(300) now = time.time() stale = [ tid for tid, t in tasks.items() if now - t.created_at > RESULT_TTL and t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED) ] for tid in stale: del tasks[tid] if stale: log.info(f"Cleaned up {len(stale)} old tasks") # ───────────────────────────────────────────── # WEBSOCKET ENDPOINT # ───────────────────────────────────────────── @app.websocket("/ws/{task_id}") async def ws_endpoint(websocket: WebSocket, task_id: str): await websocket.accept() ws_map[task_id].append(websocket) # Send current state immediately task = tasks.get(task_id) if task: if task.status == TaskStatus.COMPLETED: await websocket.send_json({"event": "completed", "task_id": task_id, "proc_time": str(task.proc_time or 0), "analysis": task.analysis}) elif task.status == TaskStatus.FAILED: await websocket.send_json({"event": "failed", "error": task.error}) elif task.status == TaskStatus.PENDING: await websocket.send_json({"event": "queued", "position": task.queue_pos, "total": len(pending_queue)}) elif task.status == TaskStatus.PROCESSING: await websocket.send_json({"event": "stage", "stage": task.stage}) try: while True: await asyncio.wait_for(websocket.receive_text(), timeout=60) except (WebSocketDisconnect, asyncio.TimeoutError): pass finally: try: ws_map[task_id].remove(websocket) except ValueError: pass # ───────────────────────────────────────────── # HTTP ENDPOINTS # ───────────────────────────────────────────── @app.get("/health") async def health(): return {"status": "ok", "queue": len(pending_queue), "processing": current_task is not None} @app.get("/") async def root(): from fastapi.responses import FileResponse return FileResponse("static/index.html") @app.post("/upload") async def upload( request: Request, file: UploadFile = File(...), mode: str = "fast", ): ip = request.client.host or "unknown" # Validate mode if mode not in (Mode.FAST, Mode.THINKING): raise HTTPException(400, "وضع غير صالح. استخدم 'fast' أو 'thinking'") # Rate limit allowed, msg = check_rate_limit(ip) if not allowed: raise HTTPException(429, msg) # Queue capacity async with queue_lock: if len(pending_queue) >= MAX_QUEUE_SIZE: raise HTTPException(503, f"الطابور ممتلئ ({MAX_QUEUE_SIZE}/{MAX_QUEUE_SIZE}). يرجى الانتظار") # Read & validate data = await file.read() valid, err = await validate_image(file, data) if not valid: # Refund the rate limit slot ip_times[ip].pop() if ip_times[ip] else None raise HTTPException(400, err) # Image metadata img = Image.open(io.BytesIO(data)) orig_size = img.size # Create task task_id = str(uuid.uuid4()) task = Task( id=task_id, mode=Mode(mode), image_data=data, filename=file.filename or "image", ip=ip, orig_size=orig_size, ) async with queue_lock: tasks[task_id] = task pending_queue.append(task_id) task.queue_pos = len(pending_queue) ip_active[ip] += 1 log.info(f"New task {task_id[:8]} | mode={mode} | size={orig_size} | ip={ip}") return JSONResponse({ "task_id": task_id, "queue_pos": task.queue_pos, "queue_total": len(pending_queue), "mode": mode, "image_size": f"{orig_size[0]}×{orig_size[1]}", "filename": file.filename, }) @app.get("/status/{task_id}") async def status(task_id: str): task = tasks.get(task_id) if not task: raise HTTPException(404, "المهمة غير موجودة أو انتهت صلاحيتها") base = { "task_id": task_id, "status": task.status.value, "mode": task.mode.value, "filename": task.filename, } if task.status == TaskStatus.PENDING: base.update({"queue_pos": task.queue_pos, "queue_total": len(pending_queue) + (1 if current_task else 0)}) elif task.status == TaskStatus.PROCESSING: base.update({"stage": task.stage}) elif task.status == TaskStatus.COMPLETED: base.update({"proc_time": task.proc_time, "analysis": task.analysis, "size_kb": len(task.result_png or b"") // 1024}) elif task.status == TaskStatus.FAILED: base.update({"error": task.error}) return JSONResponse(base) @app.get("/result/{task_id}") async def result(task_id: str, fmt: str = "png"): task = tasks.get(task_id) if not task: raise HTTPException(404, "المهمة غير موجودة") if task.status != TaskStatus.COMPLETED: raise HTTPException(400, f"المهمة لم تكتمل. الحالة: {task.status.value}") stem = Path(task.filename).stem if fmt == "webp" and task.result_webp: return Response( content=task.result_webp, media_type="image/webp", headers={"Content-Disposition": f'attachment; filename="{stem}_nobg.webp"'}, ) return Response( content=task.result_png, media_type="image/png", headers={"Content-Disposition": f'attachment; filename="{stem}_nobg.png"'}, ) @app.get("/preview/{task_id}") async def preview(task_id: str): """Inline preview (no Content-Disposition) for display in browser.""" task = tasks.get(task_id) if not task or task.status != TaskStatus.COMPLETED: raise HTTPException(404, "النتيجة غير متاحة") return Response(content=task.result_png, media_type="image/png") @app.get("/queue-info") async def queue_info(): return JSONResponse({ "waiting": len(pending_queue), "max": MAX_QUEUE_SIZE, "free_slots": MAX_QUEUE_SIZE - len(pending_queue), "processing": current_task is not None, "total_tasks": len(tasks), }) @app.delete("/task/{task_id}") async def cancel_task(task_id: str, request: Request): task = tasks.get(task_id) if not task: raise HTTPException(404, "المهمة غير موجودة") if task.status == TaskStatus.PROCESSING: raise HTTPException(400, "لا يمكن إلغاء مهمة قيد المعالجة") async with queue_lock: if task_id in pending_queue: pending_queue.remove(task_id) ip_active[task.ip] = max(0, ip_active[task.ip] - 1) if task_id in tasks: del tasks[task_id] await broadcast_all_positions() return JSONResponse({"message": "تم إلغاء المهمة"}) # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") # ───────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, loop="asyncio")