ynyg commited on
Commit
1c20dbd
·
1 Parent(s): 7bd5dc8

refactor: 重构分块推理逻辑并移除 Albumentations 依赖

Browse files

- 替换 Albumentations 的图像预处理逻辑,直接使用 PyTorch 实现归一化操作
- 引入自定义的分块推理方法 `_tiled_infer`,支持 Tiling + Overlap 机制
- 优化模型加载流程,动态检测和加载权重文件
- 使用 Torch 和 OpenCV 完成输入/输出的处理,移除了 Albumentations 相关代码
- 支持大尺寸图像的高效推理,提升内存和性能表现

Files changed (1) hide show
  1. app.py +184 -109
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import json
 
2
  from contextlib import asynccontextmanager
3
  from pathlib import Path
4
 
5
- import albumentations as A
6
  import cv2
7
  import numpy as np
8
  import torch
 
9
  from anyio.to_thread import run_sync
10
  from fastapi import FastAPI, Request, UploadFile, File
11
  from fastapi.responses import Response
@@ -15,45 +16,168 @@ from segmentation_models_pytorch import UnetPlusPlus
15
  MODEL_PATH = "models/InkErase"
16
  # 設備
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
- # 分块大小
19
- TRAIN_SIZE = 512
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def load_model() -> UnetPlusPlus:
23
  """加載模型"""
24
- # 模型路径
25
  path = Path(MODEL_PATH)
26
- # 读取配置文件
27
  cfg = json.loads((path / "config.json").read_text(encoding="utf-8"))
28
- # 加載模型
29
- return UnetPlusPlus(
30
  encoder_name=cfg.get("encoder_name", "resnet50"),
31
- encoder_weights=None,
32
  in_channels=int(cfg.get("in_channels", 3)),
33
  classes=int(cfg.get("classes", 3)),
34
  decoder_attention_type=cfg.get("decoder_attention_type"),
35
  activation=cfg.get("activation", "sigmoid"),
36
  )
37
-
38
-
39
- def get_preprocessing() -> A.Compose:
40
- """获取Albumentations 預處理 pipeline"""
41
- return A.Compose([
42
- A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
43
- A.ToTensorV2()
44
- ])
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  @asynccontextmanager
48
  async def lifespan(instance: FastAPI):
49
- """
50
- FastAPI 應用程序的生命周期管理器。
51
- :param instance: FastAPI 應用程序實例
52
- """
53
- # 加載模型
54
  instance.state.model = load_model()
55
- # 初始化預處理函數
56
- instance.state.preprocess_fn = get_preprocessing()
57
  yield
58
 
59
 
@@ -63,105 +187,56 @@ app = FastAPI(lifespan=lifespan)
63
  @app.post("/predict")
64
  async def predict(request: Request, file: UploadFile = File(...)):
65
  """
66
- 笔迹擦除
67
- :param request: 请求对象
68
- :param file: 待处理的图片
69
- :return: 預測結果,包括文本、預測類別和置信度
70
  """
71
- # 1. 使用 OpenCV 直接從內存讀取圖片
72
  content = await file.read()
73
- # 將 bytes 轉換為 numpy array
74
  nparr = np.frombuffer(content, np.uint8)
75
- # 解碼圖片 (默認 BGR)
76
- original_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
77
- # 转换为 RGB
78
- original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
79
-
80
- # 获取图片尺寸
81
- orig_h, orig_w = original_image.shape[:2]
82
- # 获取模型和处理流
83
  model = request.app.state.model
84
- preprocess_fn = request.app.state.preprocess_fn
85
 
86
  def _inference_logic():
