Spaces:
Sleeping
Sleeping
Commit ·
cda20d5
1
Parent(s): 36e825c
feat: add predict api and load model
Browse files- app.py +93 -3
- app_DeblurGan_PyTorch.py +159 -0
- models/fpn_inception.py +167 -0
app.py
CHANGED
|
@@ -1,11 +1,80 @@
|
|
| 1 |
from fastapi import FastAPI, Request, Response
|
| 2 |
from fastapi.responses import JSONResponse
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
@app.get("/")
|
| 7 |
def root():
|
| 8 |
-
return {"
|
| 9 |
|
| 10 |
@app.get("/greetjson")
|
| 11 |
def greet_json(request: Request, response: Response):
|
|
@@ -17,4 +86,25 @@ def greet_json(request: Request, response: Response):
|
|
| 17 |
response.headers["X-Custom-Header"] = "HelloHeader"
|
| 18 |
|
| 19 |
# 回傳 JSON
|
| 20 |
-
return JSONResponse(content={"message": "Hello World", "client": client_host})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI, Request, Response
|
| 2 |
from fastapi.responses import JSONResponse
|
| 3 |
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import io
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
+
from models.fpn_inception import FPNInception # 你自己的模型類別
|
| 13 |
+
|
| 14 |
+
# =====================
|
| 15 |
+
# 初始化模型
|
| 16 |
+
# =====================
|
| 17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
print(f"🔹 Using device: {device}")
|
| 19 |
+
|
| 20 |
+
checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
|
| 21 |
+
|
| 22 |
+
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
|
| 23 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 24 |
+
G.load_state_dict(checkpoint["G"], strict=False)
|
| 25 |
+
G.eval()
|
| 26 |
+
print("✅ Model loaded from", checkpoint_path)
|
| 27 |
+
|
| 28 |
+
# =====================
|
| 29 |
+
# Tile-based 推論函式
|
| 30 |
+
# =====================
|
| 31 |
+
def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
|
| 32 |
+
model.eval()
|
| 33 |
+
w, h = img.size
|
| 34 |
+
new_w = (w // 32) * 32
|
| 35 |
+
new_h = (h // 32) * 32
|
| 36 |
+
if new_w != w or new_h != h:
|
| 37 |
+
img = img.resize((new_w, new_h), Image.BICUBIC)
|
| 38 |
+
w, h = new_w, new_h
|
| 39 |
+
|
| 40 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
| 41 |
+
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 42 |
+
|
| 43 |
+
stride = tile_size - overlap
|
| 44 |
+
tiles_x = list(range(0, w, stride))
|
| 45 |
+
tiles_y = list(range(0, h, stride))
|
| 46 |
+
if tiles_x[-1] + tile_size > w:
|
| 47 |
+
tiles_x[-1] = w - tile_size
|
| 48 |
+
if tiles_y[-1] + tile_size > h:
|
| 49 |
+
tiles_y[-1] = h - tile_size
|
| 50 |
+
|
| 51 |
+
output = torch.zeros_like(img_tensor)
|
| 52 |
+
weight = torch.zeros_like(img_tensor)
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
for y in tiles_y:
|
| 56 |
+
for x in tiles_x:
|
| 57 |
+
patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
|
| 58 |
+
pred = model(patch)
|
| 59 |
+
output[:, :, y:y+tile_size, x:x+tile_size] += pred
|
| 60 |
+
weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
|
| 61 |
+
|
| 62 |
+
output /= weight
|
| 63 |
+
output = torch.clamp(output, 0, 1)
|
| 64 |
+
out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
|
| 65 |
+
return Image.fromarray(out_np)
|
| 66 |
+
|
| 67 |
+
# =====================
|
| 68 |
+
# 初始化 FastAPI
|
| 69 |
+
# =====================
|
| 70 |
+
app = FastAPI(title="DeblurGANv2 API")
|
| 71 |
+
|
| 72 |
+
# =====================
|
| 73 |
+
# API 路由
|
| 74 |
+
# =====================
|
| 75 |
@app.get("/")
|
| 76 |
def root():
|
| 77 |
+
return {"message": "DeblurGANv2 API ready!"}
|
| 78 |
|
| 79 |
@app.get("/greetjson")
|
| 80 |
def greet_json(request: Request, response: Response):
|
|
|
|
| 86 |
response.headers["X-Custom-Header"] = "HelloHeader"
|
| 87 |
|
| 88 |
# 回傳 JSON
|
| 89 |
+
return JSONResponse(content={"message": "Hello World", "client": client_host})
|
| 90 |
+
|
| 91 |
+
@app.post("/predict")
|
| 92 |
+
async def predict(file: UploadFile = File(...)):
|
| 93 |
+
try:
|
| 94 |
+
# 讀取上傳圖片
|
| 95 |
+
contents = await file.read()
|
| 96 |
+
img = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 97 |
+
|
| 98 |
+
# 去模糊
|
| 99 |
+
result = deblur_image_tiled(G, img, device)
|
| 100 |
+
|
| 101 |
+
# 輸出為 bytes
|
| 102 |
+
img_byte_arr = io.BytesIO()
|
| 103 |
+
result.save(img_byte_arr, format="PNG")
|
| 104 |
+
img_byte_arr.seek(0)
|
| 105 |
+
|
| 106 |
+
# 直接回傳圖片
|
| 107 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
|
app_DeblurGan_PyTorch.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Thu Oct 16 12:05:42 2025
|
| 4 |
+
|
| 5 |
+
@author: ittraining
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# -*- coding: utf-8 -*-
|
| 9 |
+
"""
|
| 10 |
+
Use PyTorch DeblurGAN-v2 (.pth) to deblur images with Tkinter UI
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import numpy as np
|
| 17 |
+
from PIL import Image, ImageTk
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
import tkinter as tk
|
| 20 |
+
from tkinter import filedialog
|
| 21 |
+
|
| 22 |
+
# ======== 模型定義區 ========
|
| 23 |
+
from models.fpn_inception import FPNInception # 你需確認這個檔案存在
|
| 24 |
+
|
| 25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
print(f"🔹 Using device: {device}")
|
| 27 |
+
|
| 28 |
+
# 模型 checkpoint 路徑
|
| 29 |
+
checkpoint_dir = os.path.join(os.getcwd(), "model")
|
| 30 |
+
ckpt_path = os.path.join(checkpoint_dir, "deblurgan_v2_latest.pth")
|
| 31 |
+
|
| 32 |
+
# 初始化模型
|
| 33 |
+
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
|
| 34 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 35 |
+
G.load_state_dict(checkpoint["G"], strict=False)
|
| 36 |
+
G.eval()
|
| 37 |
+
print("✅ Model loaded from", ckpt_path)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ======== Tile-based 推論函式 ========
|
| 41 |
+
def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
|
| 42 |
+
"""
|
| 43 |
+
用 tile-based 方法在 GPU 記憶體有限時推論整張大圖。
|
| 44 |
+
Args:
|
| 45 |
+
model: 已載入權重的 DeblurGAN-v2 Generator
|
| 46 |
+
img: 要處理的影像
|
| 47 |
+
device: torch.device("cuda" or "cpu")
|
| 48 |
+
tile_size: 每塊大小(建議 512)
|
| 49 |
+
overlap: 重疊區域像素數(建議 16~64)
|
| 50 |
+
"""
|
| 51 |
+
model.eval()
|
| 52 |
+
|
| 53 |
+
# ---- 預處理 ----
|
| 54 |
+
w, h = img.size
|
| 55 |
+
|
| 56 |
+
# 確保為 32 倍數
|
| 57 |
+
new_w = (w // 32) * 32
|
| 58 |
+
new_h = (h // 32) * 32
|
| 59 |
+
if new_w != w or new_h != h:
|
| 60 |
+
img = img.resize((new_w, new_h), Image.BICUBIC)
|
| 61 |
+
w, h = new_w, new_h
|
| 62 |
+
|
| 63 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
| 64 |
+
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 65 |
+
|
| 66 |
+
# ---- 計算 tile 網格 ----
|
| 67 |
+
stride = tile_size - overlap
|
| 68 |
+
tiles_x = list(range(0, w, stride))
|
| 69 |
+
tiles_y = list(range(0, h, stride))
|
| 70 |
+
if tiles_x[-1] + tile_size > w:
|
| 71 |
+
tiles_x[-1] = w - tile_size
|
| 72 |
+
if tiles_y[-1] + tile_size > h:
|
| 73 |
+
tiles_y[-1] = h - tile_size
|
| 74 |
+
|
| 75 |
+
# ---- 準備空白輸出與權重 ----
|
| 76 |
+
output = torch.zeros_like(img_tensor)
|
| 77 |
+
weight = torch.zeros_like(img_tensor)
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
for y in tiles_y:
|
| 81 |
+
for x in tiles_x:
|
| 82 |
+
patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
|
| 83 |
+
pred = model(patch)
|
| 84 |
+
|
| 85 |
+
# 疊加到對應位置
|
| 86 |
+
output[:, :, y:y+tile_size, x:x+tile_size] += pred
|
| 87 |
+
weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
|
| 88 |
+
|
| 89 |
+
# ---- 平均化(避免重疊區域過曝)----
|
| 90 |
+
output /= weight
|
| 91 |
+
output = torch.clamp(output, 0, 1)
|
| 92 |
+
|
| 93 |
+
# ---- 轉回圖片 ----
|
| 94 |
+
out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
|
| 95 |
+
return Image.fromarray(out_np)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ======== 封裝成 DeblurModel 類 ========
|
| 99 |
+
class DeblurModel:
|
| 100 |
+
def __init__(self, model):
|
| 101 |
+
self.model = model
|
| 102 |
+
|
| 103 |
+
def predict(self, image_path):
|
| 104 |
+
img = Image.open(image_path).convert("RGB")
|
| 105 |
+
out_img = deblur_image_tiled(self.model, img, device, tile_size=512, overlap=32)
|
| 106 |
+
return out_img
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ======== Tkinter GUI ========
|
| 110 |
+
class ImageViewerApp:
|
| 111 |
+
def __init__(self, root):
|
| 112 |
+
self.root = root
|
| 113 |
+
self.root.title("AI Image Deblurring Viewer (PyTorch)")
|
| 114 |
+
self.root.geometry("1500x700")
|
| 115 |
+
self.create_gui()
|
| 116 |
+
self.model = DeblurModel(G)
|
| 117 |
+
|
| 118 |
+
def create_gui(self):
|
| 119 |
+
label_font = ("Helvetica", 16)
|
| 120 |
+
self.browse_button = tk.Button(
|
| 121 |
+
self.root, text="Browse Image", command=self.browse_image, font=label_font
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.canvas_original = tk.Canvas(self.root, width=480, height=420, bg="lightgray")
|
| 125 |
+
self.canvas_result = tk.Canvas(self.root, width=480, height=420, bg="lightgray")
|
| 126 |
+
self.result_label = tk.Label(self.root, text="", font=("Helvetica", 18, "bold"), fg="blue")
|
| 127 |
+
|
| 128 |
+
self.browse_button.grid(row=0, column=0, columnspan=2, pady=10)
|
| 129 |
+
self.canvas_original.grid(row=1, column=0, padx=10, pady=10)
|
| 130 |
+
self.canvas_result.grid(row=1, column=1, padx=10, pady=10)
|
| 131 |
+
self.result_label.grid(row=2, column=0, columnspan=2, pady=10)
|
| 132 |
+
|
| 133 |
+
def browse_image(self):
|
| 134 |
+
file_path = filedialog.askopenfilename(
|
| 135 |
+
filetypes=[("Image files", "*.jpg *.jpeg *.png *.gif *.bmp *.tif")]
|
| 136 |
+
)
|
| 137 |
+
if file_path:
|
| 138 |
+
self.display_images(file_path)
|
| 139 |
+
|
| 140 |
+
def display_images(self, image_path):
|
| 141 |
+
img = Image.open(image_path)
|
| 142 |
+
img.thumbnail((480, 420))
|
| 143 |
+
photo = ImageTk.PhotoImage(img)
|
| 144 |
+
self.canvas_original.create_image(0, 0, anchor="nw", image=photo)
|
| 145 |
+
self.canvas_original.image = photo
|
| 146 |
+
|
| 147 |
+
result_img = self.model.predict(image_path)
|
| 148 |
+
result_img.thumbnail((480, 420))
|
| 149 |
+
photo_result = ImageTk.PhotoImage(result_img)
|
| 150 |
+
self.canvas_result.create_image(0, 0, anchor="nw", image=photo_result)
|
| 151 |
+
self.canvas_result.image = photo_result
|
| 152 |
+
|
| 153 |
+
self.result_label.config(text=f"File: {os.path.basename(image_path)} → Deblurred by DeblurGAN-v2")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
root = tk.Tk()
|
| 158 |
+
app = ImageViewerApp(root)
|
| 159 |
+
root.mainloop()
|
models/fpn_inception.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchsummary import summary
|
| 4 |
+
from pretrainedmodels import inceptionresnetv2
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
class FPNHead(nn.Module):
|
| 8 |
+
def __init__(self, num_in, num_mid, num_out):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
| 12 |
+
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
x = nn.functional.relu(self.block0(x), inplace=True)
|
| 16 |
+
x = nn.functional.relu(self.block1(x), inplace=True)
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
class ConvBlock(nn.Module):
|
| 20 |
+
def __init__(self, num_in, num_out, norm_layer):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
|
| 24 |
+
norm_layer(num_out),
|
| 25 |
+
nn.ReLU(inplace=True))
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.block(x)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FPNInception(nn.Module):
|
| 33 |
+
|
| 34 |
+
def __init__(self, norm_layer=nn.InstanceNorm2d, output_ch=3, num_filters=128, num_filters_fpn=256):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
| 38 |
+
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
| 39 |
+
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
|
| 40 |
+
|
| 41 |
+
# The segmentation heads on top of the FPN
|
| 42 |
+
|
| 43 |
+
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
| 44 |
+
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
| 45 |
+
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
| 46 |
+
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
| 47 |
+
|
| 48 |
+
self.smooth = nn.Sequential(
|
| 49 |
+
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
| 50 |
+
norm_layer(num_filters),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.smooth2 = nn.Sequential(
|
| 55 |
+
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
| 56 |
+
norm_layer(num_filters // 2),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
| 61 |
+
|
| 62 |
+
def unfreeze(self):
|
| 63 |
+
self.fpn.unfreeze()
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
map0, map1, map2, map3, map4 = self.fpn(x)
|
| 67 |
+
|
| 68 |
+
map4 = nn.functional.interpolate(self.head4(map4), scale_factor=8, mode="nearest")
|
| 69 |
+
map3 = nn.functional.interpolate(self.head3(map3), scale_factor=4, mode="nearest")
|
| 70 |
+
map2 = nn.functional.interpolate(self.head2(map2), scale_factor=2, mode="nearest")
|
| 71 |
+
map1 = nn.functional.interpolate(self.head1(map1), scale_factor=1, mode="nearest")
|
| 72 |
+
|
| 73 |
+
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
| 74 |
+
smoothed = nn.functional.interpolate(smoothed, scale_factor=2, mode="nearest")
|
| 75 |
+
smoothed = self.smooth2(smoothed + map0)
|
| 76 |
+
smoothed = nn.functional.interpolate(smoothed, scale_factor=2, mode="nearest")
|
| 77 |
+
|
| 78 |
+
final = self.final(smoothed)
|
| 79 |
+
res = torch.tanh(final) + x
|
| 80 |
+
|
| 81 |
+
return torch.clamp(res, min = -1,max = 1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class FPN(nn.Module):
|
| 85 |
+
|
| 86 |
+
def __init__(self, norm_layer, num_filters=256):
|
| 87 |
+
"""Creates an `FPN` instance for feature extraction.
|
| 88 |
+
Args:
|
| 89 |
+
num_filters: the number of filters in each output pyramid level
|
| 90 |
+
pretrained: use ImageNet pre-trained backbone feature extractor
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
| 95 |
+
|
| 96 |
+
self.enc0 = self.inception.conv2d_1a
|
| 97 |
+
self.enc1 = nn.Sequential(
|
| 98 |
+
self.inception.conv2d_2a,
|
| 99 |
+
self.inception.conv2d_2b,
|
| 100 |
+
self.inception.maxpool_3a,
|
| 101 |
+
) # 64
|
| 102 |
+
self.enc2 = nn.Sequential(
|
| 103 |
+
self.inception.conv2d_3b,
|
| 104 |
+
self.inception.conv2d_4a,
|
| 105 |
+
self.inception.maxpool_5a,
|
| 106 |
+
) # 192
|
| 107 |
+
self.enc3 = nn.Sequential(
|
| 108 |
+
self.inception.mixed_5b,
|
| 109 |
+
self.inception.repeat,
|
| 110 |
+
self.inception.mixed_6a,
|
| 111 |
+
) # 1088
|
| 112 |
+
self.enc4 = nn.Sequential(
|
| 113 |
+
self.inception.repeat_1,
|
| 114 |
+
self.inception.mixed_7a,
|
| 115 |
+
) #2080
|
| 116 |
+
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
| 117 |
+
norm_layer(num_filters),
|
| 118 |
+
nn.ReLU(inplace=True))
|
| 119 |
+
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
| 120 |
+
norm_layer(num_filters),
|
| 121 |
+
nn.ReLU(inplace=True))
|
| 122 |
+
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
| 123 |
+
norm_layer(num_filters),
|
| 124 |
+
nn.ReLU(inplace=True))
|
| 125 |
+
self.pad = nn.ReflectionPad2d(1)
|
| 126 |
+
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
|
| 127 |
+
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
|
| 128 |
+
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
|
| 129 |
+
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
| 130 |
+
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
|
| 131 |
+
|
| 132 |
+
for param in self.inception.parameters():
|
| 133 |
+
param.requires_grad = False
|
| 134 |
+
|
| 135 |
+
def unfreeze(self):
|
| 136 |
+
for param in self.inception.parameters():
|
| 137 |
+
param.requires_grad = True
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
|
| 141 |
+
# Bottom-up pathway, from ResNet
|
| 142 |
+
enc0 = self.enc0(x)
|
| 143 |
+
|
| 144 |
+
enc1 = self.enc1(enc0) # 256
|
| 145 |
+
|
| 146 |
+
enc2 = self.enc2(enc1) # 512
|
| 147 |
+
|
| 148 |
+
enc3 = self.enc3(enc2) # 1024
|
| 149 |
+
|
| 150 |
+
enc4 = self.enc4(enc3) # 2048
|
| 151 |
+
|
| 152 |
+
# Lateral connections
|
| 153 |
+
|
| 154 |
+
lateral4 = self.pad(self.lateral4(enc4))
|
| 155 |
+
lateral3 = self.pad(self.lateral3(enc3))
|
| 156 |
+
lateral2 = self.lateral2(enc2)
|
| 157 |
+
lateral1 = self.pad(self.lateral1(enc1))
|
| 158 |
+
lateral0 = self.lateral0(enc0)
|
| 159 |
+
|
| 160 |
+
# Top-down pathway
|
| 161 |
+
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
|
| 162 |
+
pad1 = (0, 1, 0, 1)
|
| 163 |
+
map4 = lateral4
|
| 164 |
+
map3 = self.td1(lateral3 + nn.functional.interpolate(map4, scale_factor=2, mode="nearest"))
|
| 165 |
+
map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.interpolate(map3, scale_factor=2, mode="nearest"))
|
| 166 |
+
map1 = self.td3(lateral1 + nn.functional.interpolate(map2, scale_factor=2, mode="nearest"))
|
| 167 |
+
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
|