Spaces:
Sleeping
Sleeping
Commit ·
5e0a683
1
Parent(s): 2f75e66
feat : deblur add new module
Browse files- .gitignore +27 -0
- app.py +31 -78
- requirements.txt +1 -2
- services/agents.py +0 -0
- services/deblur.py +58 -0
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. 虛擬環境目錄
|
| 2 |
+
.venv
|
| 3 |
+
venv/
|
| 4 |
+
env/
|
| 5 |
+
/site-packages
|
| 6 |
+
|
| 7 |
+
# 2. Python 編譯快取
|
| 8 |
+
__pycache__/*
|
| 9 |
+
*.pyc
|
| 10 |
+
*.pyd
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# 3. 測試和文件
|
| 14 |
+
.pytest_cache/
|
| 15 |
+
htmlcov/
|
| 16 |
+
.coverage
|
| 17 |
+
|
| 18 |
+
# 4. 編輯器/IDE 檔案 (可選,依據您使用的工具)
|
| 19 |
+
.idea/ # PyCharm/IntelliJ 專案檔案
|
| 20 |
+
.vscode/ # VS Code 設定 (如果不想共享)
|
| 21 |
+
*.swp # Vim 臨時檔案
|
| 22 |
+
|
| 23 |
+
# 5. 您的靜態/媒體檔案 (保持原樣,但使用更精確的模式)
|
| 24 |
+
/static/ # 忽略整個 static 資料夾
|
| 25 |
+
/logs/ # 日誌檔案
|
| 26 |
+
*.log
|
| 27 |
+
*.sqlite3 # 如果使用 SQLite 資料庫
|
app.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
-
from fastapi import FastAPI, Request, Response
|
| 2 |
from fastapi.responses import JSONResponse
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 4 |
import requests
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
import traceback
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
from torchvision import transforms
|
| 10 |
from PIL import Image
|
| 11 |
import io
|
| 12 |
-
import numpy as np
|
| 13 |
import os
|
| 14 |
from datetime import datetime
|
| 15 |
-
|
| 16 |
-
from models.fpn_inception import FPNInception # 你自己的模型類別
|
| 17 |
|
| 18 |
STATIC_DIR = "static"
|
| 19 |
|
|
@@ -23,71 +21,20 @@ os.environ["TRANSFORMERS_CACHE"] = "./.cache"
|
|
| 23 |
os.makedirs("./.cache", exist_ok=True)
|
| 24 |
os.makedirs(STATIC_DIR, exist_ok=True)
|
| 25 |
|
| 26 |
-
# =====================
|
| 27 |
-
# 初始化模型
|
| 28 |
-
# =====================
|
| 29 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
-
print(f"🔹 Using device: {device}")
|
| 31 |
-
|
| 32 |
-
checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
|
| 33 |
-
|
| 34 |
-
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
|
| 35 |
-
|
| 36 |
-
try:
|
| 37 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 38 |
-
except Exception as e:
|
| 39 |
-
print("❌ checkpoint 載入錯誤:", e)
|
| 40 |
-
traceback.print_exc()
|
| 41 |
-
|
| 42 |
-
G.load_state_dict(checkpoint["G"], strict=False)
|
| 43 |
-
G.eval()
|
| 44 |
-
print("✅ Model loaded from", checkpoint_path)
|
| 45 |
-
|
| 46 |
-
# =====================
|
| 47 |
-
# Tile-based 推論函式
|
| 48 |
-
# =====================
|
| 49 |
-
def deblur_image_tiled(model, img, device, tile_size=512, overlap=32):
|
| 50 |
-
model.eval()
|
| 51 |
-
w, h = img.size
|
| 52 |
-
new_w = (w // 32) * 32
|
| 53 |
-
new_h = (h // 32) * 32
|
| 54 |
-
if new_w != w or new_h != h:
|
| 55 |
-
img = img.resize((new_w, new_h), Image.BICUBIC)
|
| 56 |
-
w, h = new_w, new_h
|
| 57 |
-
|
| 58 |
-
img_np = np.array(img).astype(np.float32) / 255.0
|
| 59 |
-
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 60 |
-
|
| 61 |
-
stride = tile_size - overlap
|
| 62 |
-
tiles_x = list(range(0, w, stride))
|
| 63 |
-
tiles_y = list(range(0, h, stride))
|
| 64 |
-
if tiles_x[-1] + tile_size > w:
|
| 65 |
-
tiles_x[-1] = w - tile_size
|
| 66 |
-
if tiles_y[-1] + tile_size > h:
|
| 67 |
-
tiles_y[-1] = h - tile_size
|
| 68 |
-
|
| 69 |
-
output = torch.zeros_like(img_tensor)
|
| 70 |
-
weight = torch.zeros_like(img_tensor)
|
| 71 |
-
|
| 72 |
-
with torch.no_grad():
|
| 73 |
-
for y in tiles_y:
|
| 74 |
-
for x in tiles_x:
|
| 75 |
-
patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
|
| 76 |
-
pred = model(patch)
|
| 77 |
-
output[:, :, y:y+tile_size, x:x+tile_size] += pred
|
| 78 |
-
weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
|
| 79 |
-
|
| 80 |
-
output /= weight
|
| 81 |
-
output = torch.clamp(output, 0, 1)
|
| 82 |
-
out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
|
| 83 |
-
return Image.fromarray(out_np)
|
| 84 |
-
|
| 85 |
# =====================
|
| 86 |
# 初始化 FastAPI
|
| 87 |
# =====================
|
| 88 |
app = FastAPI(title="DeblurGANv2 API")
|
| 89 |
|
| 90 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# =====================
|
| 93 |
# API 路由
|
|
@@ -109,17 +56,19 @@ def greet_json(request: Request, response: Response):
|
|
| 109 |
return JSONResponse(content={"message": "Hello World", "client": client_host})
|
| 110 |
|
| 111 |
@app.post("/predict")
|
| 112 |
-
async def predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
try:
|
| 114 |
print("### start /predict !!")
|
| 115 |
-
# 1️⃣ 讀取 form-data
|
| 116 |
-
form = await request.form()
|
| 117 |
-
file_name = form.get("file_name")
|
| 118 |
-
file_format = form.get("file_format")
|
| 119 |
-
file_url = form.get("file_url")
|
| 120 |
-
file_width = int(form.get("file_width", 0))
|
| 121 |
-
file_height = int(form.get("file_height", 0))
|
| 122 |
-
file_created_at = form.get("file_created_at")
|
| 123 |
|
| 124 |
if not file_url:
|
| 125 |
return JSONResponse(
|
|
@@ -133,7 +82,7 @@ async def predict( request: Request, response: Response ):
|
|
| 133 |
img = Image.open(io.BytesIO(resp.content)).convert("RGB")
|
| 134 |
|
| 135 |
# 3️⃣ 去模糊
|
| 136 |
-
result = deblur_image_tiled(
|
| 137 |
|
| 138 |
# 4️⃣ 產生檔名
|
| 139 |
base_name = f"{file_name}_{file_width}_{file_height}_{file_created_at}.jpg"
|
|
@@ -159,4 +108,8 @@ async def predict( request: Request, response: Response ):
|
|
| 159 |
traceback.print_exc()
|
| 160 |
return JSONResponse(
|
| 161 |
{"status": "error", "message": str(e)}, status_code=500
|
| 162 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request, Response, Form
|
| 2 |
from fastapi.responses import JSONResponse
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware # 匯入 FastAPI 的 CORS 中介軟體
|
| 5 |
import requests
|
| 6 |
+
from typing import Annotated # 推薦用於 Pydantic v2+
|
| 7 |
+
|
| 8 |
+
from services.deblur import deblur_image_tiled
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
import io
|
|
|
|
| 12 |
import os
|
| 13 |
from datetime import datetime
|
| 14 |
+
import uvicorn
|
|
|
|
| 15 |
|
| 16 |
STATIC_DIR = "static"
|
| 17 |
|
|
|
|
| 21 |
os.makedirs("./.cache", exist_ok=True)
|
| 22 |
os.makedirs(STATIC_DIR, exist_ok=True)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# =====================
|
| 25 |
# 初始化 FastAPI
|
| 26 |
# =====================
|
| 27 |
app = FastAPI(title="DeblurGANv2 API")
|
| 28 |
|
| 29 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
| 30 |
+
# 設定 CORS (跨來源資源共用)
|
| 31 |
+
app.add_middleware(
|
| 32 |
+
CORSMiddleware,
|
| 33 |
+
allow_origins=["*"], # 允許所有來源
|
| 34 |
+
allow_credentials=True, # 允許憑證
|
| 35 |
+
allow_methods=["*"], # 允許所有 HTTP 方法
|
| 36 |
+
allow_headers=["*"], # 允許所有 HTTP 標頭
|
| 37 |
+
)
|
| 38 |
|
| 39 |
# =====================
|
| 40 |
# API 路由
|
|
|
|
| 56 |
return JSONResponse(content={"message": "Hello World", "client": client_host})
|
| 57 |
|
| 58 |
@app.post("/predict")
|
| 59 |
+
async def predict(
|
| 60 |
+
request: Request,
|
| 61 |
+
# 將您的 form-data 欄位定義為函數參數,並使用 Form()
|
| 62 |
+
file_name: Annotated[str, Form(description="檔案名稱")],
|
| 63 |
+
file_format: Annotated[str, Form(description="檔案格式")],
|
| 64 |
+
file_url: Annotated[str, Form(description="檔案下載網址")], # 可選參數,有預設值
|
| 65 |
+
# 對於需要轉換類型 (例如 int) 的欄位,直接在類型提示中指定
|
| 66 |
+
file_width: Annotated[int, Form(description="檔案寬度")] = 0,
|
| 67 |
+
file_height: Annotated[int, Form(description="檔案高度")] = 0,
|
| 68 |
+
file_created_at: Annotated[str, Form(description="檔案建立時間")] = None,
|
| 69 |
+
):
|
| 70 |
try:
|
| 71 |
print("### start /predict !!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
if not file_url:
|
| 74 |
return JSONResponse(
|
|
|
|
| 82 |
img = Image.open(io.BytesIO(resp.content)).convert("RGB")
|
| 83 |
|
| 84 |
# 3️⃣ 去模糊
|
| 85 |
+
result = deblur_image_tiled(img)
|
| 86 |
|
| 87 |
# 4️⃣ 產生檔名
|
| 88 |
base_name = f"{file_name}_{file_width}_{file_height}_{file_created_at}.jpg"
|
|
|
|
| 108 |
traceback.print_exc()
|
| 109 |
return JSONResponse(
|
| 110 |
{"status": "error", "message": str(e)}, status_code=500
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
torch
|
| 4 |
torchvision
|
|
@@ -10,5 +10,4 @@ pytorch-msssim
|
|
| 10 |
opencv-python
|
| 11 |
tqdm
|
| 12 |
torchsummary
|
| 13 |
-
python-multipart
|
| 14 |
requests
|
|
|
|
| 1 |
+
fastapi[all]
|
| 2 |
uvicorn[standard]
|
| 3 |
torch
|
| 4 |
torchvision
|
|
|
|
| 10 |
opencv-python
|
| 11 |
tqdm
|
| 12 |
torchsummary
|
|
|
|
| 13 |
requests
|
services/agents.py
ADDED
|
File without changes
|
services/deblur.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# services/deblur.py
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from models.fpn_inception import FPNInception
|
| 9 |
+
|
| 10 |
+
# =====================
|
| 11 |
+
# 初始化模型
|
| 12 |
+
# =====================
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
print(f"🔹 [DeblurService] Using device: {device}")
|
| 15 |
+
|
| 16 |
+
checkpoint_path = os.path.join("model", "deblurgan_v2_latest.pth")
|
| 17 |
+
G = FPNInception(norm_layer=nn.InstanceNorm2d).to(device)
|
| 18 |
+
|
| 19 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 20 |
+
G.load_state_dict(checkpoint["G"], strict=False)
|
| 21 |
+
G.eval()
|
| 22 |
+
print(f"✅ [DeblurService] Model loaded from {checkpoint_path}")
|
| 23 |
+
|
| 24 |
+
# ✅ 去模糊函式
|
| 25 |
+
def deblur_image_tiled(img, tile_size=512, overlap=32):
|
| 26 |
+
w, h = img.size
|
| 27 |
+
new_w = (w // 32) * 32
|
| 28 |
+
new_h = (h // 32) * 32
|
| 29 |
+
if new_w != w or new_h != h:
|
| 30 |
+
img = img.resize((new_w, new_h), Image.BICUBIC)
|
| 31 |
+
w, h = new_w, new_h
|
| 32 |
+
|
| 33 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
| 34 |
+
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 35 |
+
|
| 36 |
+
stride = tile_size - overlap
|
| 37 |
+
tiles_x = list(range(0, w, stride))
|
| 38 |
+
tiles_y = list(range(0, h, stride))
|
| 39 |
+
if tiles_x[-1] + tile_size > w:
|
| 40 |
+
tiles_x[-1] = w - tile_size
|
| 41 |
+
if tiles_y[-1] + tile_size > h:
|
| 42 |
+
tiles_y[-1] = h - tile_size
|
| 43 |
+
|
| 44 |
+
output = torch.zeros_like(img_tensor)
|
| 45 |
+
weight = torch.zeros_like(img_tensor)
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
for y in tiles_y:
|
| 49 |
+
for x in tiles_x:
|
| 50 |
+
patch = img_tensor[:, :, y:y+tile_size, x:x+tile_size]
|
| 51 |
+
pred = G(patch) # ✅ 改用 services 裡的模型 G
|
| 52 |
+
output[:, :, y:y+tile_size, x:x+tile_size] += pred
|
| 53 |
+
weight[:, :, y:y+tile_size, x:x+tile_size] += 1.0
|
| 54 |
+
|
| 55 |
+
output /= weight
|
| 56 |
+
output = torch.clamp(output, 0, 1)
|
| 57 |
+
out_np = (output.squeeze().permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
|
| 58 |
+
return Image.fromarray(out_np)
|