File size: 1,950 Bytes
5e0a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# services/deblur.py
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image

from models.fpn_inception import FPNInception

# =====================
# 初始化模型
# =====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔹 [DeblurService] Using device: {device}")

checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)

checkpoint = torch.load(checkpoint_path, map_location=device)
G.load_state_dict(checkpoint["G"], strict=False)
G.eval()
print(f"✅ [DeblurService] Model loaded from {checkpoint_path}")

# ✅ 去模糊函式
def deblur_image_tiled(img, tile_size=512, overlap=32):
    w, h = img.size
    new_w = (w // 32) * 32
    new_h = (h // 32) * 32
    if new_w != w or new_h != h:
        img = img.resize((new_w, new_h), Image.BICUBIC)
        w, h = new_w, new_h

    img_np = np.array(img).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)

    stride = tile_size - overlap
    tiles_x = list(range(0, w, stride))
    tiles_y = list(range(0, h, stride))
    if tiles_x[-1] + tile_size > w:
        tiles_x[-1] = w - tile_size
    if tiles_y[-1] + tile_size > h:
        tiles_y[-1] = h - tile_size

    output = torch.zeros_like(img_tensor)
    weight = torch.zeros_like(img_tensor)

    with torch.no_grad():
        for y in tiles_y:
            for x in tiles_x:
                patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
                pred = G(patch)  # ✅ 改用 services 裡的模型 G
                output[:, :, y:y+tile_size, x:x+tile_size] += pred
                weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0

    output /= weight
    output = torch.clamp(output, 0, 1)
    out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
    return Image.fromarray(out_np)