JasonFinley0821 commited on
Commit
5e0a683
·
1 Parent(s): 2f75e66

feat : deblur add new module

Browse files
Files changed (5) hide show
  1. .gitignore +27 -0
  2. app.py +31 -78
  3. requirements.txt +1 -2
  4. services/agents.py +0 -0
  5. services/deblur.py +58 -0
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. 虛擬環境目錄
2
+ .venv
3
+ venv/
4
+ env/
5
+ /site-packages
6
+
7
+ # 2. Python 編譯快取
8
+ __pycache__/*
9
+ *.pyc
10
+ *.pyd
11
+ *.so
12
+
13
+ # 3. 測試和文件
14
+ .pytest_cache/
15
+ htmlcov/
16
+ .coverage
17
+
18
+ # 4. 編輯器/IDE 檔案 (可選,依據您使用的工具)
19
+ .idea/ # PyCharm/IntelliJ 專案檔案
20
+ .vscode/ # VS Code 設定 (如果不想共享)
21
+ *.swp # Vim 臨時檔案
22
+
23
+ # 5. 您的靜態/媒體檔案 (保持原樣,但使用更精確的模式)
24
+ /static/ # 忽略整個 static 資料夾
25
+ /logs/ # 日誌檔案
26
+ *.log
27
+ *.sqlite3 # 如果使用 SQLite 資料庫
app.py CHANGED
@@ -1,19 +1,17 @@
1
- from fastapi import FastAPI, Request, Response
2
  from fastapi.responses import JSONResponse
3
  from fastapi.staticfiles import StaticFiles
 
4
  import requests
 
 
 
5
 
6
- import traceback
7
- import torch
8
- import torch.nn as nn
9
- from torchvision import transforms
10
  from PIL import Image
11
  import io
12
- import numpy as np
13
  import os
14
  from datetime import datetime
15
-
16
- from models.fpn_inception import FPNInception # 你自己的模型類別
17
 
18
  STATIC_DIR = "static"
19
 
@@ -23,71 +21,20 @@ os.environ["TRANSFORMERS_CACHE"] = "./.cache"
23
  os.makedirs("./.cache", exist_ok=True)
24
  os.makedirs(STATIC_DIR, exist_ok=True)
25
 
26
- # =====================
27
- # 初始化模型
28
- # =====================
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- print(f"🔹 Using device: {device}")
31
-
32
- checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
33
-
34
- G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
35
-
36
- try:
37
- checkpoint = torch.load(checkpoint_path, map_location=device)
38
- except Exception as e:
39
- print("❌ checkpoint 載入錯誤:", e)
40
- traceback.print_exc()
41
-
42
- G.load_state_dict(checkpoint["G"], strict=False)
43
- G.eval()
44
- print("✅ Model loaded from", checkpoint_path)
45
-
46
- # =====================
47
- # Tile-based 推論函式
48
- # =====================
49
- def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
50
- model.eval()
51
- w, h = img.size
52
- new_w = (w // 32) * 32
53
- new_h = (h // 32) * 32
54
- if new_w != w or new_h != h:
55
- img = img.resize((new_w, new_h), Image.BICUBIC)
56
- w, h = new_w, new_h
57
-
58
- img_np = np.array(img).astype(np.float32) / 255.0
59
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
60
-
61
- stride = tile_size - overlap
62
- tiles_x = list(range(0, w, stride))
63
- tiles_y = list(range(0, h, stride))
64
- if tiles_x[-1] + tile_size > w:
65
- tiles_x[-1] = w - tile_size
66
- if tiles_y[-1] + tile_size > h:
67
- tiles_y[-1] = h - tile_size
68
-
69
- output = torch.zeros_like(img_tensor)
70
- weight = torch.zeros_like(img_tensor)
71
-
72
- with torch.no_grad():
73
- for y in tiles_y:
74
- for x in tiles_x:
75
- patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
76
- pred = model(patch)
77
- output[:, :, y:y+tile_size, x:x+tile_size] += pred
78
- weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
79
-
80
- output /= weight
81
- output = torch.clamp(output, 0, 1)
82
- out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
83
- return Image.fromarray(out_np)
84
-
85
  # =====================
86
  # 初始化 FastAPI
87
  # =====================
88
  app = FastAPI(title="DeblurGANv2 API")
89
 
90
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
 
 
 
 
 
 
 
 
91
 
92
  # =====================
93
  # API 路由
@@ -109,17 +56,19 @@ def greet_json(request: Request, response: Response):
109
  return JSONResponse(content={"message": "Hello World", "client": client_host})
110
 
111
  @app.post("/predict")
112
- async def predict( request: Request, response: Response ):
 
 
 
 
 
 
 
 
 
 
113
  try:
114
  print("### start /predict !!")
115
- # 1️⃣ 讀取 form-data
116
- form = await request.form()
117
- file_name = form.get("file_name")
118
- file_format = form.get("file_format")
119
- file_url = form.get("file_url")
120
- file_width = int(form.get("file_width", 0))
121
- file_height = int(form.get("file_height", 0))
122
- file_created_at = form.get("file_created_at")
123
 
124
  if not file_url:
125
  return JSONResponse(
@@ -133,7 +82,7 @@ async def predict( request: Request, response: Response ):
133
  img = Image.open(io.BytesIO(resp.content)).convert("RGB")
134
 
135
  # 3️⃣ 去模糊
136
- result = deblur_image_tiled(G, img, device)
137
 
138
  # 4️⃣ 產生檔名
139
  base_name = f"{file_name}_{file_width}_{file_height}_{file_created_at}.jpg"
@@ -159,4 +108,8 @@ async def predict( request: Request, response: Response ):
159
  traceback.print_exc()
160
  return JSONResponse(
161
  {"status": "error", "message": str(e)}, status_code=500
162
- )
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response, Form
2
  from fastapi.responses import JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware # 匯入 FastAPI 的 CORS 中介軟體
5
  import requests
6
+ from typing import Annotated # 推薦用於 Pydantic v2+
7
+
8
+ from services.deblur import deblur_image_tiled
9
 
 
 
 
 
10
  from PIL import Image
11
  import io
 
12
  import os
13
  from datetime import datetime
14
+ import uvicorn
 
15
 
16
  STATIC_DIR = "static"
17
 
 
21
  os.makedirs("./.cache", exist_ok=True)
22
  os.makedirs(STATIC_DIR, exist_ok=True)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # =====================
25
  # 初始化 FastAPI
26
  # =====================
27
  app = FastAPI(title="DeblurGANv2 API")
28
 
29
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
30
+ # 設定 CORS (跨來源資源共用)
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"], # 允許所有來源
34
+ allow_credentials=True, # 允許憑證
35
+ allow_methods=["*"], # 允許所有 HTTP 方法
36
+ allow_headers=["*"], # 允許所有 HTTP 標頭
37
+ )
38
 
39
  # =====================
40
  # API 路由
 
56
  return JSONResponse(content={"message": "Hello World", "client": client_host})
57
 
58
  @app.post("/predict")
59
+ async def predict(
60
+ request: Request,
61
+ # 將您的 form-data 欄位定義為函數參數,並使用 Form()
62
+ file_name: Annotated[str, Form(description="檔案名稱")],
63
+ file_format: Annotated[str, Form(description="檔案格式")],
64
+ file_url: Annotated[str, Form(description="檔案下載網址")], # 可選參數,有預設值
65
+ # 對於需要轉換類型 (例如 int) 的欄位,直接在類型提示中指定
66
+ file_width: Annotated[int, Form(description="檔案寬度")] = 0,
67
+ file_height: Annotated[int, Form(description="檔案高度")] = 0,
68
+ file_created_at: Annotated[str, Form(description="檔案建立時間")] = None,
69
+ ):
70
  try:
71
  print("### start /predict !!")
 
 
 
 
 
 
 
 
72
 
73
  if not file_url:
74
  return JSONResponse(
 
82
  img = Image.open(io.BytesIO(resp.content)).convert("RGB")
83
 
84
  # 3️⃣ 去模糊
85
+ result = deblur_image_tiled(img)
86
 
87
  # 4️⃣ 產生檔名
88
  base_name = f"{file_name}_{file_width}_{file_height}_{file_created_at}.jpg"
 
108
  traceback.print_exc()
109
  return JSONResponse(
110
  {"status": "error", "message": str(e)}, status_code=500
111
+ )
112
+
113
+
114
+ if __name__ == "__main__":
115
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- fastapi
2
  uvicorn[standard]
3
  torch
4
  torchvision
@@ -10,5 +10,4 @@ pytorch-msssim
10
  opencv-python
11
  tqdm
12
  torchsummary
13
- python-multipart
14
  requests
 
1
+ fastapi[all]
2
  uvicorn[standard]
3
  torch
4
  torchvision
 
10
  opencv-python
11
  tqdm
12
  torchsummary
 
13
  requests
services/agents.py ADDED
File without changes
services/deblur.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # services/deblur.py
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from models.fpn_inception import FPNInception
9
+
10
+ # =====================
11
+ # 初始化模型
12
+ # =====================
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"🔹 [DeblurService] Using device: {device}")
15
+
16
+ checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
17
+ G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
18
+
19
+ checkpoint = torch.load(checkpoint_path, map_location=device)
20
+ G.load_state_dict(checkpoint["G"], strict=False)
21
+ G.eval()
22
+ print(f"✅ [DeblurService] Model loaded from {checkpoint_path}")
23
+
24
+ # ✅ 去模糊函式
25
+ def deblur_image_tiled(img, tile_size=512, overlap=32):
26
+ w, h = img.size
27
+ new_w = (w // 32) * 32
28
+ new_h = (h // 32) * 32
29
+ if new_w != w or new_h != h:
30
+ img = img.resize((new_w, new_h), Image.BICUBIC)
31
+ w, h = new_w, new_h
32
+
33
+ img_np = np.array(img).astype(np.float32) / 255.0
34
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
35
+
36
+ stride = tile_size - overlap
37
+ tiles_x = list(range(0, w, stride))
38
+ tiles_y = list(range(0, h, stride))
39
+ if tiles_x[-1] + tile_size > w:
40
+ tiles_x[-1] = w - tile_size
41
+ if tiles_y[-1] + tile_size > h:
42
+ tiles_y[-1] = h - tile_size
43
+
44
+ output = torch.zeros_like(img_tensor)
45
+ weight = torch.zeros_like(img_tensor)
46
+
47
+ with torch.no_grad():
48
+ for y in tiles_y:
49
+ for x in tiles_x:
50
+ patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
51
+ pred = G(patch) # ✅ 改用 services 裡的模型 G
52
+ output[:, :, y:y+tile_size, x:x+tile_size] += pred
53
+ weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
54
+
55
+ output /= weight
56
+ output = torch.clamp(output, 0, 1)
57
+ out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
58
+ return Image.fromarray(out_np)