JasonFinley0821's picture
feat : deblur add new module
5e0a683
# services/deblur.py
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from models.fpn_inception import FPNInception
# =====================
# 初始化模型
# =====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔹 [DeblurService] Using device: {device}")
checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
G.load_state_dict(checkpoint["G"], strict=False)
G.eval()
print(f"✅ [DeblurService] Model loaded from {checkpoint_path}")
# ✅ 去模糊函式
def deblur_image_tiled(img, tile_size=512, overlap=32):
w, h = img.size
new_w = (w // 32) * 32
new_h = (h // 32) * 32
if new_w != w or new_h != h:
img = img.resize((new_w, new_h), Image.BICUBIC)
w, h = new_w, new_h
img_np = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
stride = tile_size - overlap
tiles_x = list(range(0, w, stride))
tiles_y = list(range(0, h, stride))
if tiles_x[-1] + tile_size > w:
tiles_x[-1] = w - tile_size
if tiles_y[-1] + tile_size > h:
tiles_y[-1] = h - tile_size
output = torch.zeros_like(img_tensor)
weight = torch.zeros_like(img_tensor)
with torch.no_grad():
for y in tiles_y:
for x in tiles_x:
patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
pred = G(patch) # ✅ 改用 services 裡的模型 G
output[:, :, y:y+tile_size, x:x+tile_size] += pred
weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
output /= weight
output = torch.clamp(output, 0, 1)
out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
return Image.fromarray(out_np)