# -*- coding: utf-8 -*- """ Created on Thu Oct 16 12:05:42 2025 @author: ittraining """ # -*- coding: utf-8 -*- """ Use PyTorch DeblurGAN-v2 (.pth) to deblur images with Tkinter UI """ import os import torch import torch.nn as nn import numpy as np from PIL import Image, ImageTk from torchvision import transforms import tkinter as tk from tkinter import filedialog # ======== 模型定義區 ======== from models.fpn_inception import FPNInception # 你需確認這個檔案存在 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🔹 Using device: {device}") # 模型 checkpoint 路徑 checkpoint_dir = os.path.join(os.getcwd(), "model") ckpt_path = os.path.join(checkpoint_dir, "deblurgan_v2_latest_L1_vgg_D_jason_r20.pth") # 初始化模型 G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device) checkpoint = torch.load(ckpt_path, map_location=device) G.load_state_dict(checkpoint["G"], strict=False) num_epoch = checkpoint["epoch"] + 1 G.eval() print("✅ num epoch", num_epoch) print("✅ Model loaded from", ckpt_path) # ======== Tile-based 推論函式 ======== def deblur_image_tiled(model, img, device, tile_size=512, overlap=32): """ 用 tile-based 方法在 GPU 記憶體有限時推論整張大圖。 Args: model: 已載入權重的 DeblurGAN-v2 Generator img: 要處理的影像 device: torch.device("cuda" or "cpu") tile_size: 每塊大小(建議 512) overlap: 重疊區域像素數(建議 16~64) """ model.eval() # ---- 預處理 ---- w, h = img.size # 確保為 32 倍數 new_w = (w // 32) * 32 new_h = (h // 32) * 32 if new_w != w or new_h != h: img = img.resize((new_w, new_h), Image.BICUBIC) w, h = new_w, new_h img_np = np.array(img).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device) # ---- 計算 tile 網格 ---- stride = tile_size - overlap tiles_x = list(range(0, w, stride)) tiles_y = list(range(0, h, stride)) if tiles_x[-1] + tile_size > w: tiles_x[-1] = w - tile_size if tiles_y[-1] + tile_size > h: tiles_y[-1] = h - tile_size # ---- 準備空白輸出與權重 ---- output = torch.zeros_like(img_tensor) weight = torch.zeros_like(img_tensor) with torch.no_grad(): for y in tiles_y: for x in tiles_x: patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size] pred = model(patch) # 疊加到對應位置 output[:, :, y:y+tile_size, x:x+tile_size] += pred weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0 # ---- 平均化(避免重疊區域過曝)---- output /= weight output = torch.clamp(output, 0, 1) # ---- 轉回圖片 ---- out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) return Image.fromarray(out_np) # ======== 封裝成 DeblurModel 類 ======== class DeblurModel: def __init__(self, model): self.model = model def predict(self, image_path): img = Image.open(image_path).convert("RGB") out_img = deblur_image_tiled(self.model, img, device, tile_size=512, overlap=32) return out_img # ======== Tkinter GUI ======== class ImageViewerApp: def __init__(self, root): self.root = root self.root.title("AI Image Deblurring Viewer (PyTorch)") self.root.geometry("1500x700") self.create_gui() self.model = DeblurModel(G) def create_gui(self): label_font = ("Helvetica", 16) self.browse_button = tk.Button( self.root, text="Browse Image", command=self.browse_image, font=label_font ) self.canvas_original = tk.Canvas(self.root, width=480, height=420, bg="lightgray") self.canvas_result = tk.Canvas(self.root, width=480, height=420, bg="lightgray") self.result_label = tk.Label(self.root, text="", font=("Helvetica", 18, "bold"), fg="blue") self.browse_button.grid(row=0, column=0, columnspan=2, pady=10) self.canvas_original.grid(row=1, column=0, padx=10, pady=10) self.canvas_result.grid(row=1, column=1, padx=10, pady=10) self.result_label.grid(row=2, column=0, columnspan=2, pady=10) def browse_image(self): file_path = filedialog.askopenfilename( filetypes=[("Image files", "*.jpg *.jpeg *.png *.gif *.bmp *.tif")] ) if file_path: self.display_images(file_path) def display_images(self, image_path): img = Image.open(image_path) img.thumbnail((480, 420)) photo = ImageTk.PhotoImage(img) self.canvas_original.create_image(0, 0, anchor="nw", image=photo) self.canvas_original.image = photo result_img = self.model.predict(image_path) saveimg_dir = os.path.join(os.getcwd(), "rimg") #saveimg_path = os.path.join(saveimg_dir, "result_fullsize.png") #result_img.save(saveimg_path) result_img.thumbnail((480, 420)) photo_result = ImageTk.PhotoImage(result_img) self.canvas_result.create_image(0, 0, anchor="nw", image=photo_result) self.canvas_result.image = photo_result self.result_label.config(text=f"File: {os.path.basename(image_path)} → Deblurred by DeblurGAN-v2") if __name__ == "__main__": root = tk.Tk() app = ImageViewerApp(root) root.mainloop()