import json import math from contextlib import asynccontextmanager from pathlib import Path import cv2 import numpy as np import torch import torch.nn.functional as F from anyio.to_thread import run_sync from fastapi import FastAPI, Request, UploadFile, File from fastapi.responses import Response from segmentation_models_pytorch import UnetPlusPlus # 模型路徑 MODEL_PATH = "models/InkErase" # 設備 device = "cuda" if torch.cuda.is_available() else "cpu" # 分块大小 (参考脚本默认 512) TILE_SIZE = 512 # 重叠大小 (参考脚本默认 64) OVERLAP = 64 # ========================================== # 核心 Tiling 算法 (移植自 infer_hd.py) # ========================================== def _ceil_to_multiple(value: int, multiple: int) -> int: if multiple <= 1: return value return int(math.ceil(value / multiple) * multiple) def _build_starts(length: int, tile: int, stride: int) -> list[int]: if length <= tile: return [0] starts = list(range(0, length - tile + 1, stride)) last = length - tile if starts[-1] != last: starts.append(last) return starts def _precompute_axis_weights(starts: list[int], tile: int, overlap: int) -> list[torch.Tensor]: """预计算融合权重,用于消除拼接缝隙""" max_start = starts[-1] weights: list[torch.Tensor] = [] if overlap <= 0: one = torch.ones(tile, dtype=torch.float32) return [one for _ in starts] # 创建渐变权重 (Ramp) ramp_up = torch.linspace(0.0, 1.0, overlap, dtype=torch.float32) ramp_down = torch.linspace(1.0, 0.0, overlap, dtype=torch.float32) for start in starts: w = torch.ones(tile, dtype=torch.float32) if start > 0: w[:overlap] *= ramp_up if start < max_start: w[-overlap:] *= ramp_down weights.append(w) return weights def _tiled_infer( model: torch.nn.Module, x_cpu: torch.Tensor, tile_size: int = 512, overlap: int = 64, batch_size: int = 1, pad_multiple: int = 32, pad_mode: str = "replicate", ) -> torch.Tensor: """ 执行分块推理并融合结果 x_cpu: [1, 3, H, W] 的 Tensor (CPU) """ _, _, h, w = x_cpu.shape # 1. 计算 Padding 后的尺寸 padded_h = _ceil_to_multiple(max(h, tile_size), pad_multiple) padded_w = _ceil_to_multiple(max(w, tile_size), pad_multiple) pad_h = padded_h - h pad_w = padded_w - w if pad_h or pad_w: x_cpu = F.pad(x_cpu, (0, pad_w, 0, pad_h), mode=pad_mode) # 2. 计算切片坐标 stride = tile_size - overlap y_starts = _build_starts(padded_h, tile_size, stride) x_starts = _build_starts(padded_w, tile_size, stride) y_weights = _precompute_axis_weights(y_starts, tile_size, overlap) x_weights = _precompute_axis_weights(x_starts, tile_size, overlap) # 3. 初始化累加器和权重图 # 注意:这里假设输出是 3 通道 (RGB),如果你确认只输出单通道 Mask,可以改这里为 1 # 但根据 infer_hd.py 的逻辑,它初始化为 x_cpu.shape[1] 即 3 channels = x_cpu.shape[1] accum = torch.zeros((1, channels, padded_h, padded_w), dtype=torch.float32) weight = torch.zeros((1, 1, padded_h, padded_w), dtype=torch.float32) coords = [] for yi, yy in enumerate(y_starts): for xi, xx in enumerate(x_starts): coords.append((yy, xx, yi, xi)) # 4. 批量推理 # model 已经在外部被移动到了 device with torch.inference_mode(): for i in range(0, len(coords), batch_size): chunk = coords[i : i + batch_size] # 提取 Batch Tiles tiles = torch.stack( [x_cpu[0, :, yy : yy + tile_size, xx : xx + tile_size] for (yy, xx, _, _) in chunk], dim=0, ).to(device) # 推理 pred = model(tiles).float().detach().cpu() # [B, C, tile, tile] # 累加结果 (带权重) for bi, (yy, xx, yi, xi) in enumerate(chunk): wy = y_weights[yi] wx = x_weights[xi] # 构建权重矩阵 [1, 1, tile, tile] m = (wy[:, None] * wx[None, :]).unsqueeze(0).unsqueeze(0) accum[:, :, yy : yy + tile_size, xx : xx + tile_size] += pred[bi : bi + 1] * m weight[:, :, yy : yy + tile_size, xx : xx + tile_size] += m # 5. 归一化并裁剪 out = (accum / weight.clamp_min(1e-8)).clamp(0, 1) return out[:, :, :h, :w] # ========================================== # FastAPI 逻辑 # ========================================== def load_model() -> UnetPlusPlus: """加載模型""" path = Path(MODEL_PATH) cfg = json.loads((path / "config.json").read_text(encoding="utf-8")) model = UnetPlusPlus( encoder_name=cfg.get("encoder_name", "resnet50"), encoder_weights=None, # 注意:如果需要加载预训练权重,需在此处处理 in_channels=int(cfg.get("in_channels", 3)), classes=int(cfg.get("classes", 3)), decoder_attention_type=cfg.get("decoder_attention_type"), activation=cfg.get("activation", "sigmoid"), ) # 如果有本地权重文件 (参考 infer_hd.py 中的 model.safetensors) weights_path = path / "model.safetensors" if weights_path.exists(): try: from safetensors.torch import load_file state_dict = load_file(str(weights_path)) # 简单的 key 过滤,防止不匹配 model_keys = set(model.state_dict().keys()) filtered_dict = {k: v for k, v in state_dict.items() if k in model_keys} model.load_state_dict(filtered_dict, strict=False) print(f"Loaded weights from {weights_path}") except Exception as e: print(f"Failed to load weights: {e}") model.to(device) model.eval() return model @asynccontextmanager async def lifespan(instance: FastAPI): instance.state.model = load_model() yield app = FastAPI(lifespan=lifespan) @app.post("/predict") async def predict(request: Request, file: UploadFile = File(...)): """ 笔迹擦除 (使用 Tiling + Overlap) """ content = await file.read() nparr = np.frombuffer(content, np.uint8) # 1. OpenCV 解码 -> BGR img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # 转 RGB img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) model = request.app.state.model def _inference_logic(): # 2. 预处理: NumPy (H, W, C) -> Tensor (1, C, H, W) 且归一化到 [0, 1] # 参考脚本使用的是 TF.to_tensor,它会把 uint8 除以 255 转 float input_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0 input_tensor = input_tensor.unsqueeze(0) # [1, 3, H, W] # 3. 执行分块推理 output_tensor = _tiled_infer( model=model, x_cpu=input_tensor, tile_size=TILE_SIZE, overlap=OVERLAP, batch_size=1, # 显存够大可以调大 pad_mode="replicate" ) # 4. 后处理: Tensor (1, C, H, W) -> NumPy (H, W, C) [0, 255] output_tensor = output_tensor.squeeze(0).permute(1, 2, 0) # [H, W, C] output_np = (output_tensor.numpy() * 255).astype(np.uint8) return output_np # 執行推理 (在线程池中运行 CPU 密集型操作) result_rgb = await run_sync(_inference_logic) # 5. 转回 BGR 以便 OpenCV 编码 result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) # 编码返回 success, encoded_image = cv2.imencode(".png", result_bgr) return Response(content=encoded_image.tobytes(), media_type="image/png") @app.get("/") def greet_json(): return {"Hello": "World!"} if __name__ == '__main__': import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=8000)