JasonFinley0821 commited on
Commit
cda20d5
·
1 Parent(s): 36e825c

feat: add predict api and load model

Browse files
Files changed (3) hide show
  1. app.py +93 -3
  2. app_DeblurGan_PyTorch.py +159 -0
  3. models/fpn_inception.py +167 -0
app.py CHANGED
@@ -1,11 +1,80 @@
1
  from fastapi import FastAPI, Request, Response
2
  from fastapi.responses import JSONResponse
3
 
4
- app = FastAPI()
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @app.get("/")
7
  def root():
8
- return {"Hello": "World!"}
9
 
10
  @app.get("/greetjson")
11
  def greet_json(request: Request, response: Response):
@@ -17,4 +86,25 @@ def greet_json(request: Request, response: Response):
17
  response.headers["X-Custom-Header"] = "HelloHeader"
18
 
19
  # 回傳 JSON
20
- return JSONResponse(content={"message": "Hello World", "client": client_host})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request, Response
2
  from fastapi.responses import JSONResponse
3
 
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import io
9
+ import numpy as np
10
+ import os
11
 
12
+ from models.fpn_inception import FPNInception # 你自己的模型類別
13
+
14
+ # =====================
15
+ # 初始化模型
16
+ # =====================
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(f"🔹 Using device: {device}")
19
+
20
+ checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
21
+
22
+ G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
23
+ checkpoint = torch.load(checkpoint_path, map_location=device)
24
+ G.load_state_dict(checkpoint["G"], strict=False)
25
+ G.eval()
26
+ print("✅ Model loaded from", checkpoint_path)
27
+
28
+ # =====================
29
+ # Tile-based 推論函式
30
+ # =====================
31
+ def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
32
+ model.eval()
33
+ w, h = img.size
34
+ new_w = (w // 32) * 32
35
+ new_h = (h // 32) * 32
36
+ if new_w != w or new_h != h:
37
+ img = img.resize((new_w, new_h), Image.BICUBIC)
38
+ w, h = new_w, new_h
39
+
40
+ img_np = np.array(img).astype(np.float32) / 255.0
41
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
42
+
43
+ stride = tile_size - overlap
44
+ tiles_x = list(range(0, w, stride))
45
+ tiles_y = list(range(0, h, stride))
46
+ if tiles_x[-1] + tile_size > w:
47
+ tiles_x[-1] = w - tile_size
48
+ if tiles_y[-1] + tile_size > h:
49
+ tiles_y[-1] = h - tile_size
50
+
51
+ output = torch.zeros_like(img_tensor)
52
+ weight = torch.zeros_like(img_tensor)
53
+
54
+ with torch.no_grad():
55
+ for y in tiles_y:
56
+ for x in tiles_x:
57
+ patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
58
+ pred = model(patch)
59
+ output[:, :, y:y+tile_size, x:x+tile_size] += pred
60
+ weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
61
+
62
+ output /= weight
63
+ output = torch.clamp(output, 0, 1)
64
+ out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
65
+ return Image.fromarray(out_np)
66
+
67
+ # =====================
68
+ # 初始化 FastAPI
69
+ # =====================
70
+ app = FastAPI(title="DeblurGANv2 API")
71
+
72
+ # =====================
73
+ # API 路由
74
+ # =====================
75
  @app.get("/")
76
  def root():
77
+ return {"message": "DeblurGANv2 API ready!"}
78
 
79
  @app.get("/greetjson")
80
  def greet_json(request: Request, response: Response):
 
86
  response.headers["X-Custom-Header"] = "HelloHeader"
87
 
88
  # 回傳 JSON
89
+ return JSONResponse(content={"message": "Hello World", "client": client_host})
90
+
91
+ @app.post("/predict")
92
+ async def predict(file: UploadFile = File(...)):
93
+ try:
94
+ # 讀取上傳圖片
95
+ contents = await file.read()
96
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
97
+
98
+ # 去模糊
99
+ result = deblur_image_tiled(G, img, device)
100
+
101
+ # 輸出為 bytes
102
+ img_byte_arr = io.BytesIO()
103
+ result.save(img_byte_arr, format="PNG")
104
+ img_byte_arr.seek(0)
105
+
106
+ # 直接回傳圖片
107
+ return StreamingResponse(img_byte_arr, media_type="image/png")
108
+
109
+ except Exception as e:
110
+ return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
app_DeblurGan_PyTorch.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Oct 16 12:05:42 2025
4
+
5
+ @author: ittraining
6
+ """
7
+
8
+ # -*- coding: utf-8 -*-
9
+ """
10
+ Use PyTorch DeblurGAN-v2 (.pth) to deblur images with Tkinter UI
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ import torch.nn as nn
16
+ import numpy as np
17
+ from PIL import Image, ImageTk
18
+ from torchvision import transforms
19
+ import tkinter as tk
20
+ from tkinter import filedialog
21
+
22
+ # ======== 模型定義區 ========
23
+ from models.fpn_inception import FPNInception # 你需確認這個檔案存在
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print(f"🔹 Using device: {device}")
27
+
28
+ # 模型 checkpoint 路徑
29
+ checkpoint_dir = os.path.join(os.getcwd(), "model")
30
+ ckpt_path = os.path.join(checkpoint_dir, "deblurgan_v2_latest.pth")
31
+
32
+ # 初始化模型
33
+ G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
34
+ checkpoint = torch.load(ckpt_path, map_location=device)
35
+ G.load_state_dict(checkpoint["G"], strict=False)
36
+ G.eval()
37
+ print("✅ Model loaded from", ckpt_path)
38
+
39
+
40
+ # ======== Tile-based 推論函式 ========
41
+ def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
42
+ """
43
+ 用 tile-based 方法在 GPU 記憶體有限時推論整張大圖。
44
+ Args:
45
+ model: 已載入權重的 DeblurGAN-v2 Generator
46
+ img: 要處理的影像
47
+ device: torch.device("cuda" or "cpu")
48
+ tile_size: 每塊大小(建議 512)
49
+ overlap: 重疊區域像素數(建議 16~64)
50
+ """
51
+ model.eval()
52
+
53
+ # ---- 預處理 ----
54
+ w, h = img.size
55
+
56
+ # 確保為 32 倍數
57
+ new_w = (w // 32) * 32
58
+ new_h = (h // 32) * 32
59
+ if new_w != w or new_h != h:
60
+ img = img.resize((new_w, new_h), Image.BICUBIC)
61
+ w, h = new_w, new_h
62
+
63
+ img_np = np.array(img).astype(np.float32) / 255.0
64
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
65
+
66
+ # ---- 計算 tile 網格 ----
67
+ stride = tile_size - overlap
68
+ tiles_x = list(range(0, w, stride))
69
+ tiles_y = list(range(0, h, stride))
70
+ if tiles_x[-1] + tile_size > w:
71
+ tiles_x[-1] = w - tile_size
72
+ if tiles_y[-1] + tile_size > h:
73
+ tiles_y[-1] = h - tile_size
74
+
75
+ # ---- 準備空白輸出與權重 ----
76
+ output = torch.zeros_like(img_tensor)
77
+ weight = torch.zeros_like(img_tensor)
78
+
79
+ with torch.no_grad():
80
+ for y in tiles_y:
81
+ for x in tiles_x:
82
+ patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
83
+ pred = model(patch)
84
+
85
+ # 疊加到對應位置
86
+ output[:, :, y:y+tile_size, x:x+tile_size] += pred
87
+ weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
88
+
89
+ # ---- 平均化(避免重疊區域過曝)----
90
+ output /= weight
91
+ output = torch.clamp(output, 0, 1)
92
+
93
+ # ---- 轉回圖片 ----
94
+ out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
95
+ return Image.fromarray(out_np)
96
+
97
+
98
+ # ======== 封裝成 DeblurModel 類 ========
99
+ class DeblurModel:
100
+ def __init__(self, model):
101
+ self.model = model
102
+
103
+ def predict(self, image_path):
104
+ img = Image.open(image_path).convert("RGB")
105
+ out_img = deblur_image_tiled(self.model, img, device, tile_size=512, overlap=32)
106
+ return out_img
107
+
108
+
109
+ # ======== Tkinter GUI ========
110
+ class ImageViewerApp:
111
+ def __init__(self, root):
112
+ self.root = root
113
+ self.root.title("AI Image Deblurring Viewer (PyTorch)")
114
+ self.root.geometry("1500x700")
115
+ self.create_gui()
116
+ self.model = DeblurModel(G)
117
+
118
+ def create_gui(self):
119
+ label_font = ("Helvetica", 16)
120
+ self.browse_button = tk.Button(
121
+ self.root, text="Browse Image", command=self.browse_image, font=label_font
122
+ )
123
+
124
+ self.canvas_original = tk.Canvas(self.root, width=480, height=420, bg="lightgray")
125
+ self.canvas_result = tk.Canvas(self.root, width=480, height=420, bg="lightgray")
126
+ self.result_label = tk.Label(self.root, text="", font=("Helvetica", 18, "bold"), fg="blue")
127
+
128
+ self.browse_button.grid(row=0, column=0, columnspan=2, pady=10)
129
+ self.canvas_original.grid(row=1, column=0, padx=10, pady=10)
130
+ self.canvas_result.grid(row=1, column=1, padx=10, pady=10)
131
+ self.result_label.grid(row=2, column=0, columnspan=2, pady=10)
132
+
133
+ def browse_image(self):
134
+ file_path = filedialog.askopenfilename(
135
+ filetypes=[("Image files", "*.jpg *.jpeg *.png *.gif *.bmp *.tif")]
136
+ )
137
+ if file_path:
138
+ self.display_images(file_path)
139
+
140
+ def display_images(self, image_path):
141
+ img = Image.open(image_path)
142
+ img.thumbnail((480, 420))
143
+ photo = ImageTk.PhotoImage(img)
144
+ self.canvas_original.create_image(0, 0, anchor="nw", image=photo)
145
+ self.canvas_original.image = photo
146
+
147
+ result_img = self.model.predict(image_path)
148
+ result_img.thumbnail((480, 420))
149
+ photo_result = ImageTk.PhotoImage(result_img)
150
+ self.canvas_result.create_image(0, 0, anchor="nw", image=photo_result)
151
+ self.canvas_result.image = photo_result
152
+
153
+ self.result_label.config(text=f"File: {os.path.basename(image_path)} → Deblurred by DeblurGAN-v2")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ root = tk.Tk()
158
+ app = ImageViewerApp(root)
159
+ root.mainloop()
models/fpn_inception.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchsummary import summary
4
+ from pretrainedmodels import inceptionresnetv2
5
+ import torch.nn.functional as F
6
+
7
+ class FPNHead(nn.Module):
8
+ def __init__(self, num_in, num_mid, num_out):
9
+ super().__init__()
10
+
11
+ self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
12
+ self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
13
+
14
+ def forward(self, x):
15
+ x = nn.functional.relu(self.block0(x), inplace=True)
16
+ x = nn.functional.relu(self.block1(x), inplace=True)
17
+ return x
18
+
19
+ class ConvBlock(nn.Module):
20
+ def __init__(self, num_in, num_out, norm_layer):
21
+ super().__init__()
22
+
23
+ self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
24
+ norm_layer(num_out),
25
+ nn.ReLU(inplace=True))
26
+
27
+ def forward(self, x):
28
+ x = self.block(x)
29
+ return x
30
+
31
+
32
+ class FPNInception(nn.Module):
33
+
34
+ def __init__(self, norm_layer=nn.InstanceNorm2d, output_ch=3, num_filters=128, num_filters_fpn=256):
35
+ super().__init__()
36
+
37
+ # Feature Pyramid Network (FPN) with four feature maps of resolutions
38
+ # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
39
+ self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
40
+
41
+ # The segmentation heads on top of the FPN
42
+
43
+ self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
44
+ self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
45
+ self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
46
+ self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
47
+
48
+ self.smooth = nn.Sequential(
49
+ nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
50
+ norm_layer(num_filters),
51
+ nn.ReLU(),
52
+ )
53
+
54
+ self.smooth2 = nn.Sequential(
55
+ nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
56
+ norm_layer(num_filters // 2),
57
+ nn.ReLU(),
58
+ )
59
+
60
+ self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
61
+
62
+ def unfreeze(self):
63
+ self.fpn.unfreeze()
64
+
65
+ def forward(self, x):
66
+ map0, map1, map2, map3, map4 = self.fpn(x)
67
+
68
+ map4 = nn.functional.interpolate(self.head4(map4), scale_factor=8, mode="nearest")
69
+ map3 = nn.functional.interpolate(self.head3(map3), scale_factor=4, mode="nearest")
70
+ map2 = nn.functional.interpolate(self.head2(map2), scale_factor=2, mode="nearest")
71
+ map1 = nn.functional.interpolate(self.head1(map1), scale_factor=1, mode="nearest")
72
+
73
+ smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
74
+ smoothed = nn.functional.interpolate(smoothed, scale_factor=2, mode="nearest")
75
+ smoothed = self.smooth2(smoothed + map0)
76
+ smoothed = nn.functional.interpolate(smoothed, scale_factor=2, mode="nearest")
77
+
78
+ final = self.final(smoothed)
79
+ res = torch.tanh(final) + x
80
+
81
+ return torch.clamp(res, min = -1,max = 1)
82
+
83
+
84
+ class FPN(nn.Module):
85
+
86
+ def __init__(self, norm_layer, num_filters=256):
87
+ """Creates an `FPN` instance for feature extraction.
88
+ Args:
89
+ num_filters: the number of filters in each output pyramid level
90
+ pretrained: use ImageNet pre-trained backbone feature extractor
91
+ """
92
+
93
+ super().__init__()
94
+ self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
95
+
96
+ self.enc0 = self.inception.conv2d_1a
97
+ self.enc1 = nn.Sequential(
98
+ self.inception.conv2d_2a,
99
+ self.inception.conv2d_2b,
100
+ self.inception.maxpool_3a,
101
+ ) # 64
102
+ self.enc2 = nn.Sequential(
103
+ self.inception.conv2d_3b,
104
+ self.inception.conv2d_4a,
105
+ self.inception.maxpool_5a,
106
+ ) # 192
107
+ self.enc3 = nn.Sequential(
108
+ self.inception.mixed_5b,
109
+ self.inception.repeat,
110
+ self.inception.mixed_6a,
111
+ ) # 1088
112
+ self.enc4 = nn.Sequential(
113
+ self.inception.repeat_1,
114
+ self.inception.mixed_7a,
115
+ ) #2080
116
+ self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
117
+ norm_layer(num_filters),
118
+ nn.ReLU(inplace=True))
119
+ self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
120
+ norm_layer(num_filters),
121
+ nn.ReLU(inplace=True))
122
+ self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
123
+ norm_layer(num_filters),
124
+ nn.ReLU(inplace=True))
125
+ self.pad = nn.ReflectionPad2d(1)
126
+ self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
127
+ self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
128
+ self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
129
+ self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
130
+ self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
131
+
132
+ for param in self.inception.parameters():
133
+ param.requires_grad = False
134
+
135
+ def unfreeze(self):
136
+ for param in self.inception.parameters():
137
+ param.requires_grad = True
138
+
139
+ def forward(self, x):
140
+
141
+ # Bottom-up pathway, from ResNet
142
+ enc0 = self.enc0(x)
143
+
144
+ enc1 = self.enc1(enc0) # 256
145
+
146
+ enc2 = self.enc2(enc1) # 512
147
+
148
+ enc3 = self.enc3(enc2) # 1024
149
+
150
+ enc4 = self.enc4(enc3) # 2048
151
+
152
+ # Lateral connections
153
+
154
+ lateral4 = self.pad(self.lateral4(enc4))
155
+ lateral3 = self.pad(self.lateral3(enc3))
156
+ lateral2 = self.lateral2(enc2)
157
+ lateral1 = self.pad(self.lateral1(enc1))
158
+ lateral0 = self.lateral0(enc0)
159
+
160
+ # Top-down pathway
161
+ pad = (1, 2, 1, 2) # pad last dim by 1 on each side
162
+ pad1 = (0, 1, 0, 1)
163
+ map4 = lateral4
164
+ map3 = self.td1(lateral3 + nn.functional.interpolate(map4, scale_factor=2, mode="nearest"))
165
+ map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.interpolate(map3, scale_factor=2, mode="nearest"))
166
+ map1 = self.td3(lateral1 + nn.functional.interpolate(map2, scale_factor=2, mode="nearest"))
167
+ return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4