InkErase / app.py
ynyg's picture
refactor: 重构分块推理逻辑并移除 Albumentations 依赖
1c20dbd
import json
import math
from contextlib import asynccontextmanager
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from anyio.to_thread import run_sync
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.responses import Response
from segmentation_models_pytorch import UnetPlusPlus
# 模型路徑
MODEL_PATH = "models/InkErase"
# 設備
device = "cuda" if torch.cuda.is_available() else "cpu"
# 分块大小 (参考脚本默认 512)
TILE_SIZE = 512
# 重叠大小 (参考脚本默认 64)
OVERLAP = 64
# ==========================================
# 核心 Tiling 算法 (移植自 infer_hd.py)
# ==========================================
def _ceil_to_multiple(value: int, multiple: int) -> int:
if multiple <= 1:
return value
return int(math.ceil(value / multiple) * multiple)
def _build_starts(length: int, tile: int, stride: int) -> list[int]:
if length <= tile:
return [0]
starts = list(range(0, length - tile + 1, stride))
last = length - tile
if starts[-1] != last:
starts.append(last)
return starts
def _precompute_axis_weights(starts: list[int], tile: int, overlap: int) -> list[torch.Tensor]:
"""预计算融合权重,用于消除拼接缝隙"""
max_start = starts[-1]
weights: list[torch.Tensor] = []
if overlap <= 0:
one = torch.ones(tile, dtype=torch.float32)
return [one for _ in starts]
# 创建渐变权重 (Ramp)
ramp_up = torch.linspace(0.0, 1.0, overlap, dtype=torch.float32)
ramp_down = torch.linspace(1.0, 0.0, overlap, dtype=torch.float32)
for start in starts:
w = torch.ones(tile, dtype=torch.float32)
if start > 0:
w[:overlap] *= ramp_up
if start < max_start:
w[-overlap:] *= ramp_down
weights.append(w)
return weights
def _tiled_infer(
model: torch.nn.Module,
x_cpu: torch.Tensor,
tile_size: int = 512,
overlap: int = 64,
batch_size: int = 1,
pad_multiple: int = 32,
pad_mode: str = "replicate",
) -> torch.Tensor:
"""
执行分块推理并融合结果
x_cpu: [1, 3, H, W] 的 Tensor (CPU)
"""
_, _, h, w = x_cpu.shape
# 1. 计算 Padding 后的尺寸
padded_h = _ceil_to_multiple(max(h, tile_size), pad_multiple)
padded_w = _ceil_to_multiple(max(w, tile_size), pad_multiple)
pad_h = padded_h - h
pad_w = padded_w - w
if pad_h or pad_w:
x_cpu = F.pad(x_cpu, (0, pad_w, 0, pad_h), mode=pad_mode)
# 2. 计算切片坐标
stride = tile_size - overlap
y_starts = _build_starts(padded_h, tile_size, stride)
x_starts = _build_starts(padded_w, tile_size, stride)
y_weights = _precompute_axis_weights(y_starts, tile_size, overlap)
x_weights = _precompute_axis_weights(x_starts, tile_size, overlap)
# 3. 初始化累加器和权重图
# 注意:这里假设输出是 3 通道 (RGB),如果你确认只输出单通道 Mask,可以改这里为 1
# 但根据 infer_hd.py 的逻辑,它初始化为 x_cpu.shape[1] 即 3
channels = x_cpu.shape[1]
accum = torch.zeros((1, channels, padded_h, padded_w), dtype=torch.float32)
weight = torch.zeros((1, 1, padded_h, padded_w), dtype=torch.float32)
coords = []
for yi, yy in enumerate(y_starts):
for xi, xx in enumerate(x_starts):
coords.append((yy, xx, yi, xi))
# 4. 批量推理
# model 已经在外部被移动到了 device
with torch.inference_mode():
for i in range(0, len(coords), batch_size):
chunk = coords[i : i + batch_size]
# 提取 Batch Tiles
tiles = torch.stack(
[x_cpu[0, :, yy : yy + tile_size, xx : xx + tile_size] for (yy, xx, _, _) in chunk],
dim=0,
).to(device)
# 推理
pred = model(tiles).float().detach().cpu() # [B, C, tile, tile]
# 累加结果 (带权重)
for bi, (yy, xx, yi, xi) in enumerate(chunk):
wy = y_weights[yi]
wx = x_weights[xi]
# 构建权重矩阵 [1, 1, tile, tile]
m = (wy[:, None] * wx[None, :]).unsqueeze(0).unsqueeze(0)
accum[:, :, yy : yy + tile_size, xx : xx + tile_size] += pred[bi : bi + 1] * m
weight[:, :, yy : yy + tile_size, xx : xx + tile_size] += m
# 5. 归一化并裁剪
out = (accum / weight.clamp_min(1e-8)).clamp(0, 1)
return out[:, :, :h, :w]
# ==========================================
# FastAPI 逻辑
# ==========================================
def load_model() -> UnetPlusPlus:
"""加載模型"""
path = Path(MODEL_PATH)
cfg = json.loads((path / "config.json").read_text(encoding="utf-8"))
model = UnetPlusPlus(
encoder_name=cfg.get("encoder_name", "resnet50"),
encoder_weights=None, # 注意:如果需要加载预训练权重,需在此处处理
in_channels=int(cfg.get("in_channels", 3)),
classes=int(cfg.get("classes", 3)),
decoder_attention_type=cfg.get("decoder_attention_type"),
activation=cfg.get("activation", "sigmoid"),
)
# 如果有本地权重文件 (参考 infer_hd.py 中的 model.safetensors)
weights_path = path / "model.safetensors"
if weights_path.exists():
try:
from safetensors.torch import load_file
state_dict = load_file(str(weights_path))
# 简单的 key 过滤,防止不匹配
model_keys = set(model.state_dict().keys())
filtered_dict = {k: v for k, v in state_dict.items() if k in model_keys}
model.load_state_dict(filtered_dict, strict=False)
print(f"Loaded weights from {weights_path}")
except Exception as e:
print(f"Failed to load weights: {e}")
model.to(device)
model.eval()
return model
@asynccontextmanager
async def lifespan(instance: FastAPI):
instance.state.model = load_model()
yield
app = FastAPI(lifespan=lifespan)
@app.post("/predict")
async def predict(request: Request, file: UploadFile = File(...)):
"""
笔迹擦除 (使用 Tiling + Overlap)
"""
content = await file.read()
nparr = np.frombuffer(content, np.uint8)
# 1. OpenCV 解码 -> BGR
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# 转 RGB
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
model = request.app.state.model
def _inference_logic():
# 2. 预处理: NumPy (H, W, C) -> Tensor (1, C, H, W) 且归一化到 [0, 1]
# 参考脚本使用的是 TF.to_tensor,它会把 uint8 除以 255 转 float
input_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
input_tensor = input_tensor.unsqueeze(0) # [1, 3, H, W]
# 3. 执行分块推理
output_tensor = _tiled_infer(
model=model,
x_cpu=input_tensor,
tile_size=TILE_SIZE,
overlap=OVERLAP,
batch_size=1, # 显存够大可以调大
pad_mode="replicate"
)
# 4. 后处理: Tensor (1, C, H, W) -> NumPy (H, W, C) [0, 255]
output_tensor = output_tensor.squeeze(0).permute(1, 2, 0) # [H, W, C]
output_np = (output_tensor.numpy() * 255).astype(np.uint8)
return output_np
# 執行推理 (在线程池中运行 CPU 密集型操作)
result_rgb = await run_sync(_inference_logic)
# 5. 转回 BGR 以便 OpenCV 编码
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
# 编码返回
success, encoded_image = cv2.imencode(".png", result_bgr)
return Response(content=encoded_image.tobytes(), media_type="image/png")
@app.get("/")
def greet_json():
return {"Hello": "World!"}
if __name__ == '__main__':
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000)