from fastapi import FastAPI, Request from fastapi.responses import Response from starlette.formparsers import MultiPartParser import uvicorn import aiohttp import asyncio import base64 import logging import os import re import time from io import BytesIO import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from PIL import Image from huggingface_hub import snapshot_download from transformers import AutoModelForImageSegmentation try: import pkg_resources except Exception: pkg_resources = None # 上传限制:50MB MAX_UPLOAD_SIZE = 1024 * 1024 * 50 MultiPartParser.spool_max_size = MAX_UPLOAD_SIZE MultiPartParser.max_part_size = MAX_UPLOAD_SIZE logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s" ) logger = logging.getLogger(__name__) app = FastAPI() read_key = os.environ.get("HF_TOKEN", None) HF_DATASET_REPO = "Maid-10000/RMBG-DataBase" MODEL_ALLOW_PATTERNS = [ "config.json", "MyConfig.py", "briarmbg.py", "pytorch_model.bin", ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(): local_model_path = snapshot_download( repo_id=HF_DATASET_REPO, repo_type="dataset", allow_patterns=MODEL_ALLOW_PATTERNS, token=read_key, ) return AutoModelForImageSegmentation.from_pretrained( local_model_path, trust_remote_code=True, use_safetensors=False, ) logger.info(f"Loading BriaRMBG model on {device}...") net = load_model() net.to(device) net.eval() logger.info("Model loaded.") def end_time(start_time, text): logger.info(f"{text}: 共执行 {time.time() - start_time:.3f} 秒") def resize_image(image: Image.Image) -> Image.Image: image = image.convert("RGB") return image.resize((1024, 1024), Image.BILINEAR) def process(image: Image.Image) -> Response: orig_image = image.convert("RGB") w, h = orig_image.size input_image = resize_image(orig_image) im_np = np.array(input_image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) im_tensor = im_tensor.unsqueeze(0) im_tensor = im_tensor / 255.0 im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) im_tensor = im_tensor.to(device) with torch.no_grad(): result = net(im_tensor) result = result[0][0] result = F.interpolate(result, size=(h, w), mode="bilinear", align_corners=False) result = torch.squeeze(result, 0) ma = torch.max(result) mi = torch.min(result) if ma == mi: result = torch.zeros_like(result) else: result = (result - mi) / (ma - mi) result_array = (result * 255).cpu().numpy().astype(np.uint8) pil_mask = Image.fromarray(np.squeeze(result_array)) new_im = orig_image.copy() new_im.putalpha(pil_mask) buf = BytesIO() new_im.save(buf, format="PNG") buf.seek(0) return Response(content=buf.read(), media_type="image/png") async def fetch_url(url: str) -> bytes: timeout = aiohttp.ClientTimeout(total=15) async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get(url) as response: if response.status != 200: raise RuntimeError(f"URL returned status {response.status}") return await response.read() @app.get("/") async def main(): return {"code": 200, "msg": "Success"} @app.api_route("/rmbg", methods=["GET", "POST"]) async def rmbg(request: Request): init_time = time.time() try: start_time = time.time() params = dict(request.query_params) if request.method in ["POST", "PUT", "PATCH"]: content_type = request.headers.get("Content-Type", "").lower() if "application/json" in content_type: json_data = await request.json() params.update(json_data) else: form_data = await request.form( max_part_size=MultiPartParser.max_part_size ) for key, value in form_data.items(): params[key] = value end_time(start_time, "参数合并完成") except Exception as e: logger.exception(e) return { "code": 503, "msg": "An unexpected error occurred during the parameter assignment process" } url = params.get("url") file = params.get("file") b64 = params.get("base64") try: start_time = time.time() if file: data = await file.read() elif b64: pattern = r"^data:image\/[a-zA-Z]+;base64," if re.match(pattern, b64): b64 = re.sub(pattern, "", b64) data = base64.b64decode(b64, validate=True) elif url: data = await fetch_url(url) else: return {"code": 503, "msg": "No image parameters entered"} end_time(start_time, "图片下载完成") except Exception as e: logger.exception(e) return {"code": 503, "msg": "Image parameter parsing error"} try: start_time = time.time() loop = asyncio.get_running_loop() image = await loop.run_in_executor( None, lambda: Image.open(BytesIO(data)).convert("RGB") ) end_time(start_time, "图片读取完成") except Exception as e: logger.exception(e) return {"code": 503, "msg": "The input is not an image"} try: start_time = time.time() loop = asyncio.get_running_loop() result = await loop.run_in_executor(None, lambda: process(image)) end_time(start_time, "图片分析完成") end_time(init_time, "[ 任务总耗时 ]") return result except Exception as e: logger.exception(e) return {"code": 503, "msg": "Image processing failed"} @app.get("/pkg-version") async def pkg_version(): if pkg_resources is None: return {"code": 503, "msg": "pkg_resources unavailable"} installed_packages = pkg_resources.working_set packages_list = [ {"name": pkg.project_name, "version": pkg.version} for pkg in installed_packages ] packages_list = sorted(packages_list, key=lambda x: x["name"]) return { "code": 200, "msg": "Success", "pkg-version": packages_list } if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=7860, workers=1 )