87
- with torch.no_grad():
88
- # ==============================
89
- # 情況 A: 圖片大於 512,進行切塊處理
90
- # ==============================
91
- if orig_w > TRAIN_SIZE or orig_h > TRAIN_SIZE:
92
- # 1. 計算新的寬高(補齊為 512 的倍數)
93
- new_w = (orig_w // TRAIN_SIZE + (1 if orig_w % TRAIN_SIZE != 0 else 0)) * TRAIN_SIZE
94
- new_h = (orig_h // TRAIN_SIZE + (1 if orig_h % TRAIN_SIZE != 0 else 0)) * TRAIN_SIZE
95
-
96
- # 2. Padding 原圖 (使用 NumPy 創建黑色畫布)
97
- padded_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
98
- # 將原圖貼到左上角 (NumPy 切片賦值)
99
- padded_img[:orig_h, :orig_w, :] = original_image
100
-
101
- # 創建結果掩碼畫布 (單通道)
102
- result_mask = np.zeros((new_h, new_w), dtype=np.uint8)
103
-
104
- # 3. 循環切割
105
- for y in range(0, new_h, TRAIN_SIZE):
106
- for x in range(0, new_w, TRAIN_SIZE):
107
- # NumPy 切片代替 crop
108
- patch = padded_img[y:y + TRAIN_SIZE, x:x + TRAIN_SIZE]
109
-
110
- # Albumentations 處理 (patch 已經是 numpy array)
111
- transformed = preprocess_fn(image=patch)
112
- input_tensor = transformed["image"].unsqueeze(0).to(device)
113
-
114
- output = model(input_tensor)
115
-
116
- pred_mask = (output > 0.5).float().squeeze().cpu().numpy()
117
- pred_mask = (pred_mask * 255).astype(np.uint8)
118
-
119
- # 將預測結果貼回大圖
120
- result_mask[y:y + TRAIN_SIZE, x:x + TRAIN_SIZE] = pred_mask
121
-
122
- # 裁剪回原始尺寸
123
- final_image = result_mask[:orig_h, :orig_w]
124
-
125
- # ==============================
126
- # 情況 B: 圖片小於等於 512
127
- # ==============================
128
- else:
129
- # 創建黑色畫布
130
- padded_img = np.zeros((TRAIN_SIZE, TRAIN_SIZE, 3), dtype=np.uint8)
131
- padded_img[:orig_h, :orig_w, :] = original_image
132
-
133
- # Albumentations 處理
134
- transformed = preprocess_fn(image=padded_img)
135
- input_tensor = transformed["image"].unsqueeze(0).to(device)
136
-
137
- output = model(input_tensor)
138
-
139
- pred_mask = (output > 0.5).float().squeeze().cpu().numpy()
140
- pred_mask = (pred_mask * 255).astype(np.uint8)
141
-
142
- # 裁剪回原始尺寸
143
- final_image = pred_mask[:orig_h, :orig_w]
144
-
145
- return final_image
146
-
147
- # 執行推理
148
- result_image = await run_sync(_inference_logic)
149
-
150
- # 返回圖片流 (使用 cv2.imencode)
151
- # result_image 是单通道灰度图,可以直接编码为 PNG
152
- success, encoded_image = cv2.imencode(".png", result_image)
153
  return Response(content=encoded_image.tobytes(), media_type="image/png")
154
 
155
 
156
  @app.get("/")
157
  def greet_json():
158
- """
159
- 返回一個 JSON 格式的歡迎訊息。
160
- """
161
  return {"Hello": "World!"}
162
 
163
 
164
  if __name__ == '__main__':
165
  import uvicorn
166
-
167
  uvicorn.run("app:app", host="0.0.0.0", port=8000)
 
1
  import json
2
+ import math
3
  from contextlib import asynccontextmanager
4
  from pathlib import Path
5
 
 
6
  import cv2
7
  import numpy as np
8
  import torch
9
+ import torch.nn.functional as F
10
  from anyio.to_thread import run_sync
11
  from fastapi import FastAPI, Request, UploadFile, File
12
  from fastapi.responses import Response
 
16
  MODEL_PATH = "models/InkErase"
17
  # 設備
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ # 分块大小 (参考脚本默认 512)
20
+ TILE_SIZE = 512
21
+ # 重叠大小 (参考脚本默认 64)
22
+ OVERLAP = 64
23
+
24
+
25
+ # ==========================================
26
+ # 核心 Tiling 算法 (移植自 infer_hd.py)
27
+ # ==========================================
28
+
29
+ def _ceil_to_multiple(value: int, multiple: int) -> int:
30
+ if multiple <= 1:
31
+ return value
32
+ return int(math.ceil(value / multiple) * multiple)
33
+
34
+
35
+ def _build_starts(length: int, tile: int, stride: int) -> list[int]:
36
+ if length <= tile:
37
+ return [0]
38
+ starts = list(range(0, length - tile + 1, stride))
39
+ last = length - tile
40
+ if starts[-1] != last:
41
+ starts.append(last)
42
+ return starts
43
+
44
+
45
+ def _precompute_axis_weights(starts: list[int], tile: int, overlap: int) -> list[torch.Tensor]:
46
+ """预计算融合权重,用于消除拼接缝隙"""
47
+ max_start = starts[-1]
48
+ weights: list[torch.Tensor] = []
49
+ if overlap <= 0:
50
+ one = torch.ones(tile, dtype=torch.float32)
51
+ return [one for _ in starts]
52
+
53
+ # 创建渐变权重 (Ramp)
54
+ ramp_up = torch.linspace(0.0, 1.0, overlap, dtype=torch.float32)
55
+ ramp_down = torch.linspace(1.0, 0.0, overlap, dtype=torch.float32)
56
+
57
+ for start in starts:
58
+ w = torch.ones(tile, dtype=torch.float32)
59
+ if start > 0:
60
+ w[:overlap] *= ramp_up
61
+ if start < max_start:
62
+ w[-overlap:] *= ramp_down
63
+ weights.append(w)
64
+ return weights
65
+
66
+
67
+ def _tiled_infer(
68
+ model: torch.nn.Module,
69
+ x_cpu: torch.Tensor,
70
+ tile_size: int = 512,
71
+ overlap: int = 64,
72
+ batch_size: int = 1,
73
+ pad_multiple: int = 32,
74
+ pad_mode: str = "replicate",
75
+ ) -> torch.Tensor:
76
+ """
77
+ 执行分块推理并融合结果
78
+ x_cpu: [1, 3, H, W] 的 Tensor (CPU)
79
+ """
80
+ _, _, h, w = x_cpu.shape
81
+
82
+ # 1. 计算 Padding 后的尺寸
83
+ padded_h = _ceil_to_multiple(max(h, tile_size), pad_multiple)
84
+ padded_w = _ceil_to_multiple(max(w, tile_size), pad_multiple)
85
+
86
+ pad_h = padded_h - h
87
+ pad_w = padded_w - w
88
+ if pad_h or pad_w:
89
+ x_cpu = F.pad(x_cpu, (0, pad_w, 0, pad_h), mode=pad_mode)
90
+
91
+ # 2. 计算切片坐标
92
+ stride = tile_size - overlap
93
+ y_starts = _build_starts(padded_h, tile_size, stride)
94
+ x_starts = _build_starts(padded_w, tile_size, stride)
95
+
96
+ y_weights = _precompute_axis_weights(y_starts, tile_size, overlap)
97
+ x_weights = _precompute_axis_weights(x_starts, tile_size, overlap)
98
+
99
+ # 3. 初始化累加器和权重图
100
+ # 注意:这里假设输出是 3 通道 (RGB),如果你确认只输出单通道 Mask,可以改这里为 1
101
+ # 但根据 infer_hd.py 的逻辑,它初始化为 x_cpu.shape[1] 即 3
102
+ channels = x_cpu.shape[1]
103
+ accum = torch.zeros((1, channels, padded_h, padded_w), dtype=torch.float32)
104
+ weight = torch.zeros((1, 1, padded_h, padded_w), dtype=torch.float32)
105
+
106
+ coords = []
107
+ for yi, yy in enumerate(y_starts):
108
+ for xi, xx in enumerate(x_starts):
109
+ coords.append((yy, xx, yi, xi))
110
+
111
+ # 4. 批量推理
112
+ # model 已经在外部被移动到了 device
113
+ with torch.inference_mode():
114
+ for i in range(0, len(coords), batch_size):
115
+ chunk = coords[i : i + batch_size]
116
+
117
+ # 提取 Batch Tiles
118
+ tiles = torch.stack(
119
+ [x_cpu[0, :, yy : yy + tile_size, xx : xx + tile_size] for (yy, xx, _, _) in chunk],
120
+ dim=0,
121
+ ).to(device)
122
+
123
+ # 推理
124
+ pred = model(tiles).float().detach().cpu() # [B, C, tile, tile]
125
+
126
+ # 累加结果 (带权重)
127
+ for bi, (yy, xx, yi, xi) in enumerate(chunk):
128
+ wy = y_weights[yi]
129
+ wx = x_weights[xi]
130
+ # 构建权重矩阵 [1, 1, tile, tile]
131
+ m = (wy[:, None] * wx[None, :]).unsqueeze(0).unsqueeze(0)
132
+
133
+ accum[:, :, yy : yy + tile_size, xx : xx + tile_size] += pred[bi : bi + 1] * m
134
+ weight[:, :, yy : yy + tile_size, xx : xx + tile_size] += m
135
+
136
+ # 5. 归一化并裁剪
137
+ out = (accum / weight.clamp_min(1e-8)).clamp(0, 1)
138
+ return out[:, :, :h, :w]
139
+
140
+
141
+ # ==========================================
142
+ # FastAPI 逻辑
143
+ # ==========================================
144
 
145
  def load_model() -> UnetPlusPlus:
146
  """加載模型"""
 
147
  path = Path(MODEL_PATH)
 
148
  cfg = json.loads((path / "config.json").read_text(encoding="utf-8"))
149
+
150
+ model = UnetPlusPlus(
151
  encoder_name=cfg.get("encoder_name", "resnet50"),
152
+ encoder_weights=None, # 注意:如果需要加载预训练权重,需在此处处理
153
  in_channels=int(cfg.get("in_channels", 3)),
154
  classes=int(cfg.get("classes", 3)),
155
  decoder_attention_type=cfg.get("decoder_attention_type"),
156
  activation=cfg.get("activation", "sigmoid"),
157
  )
158
+
159
+ # 如果有本地权重文件 (参考 infer_hd.py 中的 model.safetensors)
160
+ weights_path = path / "model.safetensors"
161
+ if weights_path.exists():
162
+ try:
163
+ from safetensors.torch import load_file
164
+ state_dict = load_file(str(weights_path))
165
+ # 简单的 key 过滤,防止不匹配
166
+ model_keys = set(model.state_dict().keys())
167
+ filtered_dict = {k: v for k, v in state_dict.items() if k in model_keys}
168
+ model.load_state_dict(filtered_dict, strict=False)
169
+ print(f"Loaded weights from {weights_path}")
170
+ except Exception as e:
171
+ print(f"Failed to load weights: {e}")
172
+
173
+ model.to(device)
174
+ model.eval()
175
+ return model
176
 
177
 
178
  @asynccontextmanager
179
  async def lifespan(instance: FastAPI):
 
 
 
 
 
180
  instance.state.model = load_model()
 
 
181
  yield
182
 
183
 
 
187
  @app.post("/predict")
188
  async def predict(request: Request, file: UploadFile = File(...)):
189
  """
190
+ 笔迹擦除 (使用 Tiling + Overlap)
 
 
 
191
  """
 
192
  content = await file.read()
 
193
  nparr = np.frombuffer(content, np.uint8)
194
+
195
+ # 1. OpenCV 解码 -> BGR
196
+ img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
197
+ # RGB
198
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
199
+
 
 
200
  model = request.app.state.model
 
201
 
202
  def _inference_logic():
203
+ # 2. 预处理: NumPy (H, W, C) -> Tensor (1, C, H, W) 且归一化到 [0, 1]
204
+ # 参考脚本使用的是 TF.to_tensor,它会把 uint8 除以 255 转 float
205
+ input_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
206
+ input_tensor = input_tensor.unsqueeze(0) # [1, 3, H, W]
207
+
208
+ # 3. 执行分块推理
209
+ output_tensor = _tiled_infer(
210
+ model=model,
211
+ x_cpu=input_tensor,
212
+ tile_size=TILE_SIZE,
213
+ overlap=OVERLAP,
214
+ batch_size=1, # 显存够大可以调大
215
+ pad_mode="replicate"
216
+ )
217
+
218
+ # 4. 后处理: Tensor (1, C, H, W) -> NumPy (H, W, C) [0, 255]
219
+ output_tensor = output_tensor.squeeze(0).permute(1, 2, 0) # [H, W, C]
220
+ output_np = (output_tensor.numpy() * 255).astype(np.uint8)
221
+
222
+ return output_np
223
+
224
+ # 執行推理 (在线程池中运行 CPU 密集型操作)
225
+ result_rgb = await run_sync(_inference_logic)
226
+
227
+ # 5. 转回 BGR 以便 OpenCV 编码
228
+ result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
229
+
230
+ # 编码返回
231
+ success, encoded_image = cv2.imencode(".png", result_bgr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return Response(content=encoded_image.tobytes(), media_type="image/png")
233
 
234
 
235
  @app.get("/")
236
  def greet_json():
 
 
 
237
  return {"Hello": "World!"}
238
 
239
 
240
  if __name__ == '__main__':
241
  import uvicorn
 
242
  uvicorn.run("app:app", host="0.0.0.0", port=8000)