import cv2 import torch import numpy as np from torchvision import transforms from PIL import Image import shutil import os import uuid from fastapi import FastAPI, File, UploadFile from fastapi.responses import FileResponse # Giả định rằng file u2net.py chứa class U2NET và file u2net.pth # nằm cùng thư mục với main.py hoặc trong PYTHONPATH. # Ví dụ: from your_project.u2net import U2NET # Nếu u2net.py là file định nghĩa class, bạn có thể import trực tiếp: try: from u2net import U2NET # Model definition except ImportError: print("Lỗi: Không tìm thấy file u2net.py định nghĩa class U2NET.") print("Hãy đảm bảo file u2net.py (chứa class U2NET) nằm trong cùng thư mục hoặc PYTHONPATH.") exit() # --- Khởi tạo FastAPI app --- app = FastAPI(title="VITON-Extends API") # --- Cấu hình và tải mô hình U2NET --- # Thực hiện một lần khi ứng dụng khởi động U2NET_MODEL = None _API_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) U2NET_MODEL_PATH = os.path.join(_API_SCRIPT_DIR, 'checkpoints', 'u2net.pth') def load_u2net_model(): global U2NET_MODEL if not os.path.exists(U2NET_MODEL_PATH): print(f"Lỗi: Không tìm thấy file trọng số U2NET tại '{U2NET_MODEL_PATH}'.") print("Hãy đảm bảo file u2net.pth nằm đúng vị trí.") # Không exit() ở đây để FastAPI vẫn có thể khởi động và báo lỗi qua API nếu cần # Hoặc bạn có thể quyết định exit() nếu U2NET là bắt buộc. return False try: U2NET_MODEL = U2NET(3, 1) # 3 kênh đầu vào (RGB), 1 kênh đầu ra (mask) # Sử dụng torch.device để đảm bảo tương thích CPU/GPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') U2NET_MODEL.load_state_dict(torch.load(U2NET_MODEL_PATH, map_location=device)) U2NET_MODEL.to(device) U2NET_MODEL.eval() print(f"Đã tải thành công mô hình U2NET lên {device}.") return True except Exception as e: print(f"Lỗi khi tải mô hình U2NET: {e}") U2NET_MODEL = None # Đảm bảo model là None nếu tải lỗi return False # Gọi hàm tải model khi ứng dụng khởi động # FastAPI sẽ chạy hàm này trong sự kiện startup nếu bạn dùng @app.on_event("startup") # Tuy nhiên, để đơn giản, ta gọi trực tiếp. Nếu có lỗi, các endpoint sẽ kiểm tra U2NET_MODEL. MODEL_LOADED_SUCCESSFULLY = load_u2net_model() # --- Các hàm xử lý ảnh với U2NET (tương tự unet.py) --- def _preprocess_for_u2net(pil_image: Image.Image) -> torch.Tensor: """Tiền xử lý ảnh PIL cho đầu vào U2NET.""" transform = transforms.Compose([ transforms.Resize((256, 192)), # (H, W) theo convention của U2NET trong unet.py transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return transform(pil_image).unsqueeze(0) def generate_person_mask_from_image( image_path: str, output_mask_path: str, device: torch.device ) -> str: """ Tạo mặt nạ cho người từ ảnh đầu vào bằng U2NET. Mặt nạ được resize về kích thước ảnh gốc và lưu dưới dạng ảnh grayscale. """ if U2NET_MODEL is None: raise RuntimeError("Mô hình U2NET chưa được tải thành công.") pil_image = Image.open(image_path).convert('RGB') original_w, original_h = pil_image.size input_tensor = _preprocess_for_u2net(pil_image) input_tensor = input_tensor.to(device) # Chuyển tensor sang device của model with torch.no_grad(): # U2NET thường trả về nhiều output ở các scale khác nhau, d1 là output chính d1, *_ = U2NET_MODEL(input_tensor) pred = d1[:,0,:,:] # Lấy mask từ output, shape: (1, H_u2net, W_u2net) pred = pred.squeeze().cpu().numpy() # Chuyển về numpy array trên CPU, shape: (H_u2net, W_u2net) # Chuẩn hóa giá trị của mask về khoảng [0, 1] pred_min = pred.min() pred_max = pred.max() if pred_max - pred_min > 1e-8: # Tránh chia cho 0 pred = (pred - pred_min) / (pred_max - pred_min) else: pred = np.zeros_like(pred) # Nếu ảnh đầu vào đồng màu, mask có thể là 0 # Resize mask về kích thước ảnh gốc # cv2.resize yêu cầu dsize là (width, height) mask_resized = cv2.resize(pred, (original_w, original_h), interpolation=cv2.INTER_LINEAR) # Lưu mask dưới dạng ảnh grayscale (giá trị 0-255) cv2.imwrite(output_mask_path, (mask_resized * 255).astype(np.uint8)) return output_mask_path def apply_mask_and_prepare_person_image( original_image_path: str, person_mask_path: str, # Đường dẫn đến file mask grayscale (0-255) output_image_path: str, target_size: tuple = (192, 256) # (W, H) cho ảnh output cuối cùng của người ) -> str: """ Áp dụng mặt nạ lên ảnh gốc (nền trắng) và resize về kích thước mục tiêu. """ original_bgr = cv2.imread(original_image_path) if original_bgr is None: raise FileNotFoundError(f"Không tìm thấy ảnh gốc: {original_image_path}") original_rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB) mask_gray = cv2.imread(person_mask_path, cv2.IMREAD_GRAYSCALE) if mask_gray is None: raise FileNotFoundError(f"Không tìm thấy ảnh mặt nạ: {person_mask_path}") # Đảm bảo mask có cùng kích thước với ảnh gốc if mask_gray.shape[:2] != original_rgb.shape[:2]: mask_gray = cv2.resize(mask_gray, (original_rgb.shape[1], original_rgb.shape[0]), interpolation=cv2.INTER_LINEAR) # Chuẩn hóa mask về khoảng [0, 1] mask_float = mask_gray / 255.0 # Mở rộng mask thành 3 kênh để áp dụng cho ảnh RGB mask_3channel = np.repeat(np.expand_dims(mask_float, axis=2), 3, axis=2) # Tạo ảnh nền trắng white_background_rgb = np.full_like(original_rgb, 255, dtype=np.uint8) # Áp dụng công thức: result = foreground * mask + background * (1 - mask) composited_rgb = (original_rgb.astype(float) * mask_3channel + \ white_background_rgb.astype(float) * (1 - mask_3channel)) composited_uint8 = np.clip(composited_rgb, 0, 255).astype(np.uint8) # Resize ảnh đã xử lý về kích thước mục tiêu (W, H) # cv2.resize yêu cầu dsize là (width, height) resized_image_rgb = cv2.resize(composited_uint8, target_size, interpolation=cv2.INTER_AREA) # Chuyển lại BGR để lưu bằng OpenCV output_bgr = cv2.cvtColor(resized_image_rgb, cv2.COLOR_RGB2BGR) cv2.imwrite(output_image_path, output_bgr) return output_image_path # --- Placeholder cho mô hình VITON-Extends --- def run_viton_try_on_model( person_image_path: str, # Ảnh người đã xử lý (nền trắng, resized) clothing_image_path: str, # Ảnh trang phục person_mask_path: str, # Mặt nạ người (có thể dùng làm "edge") output_dir: str ) -> str: """ Hàm giả lập việc gọi mô hình VITON-Extends. Trong thực tế, bạn sẽ thay thế phần này bằng code gọi mô hình của bạn. """ print(f"Gọi mô hình VITON (giả lập) với:") print(f" - Ảnh người: {person_image_path}") print(f" - Ảnh trang phục: {clothing_image_path}") print(f" - Mặt nạ người ('edge'): {person_mask_path}") # Tạo một file output giả lập dummy_output_name = f"viton_result_{uuid.uuid4().hex}.png" dummy_output_path = os.path.join(output_dir, dummy_output_name) # Ví dụ: copy ảnh người làm kết quả giả lập if os.path.exists(person_image_path): shutil.copy(person_image_path, dummy_output_path) print(f"Mô hình VITON (giả lập) đã lưu kết quả tại: {dummy_output_path}") return dummy_output_path else: # Tạo ảnh trống nếu không có ảnh người error_img = np.zeros((256, 192, 3), dtype=np.uint8) cv2.putText(error_img, "VITON Error", (10,128), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),1) cv2.imwrite(dummy_output_path, error_img) print(f"Lỗi: Không tìm thấy ảnh người để tạo output giả lập. Đã tạo ảnh lỗi.") return dummy_output_path # --- Thiết lập thư mục tạm --- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) TEMP_DIR = os.path.join(BASE_DIR, "temp_files_viton_api") os.makedirs(TEMP_DIR, exist_ok=True) # Thư mục cho các file input gốc UPLOADED_FILES_DIR = os.path.join(TEMP_DIR, "uploads") os.makedirs(UPLOADED_FILES_DIR, exist_ok=True) # Thư mục cho output của U2NET (ảnh người đã xử lý, mặt nạ người) UNET_OUTPUT_DIR = os.path.join(TEMP_DIR, "unet_outputs") os.makedirs(UNET_OUTPUT_DIR, exist_ok=True) # Thư mục cho output cuối cùng của VITON VITON_OUTPUT_DIR = os.path.join(TEMP_DIR, "viton_results") os.makedirs(VITON_OUTPUT_DIR, exist_ok=True) # --- API Endpoint --- @app.post("/virtual-try-on/", summary="Thực hiện thử đồ ảo", description="Tải lên ảnh người và ảnh trang phục. API sẽ trả về ảnh người mặc trang phục đó.") async def virtual_try_on_endpoint( person_image: UploadFile = File(..., description="Ảnh người dùng (định dạng JPG, PNG)."), clothing_image: UploadFile = File(..., description="Ảnh trang phục (định dạng JPG, PNG).") ): if not MODEL_LOADED_SUCCESSFULLY or U2NET_MODEL is None: return {"error": "Mô hình U2NET chưa sẵn sàng hoặc tải lỗi. Vui lòng kiểm tra console server."} request_id = uuid.uuid4().hex # --- 1. Lưu ảnh tải lên --- person_image_name = f"{request_id}_person{os.path.splitext(person_image.filename)[1]}" clothing_image_name = f"{request_id}_clothing{os.path.splitext(clothing_image.filename)[1]}" original_person_image_path = os.path.join(UPLOADED_FILES_DIR, person_image_name) original_clothing_image_path = os.path.join(UPLOADED_FILES_DIR, clothing_image_name) try: with open(original_person_image_path, "wb") as buffer: shutil.copyfileobj(person_image.file, buffer) with open(original_clothing_image_path, "wb") as buffer: shutil.copyfileobj(clothing_image.file, buffer) except Exception as e: return {"error": f"Lỗi khi lưu file tải lên: {e}"} finally: person_image.file.close() clothing_image.file.close() # --- 2. Xử lý ảnh người bằng U2NET --- # Đường dẫn cho mặt nạ người (person mask) person_mask_filename = f"{request_id}_person_mask.png" generated_person_mask_path = os.path.join(UNET_OUTPUT_DIR, person_mask_filename) # Đường dẫn cho ảnh người đã xử lý (nền trắng, resized) - đây sẽ là "img" cho VITON # Kích thước (192, 256) WxH như trong unet.py gốc processed_person_image_filename = f"{request_id}_person_processed.png" processed_person_image_path = os.path.join(UNET_OUTPUT_DIR, processed_person_image_filename) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: print(f"[{request_id}] Bắt đầu tạo mặt nạ cho: {original_person_image_path}") actual_person_mask_path = generate_person_mask_from_image( original_person_image_path, generated_person_mask_path, device ) print(f"[{request_id}] Đã tạo mặt nạ người tại: {actual_person_mask_path}") print(f"[{request_id}] Bắt đầu xử lý ảnh người: {original_person_image_path}") # "img" cho VITON (192W x 256H, nền trắng) actual_processed_person_image_path = apply_mask_and_prepare_person_image( original_image_path=original_person_image_path, person_mask_path=actual_person_mask_path, output_image_path=processed_person_image_path, target_size=(192, 256) # (W, H) ) print(f"[{request_id}] Đã xử lý ảnh người tại: {actual_processed_person_image_path}") except FileNotFoundError as e: return {"error": f"Lỗi file trong quá trình xử lý U2NET: {e}"} except RuntimeError as e: # Bắt lỗi Runtime từ U2NET (ví dụ model chưa tải) return {"error": f"Lỗi Runtime U2NET: {e}"} except Exception as e: return {"error": f"Lỗi không xác định trong quá trình xử lý U2NET: {e}"} # --- 3. Gọi mô hình VITON-Extends (Placeholder) --- # Đầu vào cho mô hình VITON: # - img: actual_processed_person_image_path (ảnh người 192x256, nền trắng) # - clothes: original_clothing_image_path (ảnh trang phục gốc) # - edge: actual_person_mask_path (mặt nạ người, kích thước gốc) try: print(f"[{request_id}] Bắt đầu gọi mô hình VITON (giả lập)...") final_try_on_image_path = run_viton_try_on_model( person_image_path=actual_processed_person_image_path, clothing_image_path=original_clothing_image_path, person_mask_path=actual_person_mask_path, # "edge" là mặt nạ người output_dir=VITON_OUTPUT_DIR ) print(f"[{request_id}] Mô hình VITON (giả lập) hoàn tất. Kết quả: {final_try_on_image_path}") if not os.path.exists(final_try_on_image_path): return {"error": "Mô hình VITON không tạo ra file output."} # Trả về file ảnh kết quả return FileResponse(final_try_on_image_path, media_type="image/png") except Exception as e: # Cân nhắc dọn dẹp file tạm ở đây nếu cần return {"error": f"Lỗi trong quá trình chạy mô hình VITON: {e}"} # --- Chạy FastAPI app (ví dụ với uvicorn) --- # Để chạy: mở terminal, cd vào thư mục chứa file này và chạy: # uvicorn main:app --reload # (main là tên file .py, app là tên biến FastAPI instance) if __name__ == "__main__": import uvicorn print("Để chạy API, sử dụng lệnh: uvicorn main:app --reload --host 0.0.0.0 --port 8000") uvicorn.run(app, host="0.0.0.0", port=8000) # Bỏ comment để chạy trực tiếp khi thực thi file