Spaces:
Running
Running
| from fastapi import FastAPI, Request | |
| from fastapi.responses import Response | |
| from starlette.formparsers import MultiPartParser | |
| import uvicorn | |
| import aiohttp | |
| import asyncio | |
| import base64 | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from io import BytesIO | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import normalize | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoModelForImageSegmentation | |
| try: | |
| import pkg_resources | |
| except Exception: | |
| pkg_resources = None | |
| # 上传限制:50MB | |
| MAX_UPLOAD_SIZE = 1024 * 1024 * 50 | |
| MultiPartParser.spool_max_size = MAX_UPLOAD_SIZE | |
| MultiPartParser.max_part_size = MAX_UPLOAD_SIZE | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| read_key = os.environ.get("HF_TOKEN", None) | |
| HF_DATASET_REPO = "Maid-10000/RMBG-DataBase" | |
| MODEL_ALLOW_PATTERNS = [ | |
| "config.json", | |
| "MyConfig.py", | |
| "briarmbg.py", | |
| "pytorch_model.bin", | |
| ] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| local_model_path = snapshot_download( | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| allow_patterns=MODEL_ALLOW_PATTERNS, | |
| token=read_key, | |
| ) | |
| return AutoModelForImageSegmentation.from_pretrained( | |
| local_model_path, | |
| trust_remote_code=True, | |
| use_safetensors=False, | |
| ) | |
| logger.info(f"Loading BriaRMBG model on {device}...") | |
| net = load_model() | |
| net.to(device) | |
| net.eval() | |
| logger.info("Model loaded.") | |
| def end_time(start_time, text): | |
| logger.info(f"{text}: 共执行 {time.time() - start_time:.3f} 秒") | |
| def resize_image(image: Image.Image) -> Image.Image: | |
| image = image.convert("RGB") | |
| return image.resize((1024, 1024), Image.BILINEAR) | |
| def process(image: Image.Image) -> Response: | |
| orig_image = image.convert("RGB") | |
| w, h = orig_image.size | |
| input_image = resize_image(orig_image) | |
| im_np = np.array(input_image) | |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) | |
| im_tensor = im_tensor.unsqueeze(0) | |
| im_tensor = im_tensor / 255.0 | |
| im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
| im_tensor = im_tensor.to(device) | |
| with torch.no_grad(): | |
| result = net(im_tensor) | |
| result = result[0][0] | |
| result = F.interpolate(result, size=(h, w), mode="bilinear", align_corners=False) | |
| result = torch.squeeze(result, 0) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| if ma == mi: | |
| result = torch.zeros_like(result) | |
| else: | |
| result = (result - mi) / (ma - mi) | |
| result_array = (result * 255).cpu().numpy().astype(np.uint8) | |
| pil_mask = Image.fromarray(np.squeeze(result_array)) | |
| new_im = orig_image.copy() | |
| new_im.putalpha(pil_mask) | |
| buf = BytesIO() | |
| new_im.save(buf, format="PNG") | |
| buf.seek(0) | |
| return Response(content=buf.read(), media_type="image/png") | |
| async def fetch_url(url: str) -> bytes: | |
| timeout = aiohttp.ClientTimeout(total=15) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.get(url) as response: | |
| if response.status != 200: | |
| raise RuntimeError(f"URL returned status {response.status}") | |
| return await response.read() | |
| async def main(): | |
| return {"code": 200, "msg": "Success"} | |
| async def rmbg(request: Request): | |
| init_time = time.time() | |
| try: | |
| start_time = time.time() | |
| params = dict(request.query_params) | |
| if request.method in ["POST", "PUT", "PATCH"]: | |
| content_type = request.headers.get("Content-Type", "").lower() | |
| if "application/json" in content_type: | |
| json_data = await request.json() | |
| params.update(json_data) | |
| else: | |
| form_data = await request.form( | |
| max_part_size=MultiPartParser.max_part_size | |
| ) | |
| for key, value in form_data.items(): | |
| params[key] = value | |
| end_time(start_time, "参数合并完成") | |
| except Exception as e: | |
| logger.exception(e) | |
| return { | |
| "code": 503, | |
| "msg": "An unexpected error occurred during the parameter assignment process" | |
| } | |
| url = params.get("url") | |
| file = params.get("file") | |
| b64 = params.get("base64") | |
| try: | |
| start_time = time.time() | |
| if file: | |
| data = await file.read() | |
| elif b64: | |
| pattern = r"^data:image\/[a-zA-Z]+;base64," | |
| if re.match(pattern, b64): | |
| b64 = re.sub(pattern, "", b64) | |
| data = base64.b64decode(b64, validate=True) | |
| elif url: | |
| data = await fetch_url(url) | |
| else: | |
| return {"code": 503, "msg": "No image parameters entered"} | |
| end_time(start_time, "图片下载完成") | |
| except Exception as e: | |
| logger.exception(e) | |
| return {"code": 503, "msg": "Image parameter parsing error"} | |
| try: | |
| start_time = time.time() | |
| loop = asyncio.get_running_loop() | |
| image = await loop.run_in_executor( | |
| None, | |
| lambda: Image.open(BytesIO(data)).convert("RGB") | |
| ) | |
| end_time(start_time, "图片读取完成") | |
| except Exception as e: | |
| logger.exception(e) | |
| return {"code": 503, "msg": "The input is not an image"} | |
| try: | |
| start_time = time.time() | |
| loop = asyncio.get_running_loop() | |
| result = await loop.run_in_executor(None, lambda: process(image)) | |
| end_time(start_time, "图片分析完成") | |
| end_time(init_time, "[ 任务总耗时 ]") | |
| return result | |
| except Exception as e: | |
| logger.exception(e) | |
| return {"code": 503, "msg": "Image processing failed"} | |
| async def pkg_version(): | |
| if pkg_resources is None: | |
| return {"code": 503, "msg": "pkg_resources unavailable"} | |
| installed_packages = pkg_resources.working_set | |
| packages_list = [ | |
| {"name": pkg.project_name, "version": pkg.version} | |
| for pkg in installed_packages | |
| ] | |
| packages_list = sorted(packages_list, key=lambda x: x["name"]) | |
| return { | |
| "code": 200, | |
| "msg": "Success", | |
| "pkg-version": packages_list | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| workers=1 | |
| ) | |