NguyenDinhHieu's picture
Add files using upload-large-folder tool
2bd5c7a verified
import cv2
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import shutil
import os
import uuid
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
# Giả định rằng file u2net.py chứa class U2NET và file u2net.pth
# nằm cùng thư mục với main.py hoặc trong PYTHONPATH.
# Ví dụ: from your_project.u2net import U2NET
# Nếu u2net.py là file định nghĩa class, bạn có thể import trực tiếp:
try:
from u2net import U2NET # Model definition
except ImportError:
print("Lỗi: Không tìm thấy file u2net.py định nghĩa class U2NET.")
print("Hãy đảm bảo file u2net.py (chứa class U2NET) nằm trong cùng thư mục hoặc PYTHONPATH.")
exit()
# --- Khởi tạo FastAPI app ---
app = FastAPI(title="VITON-Extends API")
# --- Cấu hình và tải mô hình U2NET ---
# Thực hiện một lần khi ứng dụng khởi động
U2NET_MODEL = None
_API_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
U2NET_MODEL_PATH = os.path.join(_API_SCRIPT_DIR, 'checkpoints', 'u2net.pth')
def load_u2net_model():
global U2NET_MODEL
if not os.path.exists(U2NET_MODEL_PATH):
print(f"Lỗi: Không tìm thấy file trọng số U2NET tại '{U2NET_MODEL_PATH}'.")
print("Hãy đảm bảo file u2net.pth nằm đúng vị trí.")
# Không exit() ở đây để FastAPI vẫn có thể khởi động và báo lỗi qua API nếu cần
# Hoặc bạn có thể quyết định exit() nếu U2NET là bắt buộc.
return False
try:
U2NET_MODEL = U2NET(3, 1) # 3 kênh đầu vào (RGB), 1 kênh đầu ra (mask)
# Sử dụng torch.device để đảm bảo tương thích CPU/GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
U2NET_MODEL.load_state_dict(torch.load(U2NET_MODEL_PATH, map_location=device))
U2NET_MODEL.to(device)
U2NET_MODEL.eval()
print(f"Đã tải thành công mô hình U2NET lên {device}.")
return True
except Exception as e:
print(f"Lỗi khi tải mô hình U2NET: {e}")
U2NET_MODEL = None # Đảm bảo model là None nếu tải lỗi
return False
# Gọi hàm tải model khi ứng dụng khởi động
# FastAPI sẽ chạy hàm này trong sự kiện startup nếu bạn dùng @app.on_event("startup")
# Tuy nhiên, để đơn giản, ta gọi trực tiếp. Nếu có lỗi, các endpoint sẽ kiểm tra U2NET_MODEL.
MODEL_LOADED_SUCCESSFULLY = load_u2net_model()
# --- Các hàm xử lý ảnh với U2NET (tương tự unet.py) ---
def _preprocess_for_u2net(pil_image: Image.Image) -> torch.Tensor:
"""Tiền xử lý ảnh PIL cho đầu vào U2NET."""
transform = transforms.Compose([
transforms.Resize((256, 192)), # (H, W) theo convention của U2NET trong unet.py
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(pil_image).unsqueeze(0)
def generate_person_mask_from_image(
image_path: str,
output_mask_path: str,
device: torch.device
) -> str:
"""
Tạo mặt nạ cho người từ ảnh đầu vào bằng U2NET.
Mặt nạ được resize về kích thước ảnh gốc và lưu dưới dạng ảnh grayscale.
"""
if U2NET_MODEL is None:
raise RuntimeError("Mô hình U2NET chưa được tải thành công.")
pil_image = Image.open(image_path).convert('RGB')
original_w, original_h = pil_image.size
input_tensor = _preprocess_for_u2net(pil_image)
input_tensor = input_tensor.to(device) # Chuyển tensor sang device của model
with torch.no_grad():
# U2NET thường trả về nhiều output ở các scale khác nhau, d1 là output chính
d1, *_ = U2NET_MODEL(input_tensor)
pred = d1[:,0,:,:] # Lấy mask từ output, shape: (1, H_u2net, W_u2net)
pred = pred.squeeze().cpu().numpy() # Chuyển về numpy array trên CPU, shape: (H_u2net, W_u2net)
# Chuẩn hóa giá trị của mask về khoảng [0, 1]
pred_min = pred.min()
pred_max = pred.max()
if pred_max - pred_min > 1e-8: # Tránh chia cho 0
pred = (pred - pred_min) / (pred_max - pred_min)
else:
pred = np.zeros_like(pred) # Nếu ảnh đầu vào đồng màu, mask có thể là 0
# Resize mask về kích thước ảnh gốc
# cv2.resize yêu cầu dsize là (width, height)
mask_resized = cv2.resize(pred, (original_w, original_h), interpolation=cv2.INTER_LINEAR)
# Lưu mask dưới dạng ảnh grayscale (giá trị 0-255)
cv2.imwrite(output_mask_path, (mask_resized * 255).astype(np.uint8))
return output_mask_path
def apply_mask_and_prepare_person_image(
original_image_path: str,
person_mask_path: str, # Đường dẫn đến file mask grayscale (0-255)
output_image_path: str,
target_size: tuple = (192, 256) # (W, H) cho ảnh output cuối cùng của người
) -> str:
"""
Áp dụng mặt nạ lên ảnh gốc (nền trắng) và resize về kích thước mục tiêu.
"""
original_bgr = cv2.imread(original_image_path)
if original_bgr is None:
raise FileNotFoundError(f"Không tìm thấy ảnh gốc: {original_image_path}")
original_rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)
mask_gray = cv2.imread(person_mask_path, cv2.IMREAD_GRAYSCALE)
if mask_gray is None:
raise FileNotFoundError(f"Không tìm thấy ảnh mặt nạ: {person_mask_path}")
# Đảm bảo mask có cùng kích thước với ảnh gốc
if mask_gray.shape[:2] != original_rgb.shape[:2]:
mask_gray = cv2.resize(mask_gray, (original_rgb.shape[1], original_rgb.shape[0]),
interpolation=cv2.INTER_LINEAR)
# Chuẩn hóa mask về khoảng [0, 1]
mask_float = mask_gray / 255.0
# Mở rộng mask thành 3 kênh để áp dụng cho ảnh RGB
mask_3channel = np.repeat(np.expand_dims(mask_float, axis=2), 3, axis=2)
# Tạo ảnh nền trắng
white_background_rgb = np.full_like(original_rgb, 255, dtype=np.uint8)
# Áp dụng công thức: result = foreground * mask + background * (1 - mask)
composited_rgb = (original_rgb.astype(float) * mask_3channel + \
white_background_rgb.astype(float) * (1 - mask_3channel))
composited_uint8 = np.clip(composited_rgb, 0, 255).astype(np.uint8)
# Resize ảnh đã xử lý về kích thước mục tiêu (W, H)
# cv2.resize yêu cầu dsize là (width, height)
resized_image_rgb = cv2.resize(composited_uint8, target_size, interpolation=cv2.INTER_AREA)
# Chuyển lại BGR để lưu bằng OpenCV
output_bgr = cv2.cvtColor(resized_image_rgb, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_image_path, output_bgr)
return output_image_path
# --- Placeholder cho mô hình VITON-Extends ---
def run_viton_try_on_model(
person_image_path: str, # Ảnh người đã xử lý (nền trắng, resized)
clothing_image_path: str, # Ảnh trang phục
person_mask_path: str, # Mặt nạ người (có thể dùng làm "edge")
output_dir: str
) -> str:
"""
Hàm giả lập việc gọi mô hình VITON-Extends.
Trong thực tế, bạn sẽ thay thế phần này bằng code gọi mô hình của bạn.
"""
print(f"Gọi mô hình VITON (giả lập) với:")
print(f" - Ảnh người: {person_image_path}")
print(f" - Ảnh trang phục: {clothing_image_path}")
print(f" - Mặt nạ người ('edge'): {person_mask_path}")
# Tạo một file output giả lập
dummy_output_name = f"viton_result_{uuid.uuid4().hex}.png"
dummy_output_path = os.path.join(output_dir, dummy_output_name)
# Ví dụ: copy ảnh người làm kết quả giả lập
if os.path.exists(person_image_path):
shutil.copy(person_image_path, dummy_output_path)
print(f"Mô hình VITON (giả lập) đã lưu kết quả tại: {dummy_output_path}")
return dummy_output_path
else:
# Tạo ảnh trống nếu không có ảnh người
error_img = np.zeros((256, 192, 3), dtype=np.uint8)
cv2.putText(error_img, "VITON Error", (10,128), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),1)
cv2.imwrite(dummy_output_path, error_img)
print(f"Lỗi: Không tìm thấy ảnh người để tạo output giả lập. Đã tạo ảnh lỗi.")
return dummy_output_path
# --- Thiết lập thư mục tạm ---
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
TEMP_DIR = os.path.join(BASE_DIR, "temp_files_viton_api")
os.makedirs(TEMP_DIR, exist_ok=True)
# Thư mục cho các file input gốc
UPLOADED_FILES_DIR = os.path.join(TEMP_DIR, "uploads")
os.makedirs(UPLOADED_FILES_DIR, exist_ok=True)
# Thư mục cho output của U2NET (ảnh người đã xử lý, mặt nạ người)
UNET_OUTPUT_DIR = os.path.join(TEMP_DIR, "unet_outputs")
os.makedirs(UNET_OUTPUT_DIR, exist_ok=True)
# Thư mục cho output cuối cùng của VITON
VITON_OUTPUT_DIR = os.path.join(TEMP_DIR, "viton_results")
os.makedirs(VITON_OUTPUT_DIR, exist_ok=True)
# --- API Endpoint ---
@app.post("/virtual-try-on/",
summary="Thực hiện thử đồ ảo",
description="Tải lên ảnh người và ảnh trang phục. API sẽ trả về ảnh người mặc trang phục đó.")
async def virtual_try_on_endpoint(
person_image: UploadFile = File(..., description="Ảnh người dùng (định dạng JPG, PNG)."),
clothing_image: UploadFile = File(..., description="Ảnh trang phục (định dạng JPG, PNG).")
):
if not MODEL_LOADED_SUCCESSFULLY or U2NET_MODEL is None:
return {"error": "Mô hình U2NET chưa sẵn sàng hoặc tải lỗi. Vui lòng kiểm tra console server."}
request_id = uuid.uuid4().hex
# --- 1. Lưu ảnh tải lên ---
person_image_name = f"{request_id}_person{os.path.splitext(person_image.filename)[1]}"
clothing_image_name = f"{request_id}_clothing{os.path.splitext(clothing_image.filename)[1]}"
original_person_image_path = os.path.join(UPLOADED_FILES_DIR, person_image_name)
original_clothing_image_path = os.path.join(UPLOADED_FILES_DIR, clothing_image_name)
try:
with open(original_person_image_path, "wb") as buffer:
shutil.copyfileobj(person_image.file, buffer)
with open(original_clothing_image_path, "wb") as buffer:
shutil.copyfileobj(clothing_image.file, buffer)
except Exception as e:
return {"error": f"Lỗi khi lưu file tải lên: {e}"}
finally:
person_image.file.close()
clothing_image.file.close()
# --- 2. Xử lý ảnh người bằng U2NET ---
# Đường dẫn cho mặt nạ người (person mask)
person_mask_filename = f"{request_id}_person_mask.png"
generated_person_mask_path = os.path.join(UNET_OUTPUT_DIR, person_mask_filename)
# Đường dẫn cho ảnh người đã xử lý (nền trắng, resized) - đây sẽ là "img" cho VITON
# Kích thước (192, 256) WxH như trong unet.py gốc
processed_person_image_filename = f"{request_id}_person_processed.png"
processed_person_image_path = os.path.join(UNET_OUTPUT_DIR, processed_person_image_filename)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
print(f"[{request_id}] Bắt đầu tạo mặt nạ cho: {original_person_image_path}")
actual_person_mask_path = generate_person_mask_from_image(
original_person_image_path,
generated_person_mask_path,
device
)
print(f"[{request_id}] Đã tạo mặt nạ người tại: {actual_person_mask_path}")
print(f"[{request_id}] Bắt đầu xử lý ảnh người: {original_person_image_path}")
# "img" cho VITON (192W x 256H, nền trắng)
actual_processed_person_image_path = apply_mask_and_prepare_person_image(
original_image_path=original_person_image_path,
person_mask_path=actual_person_mask_path,
output_image_path=processed_person_image_path,
target_size=(192, 256) # (W, H)
)
print(f"[{request_id}] Đã xử lý ảnh người tại: {actual_processed_person_image_path}")
except FileNotFoundError as e:
return {"error": f"Lỗi file trong quá trình xử lý U2NET: {e}"}
except RuntimeError as e: # Bắt lỗi Runtime từ U2NET (ví dụ model chưa tải)
return {"error": f"Lỗi Runtime U2NET: {e}"}
except Exception as e:
return {"error": f"Lỗi không xác định trong quá trình xử lý U2NET: {e}"}
# --- 3. Gọi mô hình VITON-Extends (Placeholder) ---
# Đầu vào cho mô hình VITON:
# - img: actual_processed_person_image_path (ảnh người 192x256, nền trắng)
# - clothes: original_clothing_image_path (ảnh trang phục gốc)
# - edge: actual_person_mask_path (mặt nạ người, kích thước gốc)
try:
print(f"[{request_id}] Bắt đầu gọi mô hình VITON (giả lập)...")
final_try_on_image_path = run_viton_try_on_model(
person_image_path=actual_processed_person_image_path,
clothing_image_path=original_clothing_image_path,
person_mask_path=actual_person_mask_path, # "edge" là mặt nạ người
output_dir=VITON_OUTPUT_DIR
)
print(f"[{request_id}] Mô hình VITON (giả lập) hoàn tất. Kết quả: {final_try_on_image_path}")
if not os.path.exists(final_try_on_image_path):
return {"error": "Mô hình VITON không tạo ra file output."}
# Trả về file ảnh kết quả
return FileResponse(final_try_on_image_path, media_type="image/png")
except Exception as e:
# Cân nhắc dọn dẹp file tạm ở đây nếu cần
return {"error": f"Lỗi trong quá trình chạy mô hình VITON: {e}"}
# --- Chạy FastAPI app (ví dụ với uvicorn) ---
# Để chạy: mở terminal, cd vào thư mục chứa file này và chạy:
# uvicorn main:app --reload
# (main là tên file .py, app là tên biến FastAPI instance)
if __name__ == "__main__":
import uvicorn
print("Để chạy API, sử dụng lệnh: uvicorn main:app --reload --host 0.0.0.0 --port 8000")
uvicorn.run(app, host="0.0.0.0", port=8000) # Bỏ comment để chạy trực tiếp khi thực thi file