| import json |
| import math |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from anyio.to_thread import run_sync |
| from fastapi import FastAPI, Request, UploadFile, File |
| from fastapi.responses import Response |
| from segmentation_models_pytorch import UnetPlusPlus |
|
|
| |
| MODEL_PATH = "models/InkErase" |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| TILE_SIZE = 512 |
| |
| OVERLAP = 64 |
|
|
|
|
| |
| |
| |
|
|
| def _ceil_to_multiple(value: int, multiple: int) -> int: |
| if multiple <= 1: |
| return value |
| return int(math.ceil(value / multiple) * multiple) |
|
|
|
|
| def _build_starts(length: int, tile: int, stride: int) -> list[int]: |
| if length <= tile: |
| return [0] |
| starts = list(range(0, length - tile + 1, stride)) |
| last = length - tile |
| if starts[-1] != last: |
| starts.append(last) |
| return starts |
|
|
|
|
| def _precompute_axis_weights(starts: list[int], tile: int, overlap: int) -> list[torch.Tensor]: |
| """预计算融合权重,用于消除拼接缝隙""" |
| max_start = starts[-1] |
| weights: list[torch.Tensor] = [] |
| if overlap <= 0: |
| one = torch.ones(tile, dtype=torch.float32) |
| return [one for _ in starts] |
|
|
| |
| ramp_up = torch.linspace(0.0, 1.0, overlap, dtype=torch.float32) |
| ramp_down = torch.linspace(1.0, 0.0, overlap, dtype=torch.float32) |
|
|
| for start in starts: |
| w = torch.ones(tile, dtype=torch.float32) |
| if start > 0: |
| w[:overlap] *= ramp_up |
| if start < max_start: |
| w[-overlap:] *= ramp_down |
| weights.append(w) |
| return weights |
|
|
|
|
| def _tiled_infer( |
| model: torch.nn.Module, |
| x_cpu: torch.Tensor, |
| tile_size: int = 512, |
| overlap: int = 64, |
| batch_size: int = 1, |
| pad_multiple: int = 32, |
| pad_mode: str = "replicate", |
| ) -> torch.Tensor: |
| """ |
| 执行分块推理并融合结果 |
| x_cpu: [1, 3, H, W] 的 Tensor (CPU) |
| """ |
| _, _, h, w = x_cpu.shape |
| |
| |
| padded_h = _ceil_to_multiple(max(h, tile_size), pad_multiple) |
| padded_w = _ceil_to_multiple(max(w, tile_size), pad_multiple) |
|
|
| pad_h = padded_h - h |
| pad_w = padded_w - w |
| if pad_h or pad_w: |
| x_cpu = F.pad(x_cpu, (0, pad_w, 0, pad_h), mode=pad_mode) |
|
|
| |
| stride = tile_size - overlap |
| y_starts = _build_starts(padded_h, tile_size, stride) |
| x_starts = _build_starts(padded_w, tile_size, stride) |
|
|
| y_weights = _precompute_axis_weights(y_starts, tile_size, overlap) |
| x_weights = _precompute_axis_weights(x_starts, tile_size, overlap) |
|
|
| |
| |
| |
| channels = x_cpu.shape[1] |
| accum = torch.zeros((1, channels, padded_h, padded_w), dtype=torch.float32) |
| weight = torch.zeros((1, 1, padded_h, padded_w), dtype=torch.float32) |
|
|
| coords = [] |
| for yi, yy in enumerate(y_starts): |
| for xi, xx in enumerate(x_starts): |
| coords.append((yy, xx, yi, xi)) |
|
|
| |
| |
| with torch.inference_mode(): |
| for i in range(0, len(coords), batch_size): |
| chunk = coords[i : i + batch_size] |
| |
| |
| tiles = torch.stack( |
| [x_cpu[0, :, yy : yy + tile_size, xx : xx + tile_size] for (yy, xx, _, _) in chunk], |
| dim=0, |
| ).to(device) |
|
|
| |
| pred = model(tiles).float().detach().cpu() |
|
|
| |
| for bi, (yy, xx, yi, xi) in enumerate(chunk): |
| wy = y_weights[yi] |
| wx = x_weights[xi] |
| |
| m = (wy[:, None] * wx[None, :]).unsqueeze(0).unsqueeze(0) |
|
|
| accum[:, :, yy : yy + tile_size, xx : xx + tile_size] += pred[bi : bi + 1] * m |
| weight[:, :, yy : yy + tile_size, xx : xx + tile_size] += m |
|
|
| |
| out = (accum / weight.clamp_min(1e-8)).clamp(0, 1) |
| return out[:, :, :h, :w] |
|
|
|
|
| |
| |
| |
|
|
| def load_model() -> UnetPlusPlus: |
| """加載模型""" |
| path = Path(MODEL_PATH) |
| cfg = json.loads((path / "config.json").read_text(encoding="utf-8")) |
| |
| model = UnetPlusPlus( |
| encoder_name=cfg.get("encoder_name", "resnet50"), |
| encoder_weights=None, |
| in_channels=int(cfg.get("in_channels", 3)), |
| classes=int(cfg.get("classes", 3)), |
| decoder_attention_type=cfg.get("decoder_attention_type"), |
| activation=cfg.get("activation", "sigmoid"), |
| ) |
| |
| |
| weights_path = path / "model.safetensors" |
| if weights_path.exists(): |
| try: |
| from safetensors.torch import load_file |
| state_dict = load_file(str(weights_path)) |
| |
| model_keys = set(model.state_dict().keys()) |
| filtered_dict = {k: v for k, v in state_dict.items() if k in model_keys} |
| model.load_state_dict(filtered_dict, strict=False) |
| print(f"Loaded weights from {weights_path}") |
| except Exception as e: |
| print(f"Failed to load weights: {e}") |
| |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(instance: FastAPI): |
| instance.state.model = load_model() |
| yield |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
|
|
| @app.post("/predict") |
| async def predict(request: Request, file: UploadFile = File(...)): |
| """ |
| 笔迹擦除 (使用 Tiling + Overlap) |
| """ |
| content = await file.read() |
| nparr = np.frombuffer(content, np.uint8) |
| |
| |
| img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
| |
| model = request.app.state.model |
|
|
| def _inference_logic(): |
| |
| |
| input_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0 |
| input_tensor = input_tensor.unsqueeze(0) |
|
|
| |
| output_tensor = _tiled_infer( |
| model=model, |
| x_cpu=input_tensor, |
| tile_size=TILE_SIZE, |
| overlap=OVERLAP, |
| batch_size=1, |
| pad_mode="replicate" |
| ) |
| |
| |
| output_tensor = output_tensor.squeeze(0).permute(1, 2, 0) |
| output_np = (output_tensor.numpy() * 255).astype(np.uint8) |
| |
| return output_np |
|
|
| |
| result_rgb = await run_sync(_inference_logic) |
|
|
| |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) |
|
|
| |
| success, encoded_image = cv2.imencode(".png", result_bgr) |
| return Response(content=encoded_image.tobytes(), media_type="image/png") |
|
|
|
|
| @app.get("/") |
| def greet_json(): |
| return {"Hello": "World!"} |
|
|
|
|
| if __name__ == '__main__': |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=8000) |
|
|