| from fastapi import FastAPI, File, UploadFile, Form |
| from fastapi.responses import JSONResponse |
| from PIL import Image |
| import io |
| import torch |
| import logging |
| from typing import List, Optional, Tuple |
| import requests |
| import threading |
| from contextlib import asynccontextmanager |
| import queue |
|
|
| from logic import WatermarkRemover |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| remover: Optional[WatermarkRemover] = None |
| MAX_QUEUE_SIZE = 16 |
| task_queue: "queue.Queue[Tuple[int, List[bytes], str, Optional[str]]]" = queue.Queue(maxsize=MAX_QUEUE_SIZE) |
|
|
| def process_in_background(task_id: int, image_contents: List[bytes], callback_url: str, webhook_secret: str): |
| """Эта функция теперь вызывается только из воркера, поэтому model_lock не нужен.""" |
| global remover |
| logger.info(f"[Task {task_id}] Worker picked up task. Processing {len(image_contents)} images.") |
| status = "success" |
| cleaned_image_data = [] |
| |
| try: |
| |
| |
| |
| for i, contents in enumerate(image_contents): |
| image = Image.open(io.BytesIO(contents)).convert("RGB") |
| cleaned_image = remover.run(image) |
| |
| buf = io.BytesIO() |
| cleaned_image.save(buf, format="JPEG", quality=90, optimize=True) |
| cleaned_image_data.append(buf.getvalue()) |
| |
| logger.info(f"[Task {task_id}] All images processed successfully.") |
|
|
| except Exception as e: |
| logger.error(f"[Task {task_id}] Error during processing: {e}", exc_info=True) |
| status = "error" |
|
|
| |
| headers = {"X-Webhook-Secret": webhook_secret if webhook_secret else ""} |
| files = [('images', (f'image_{i}.jpeg', img_bytes, 'image/jpeg')) for i, img_bytes in enumerate(cleaned_image_data)] |
| data_payload = {'task_id': str(task_id), 'status': status} |
| |
| try: |
| logger.info(f"[Task {task_id}] Sending callback to {callback_url}") |
| requests.post(callback_url, files=files, data=data_payload, headers=headers, timeout=600).raise_for_status() |
| logger.info(f"[Task {task_id}] Callback sent successfully.") |
| except requests.RequestException as e: |
| logger.error(f"[Task {task_id}] Failed to send callback: {e}") |
|
|
| |
| def queue_worker(): |
| logger.info("Queue worker thread started.") |
| while True: |
| try: |
| |
| task_id, image_contents, callback_url, webhook_secret = task_queue.get() |
| |
| process_in_background(task_id, image_contents, callback_url, webhook_secret) |
|
|
| |
| task_queue.task_done() |
| except Exception as e: |
| |
| logger.exception(f"Critical error in queue_worker loop: {e}") |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global remover |
| logger.info("Application startup... Performing 'soft' warm-up.") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| remover = WatermarkRemover(device=device) |
| |
| |
| |
| remover._load_detector() |
| logger.info("Detector model pre-loaded successfully. Main model will be loaded on first request.") |
|
|
| if torch.cuda.is_available(): |
| mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| logger.info(f"Total GPU memory: {mem:.2f} GB") |
| logger.info(f"Allocated after detector load: {torch.cuda.memory_allocated() / (1024**3):.2f} GB") |
| |
| worker_thread = threading.Thread(target=queue_worker, daemon=True, name="QueueWorker") |
| worker_thread.start() |
| logger.info("Queue worker has been started.") |
| |
| yield |
| |
| logger.info("Application shutdown.") |
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| @app.get("/") |
| def root(): |
| |
| inpainting_loaded = remover is not None and remover.inpainting_pipe is not None |
| return { |
| "message": "Simba AI Services is running", |
| "detector_model_loaded": remover is not None and remover.detector is not None, |
| "inpainting_model_loaded": inpainting_loaded, |
| "tasks_in_queue": task_queue.qsize(), |
| "queue_capacity": MAX_QUEUE_SIZE |
| } |
|
|
| |
| @app.post("/process_images/") |
| async def process_images_endpoint( |
| images: List[UploadFile] = File(...), |
| task_id: int = Form(...), |
| callback_url: str = Form(...), |
| webhook_secret: Optional[str] = Form(None) |
| ): |
| logger.info(f"Accepted task {task_id}. Images count: {len(images)}") |
| |
| |
| try: |
| |
| image_contents = [await image.read() for image in images] |
| |
| |
| task_queue.put_nowait((int(task_id), image_contents, callback_url, webhook_secret)) |
|
|
| current_queue_size = task_queue.qsize() |
| logger.info(f"Task {task_id} added to queue. Current queue size: {current_queue_size}") |
| |
| return JSONResponse( |
| status_code=202, |
| content={ |
| "message": "Task accepted and placed in queue.", |
| "task_id": task_id, |
| "position_in_queue": current_queue_size |
| } |
| ) |
| except queue.Full: |
| logger.warning(f"Task {task_id} rejected because queue is full (size: {task_queue.qsize()})") |
| return JSONResponse( |
| status_code=429, |
| content={"message": "Server is busy, the processing queue is full. Please try again later."} |
| ) |
| except Exception as e: |
| logger.error(f"Failed to accept task {task_id}: {e}", exc_info=True) |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An internal error occurred while accepting the task."} |
| ) |