Spaces:
Sleeping
Sleeping
File size: 5,540 Bytes
cda20d5 319d52a cda20d5 319d52a cda20d5 319d52a cda20d5 319d52a cda20d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | # -*- coding: utf-8 -*-
"""
Created on Thu Oct 16 12:05:42 2025
@author: ittraining
"""
# -*- coding: utf-8 -*-
"""
Use PyTorch DeblurGAN-v2 (.pth) to deblur images with Tkinter UI
"""
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image, ImageTk
from torchvision import transforms
import tkinter as tk
from tkinter import filedialog
# ======== 模型定義區 ========
from models.fpn_inception import FPNInception # 你需確認這個檔案存在
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔹 Using device: {device}")
# 模型 checkpoint 路徑
checkpoint_dir = os.path.join(os.getcwd(), "model")
ckpt_path = os.path.join(checkpoint_dir, "deblurgan_v2_latest_L1_vgg_D_jason_r20.pth")
# 初始化模型
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
checkpoint = torch.load(ckpt_path, map_location=device)
G.load_state_dict(checkpoint["G"], strict=False)
num_epoch = checkpoint["epoch"] + 1
G.eval()
print("✅ num epoch", num_epoch)
print("✅ Model loaded from", ckpt_path)
# ======== Tile-based 推論函式 ========
def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
"""
用 tile-based 方法在 GPU 記憶體有限時推論整張大圖。
Args:
model: 已載入權重的 DeblurGAN-v2 Generator
img: 要處理的影像
device: torch.device("cuda" or "cpu")
tile_size: 每塊大小(建議 512)
overlap: 重疊區域像素數(建議 16~64)
"""
model.eval()
# ---- 預處理 ----
w, h = img.size
# 確保為 32 倍數
new_w = (w // 32) * 32
new_h = (h // 32) * 32
if new_w != w or new_h != h:
img = img.resize((new_w, new_h), Image.BICUBIC)
w, h = new_w, new_h
img_np = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
# ---- 計算 tile 網格 ----
stride = tile_size - overlap
tiles_x = list(range(0, w, stride))
tiles_y = list(range(0, h, stride))
if tiles_x[-1] + tile_size > w:
tiles_x[-1] = w - tile_size
if tiles_y[-1] + tile_size > h:
tiles_y[-1] = h - tile_size
# ---- 準備空白輸出與權重 ----
output = torch.zeros_like(img_tensor)
weight = torch.zeros_like(img_tensor)
with torch.no_grad():
for y in tiles_y:
for x in tiles_x:
patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
pred = model(patch)
# 疊加到對應位置
output[:, :, y:y+tile_size, x:x+tile_size] += pred
weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
# ---- 平均化(避免重疊區域過曝)----
output /= weight
output = torch.clamp(output, 0, 1)
# ---- 轉回圖片 ----
out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
return Image.fromarray(out_np)
# ======== 封裝成 DeblurModel 類 ========
class DeblurModel:
def __init__(self, model):
self.model = model
def predict(self, image_path):
img = Image.open(image_path).convert("RGB")
out_img = deblur_image_tiled(self.model, img, device, tile_size=512, overlap=32)
return out_img
# ======== Tkinter GUI ========
class ImageViewerApp:
def __init__(self, root):
self.root = root
self.root.title("AI Image Deblurring Viewer (PyTorch)")
self.root.geometry("1500x700")
self.create_gui()
self.model = DeblurModel(G)
def create_gui(self):
label_font = ("Helvetica", 16)
self.browse_button = tk.Button(
self.root, text="Browse Image", command=self.browse_image, font=label_font
)
self.canvas_original = tk.Canvas(self.root, width=480, height=420, bg="lightgray")
self.canvas_result = tk.Canvas(self.root, width=480, height=420, bg="lightgray")
self.result_label = tk.Label(self.root, text="", font=("Helvetica", 18, "bold"), fg="blue")
self.browse_button.grid(row=0, column=0, columnspan=2, pady=10)
self.canvas_original.grid(row=1, column=0, padx=10, pady=10)
self.canvas_result.grid(row=1, column=1, padx=10, pady=10)
self.result_label.grid(row=2, column=0, columnspan=2, pady=10)
def browse_image(self):
file_path = filedialog.askopenfilename(
filetypes=[("Image files", "*.jpg *.jpeg *.png *.gif *.bmp *.tif")]
)
if file_path:
self.display_images(file_path)
def display_images(self, image_path):
img = Image.open(image_path)
img.thumbnail((480, 420))
photo = ImageTk.PhotoImage(img)
self.canvas_original.create_image(0, 0, anchor="nw", image=photo)
self.canvas_original.image = photo
result_img = self.model.predict(image_path)
saveimg_dir = os.path.join(os.getcwd(), "rimg")
#saveimg_path = os.path.join(saveimg_dir, "result_fullsize.png")
#result_img.save(saveimg_path)
result_img.thumbnail((480, 420))
photo_result = ImageTk.PhotoImage(result_img)
self.canvas_result.create_image(0, 0, anchor="nw", image=photo_result)
self.canvas_result.image = photo_result
self.result_label.config(text=f"File: {os.path.basename(image_path)} → Deblurred by DeblurGAN-v2")
if __name__ == "__main__":
root = tk.Tk()
app = ImageViewerApp(root)
root.mainloop()
|