| import cv2
|
| import torch
|
| import numpy as np
|
| from torchvision import transforms
|
| from PIL import Image
|
| import shutil
|
| import os
|
| import uuid
|
| from fastapi import FastAPI, File, UploadFile
|
| from fastapi.responses import FileResponse
|
|
|
|
|
|
|
|
|
|
|
| try:
|
| from u2net import U2NET
|
| except ImportError:
|
| print("Lỗi: Không tìm thấy file u2net.py định nghĩa class U2NET.")
|
| print("Hãy đảm bảo file u2net.py (chứa class U2NET) nằm trong cùng thư mục hoặc PYTHONPATH.")
|
| exit()
|
|
|
|
|
|
|
| app = FastAPI(title="VITON-Extends API")
|
|
|
|
|
|
|
| U2NET_MODEL = None
|
| _API_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| U2NET_MODEL_PATH = os.path.join(_API_SCRIPT_DIR, 'checkpoints', 'u2net.pth')
|
|
|
| def load_u2net_model():
|
| global U2NET_MODEL
|
| if not os.path.exists(U2NET_MODEL_PATH):
|
| print(f"Lỗi: Không tìm thấy file trọng số U2NET tại '{U2NET_MODEL_PATH}'.")
|
| print("Hãy đảm bảo file u2net.pth nằm đúng vị trí.")
|
|
|
|
|
| return False
|
|
|
| try:
|
| U2NET_MODEL = U2NET(3, 1)
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| U2NET_MODEL.load_state_dict(torch.load(U2NET_MODEL_PATH, map_location=device))
|
| U2NET_MODEL.to(device)
|
| U2NET_MODEL.eval()
|
| print(f"Đã tải thành công mô hình U2NET lên {device}.")
|
| return True
|
| except Exception as e:
|
| print(f"Lỗi khi tải mô hình U2NET: {e}")
|
| U2NET_MODEL = None
|
| return False
|
|
|
|
|
|
|
|
|
| MODEL_LOADED_SUCCESSFULLY = load_u2net_model()
|
|
|
|
|
|
|
| def _preprocess_for_u2net(pil_image: Image.Image) -> torch.Tensor:
|
| """Tiền xử lý ảnh PIL cho đầu vào U2NET."""
|
| transform = transforms.Compose([
|
| transforms.Resize((256, 192)),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| ])
|
| return transform(pil_image).unsqueeze(0)
|
|
|
| def generate_person_mask_from_image(
|
| image_path: str,
|
| output_mask_path: str,
|
| device: torch.device
|
| ) -> str:
|
| """
|
| Tạo mặt nạ cho người từ ảnh đầu vào bằng U2NET.
|
| Mặt nạ được resize về kích thước ảnh gốc và lưu dưới dạng ảnh grayscale.
|
| """
|
| if U2NET_MODEL is None:
|
| raise RuntimeError("Mô hình U2NET chưa được tải thành công.")
|
|
|
| pil_image = Image.open(image_path).convert('RGB')
|
| original_w, original_h = pil_image.size
|
|
|
| input_tensor = _preprocess_for_u2net(pil_image)
|
| input_tensor = input_tensor.to(device)
|
|
|
| with torch.no_grad():
|
|
|
| d1, *_ = U2NET_MODEL(input_tensor)
|
|
|
| pred = d1[:,0,:,:]
|
| pred = pred.squeeze().cpu().numpy()
|
|
|
|
|
| pred_min = pred.min()
|
| pred_max = pred.max()
|
| if pred_max - pred_min > 1e-8:
|
| pred = (pred - pred_min) / (pred_max - pred_min)
|
| else:
|
| pred = np.zeros_like(pred)
|
|
|
|
|
|
|
| mask_resized = cv2.resize(pred, (original_w, original_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
| cv2.imwrite(output_mask_path, (mask_resized * 255).astype(np.uint8))
|
| return output_mask_path
|
|
|
| def apply_mask_and_prepare_person_image(
|
| original_image_path: str,
|
| person_mask_path: str,
|
| output_image_path: str,
|
| target_size: tuple = (192, 256)
|
| ) -> str:
|
| """
|
| Áp dụng mặt nạ lên ảnh gốc (nền trắng) và resize về kích thước mục tiêu.
|
| """
|
| original_bgr = cv2.imread(original_image_path)
|
| if original_bgr is None:
|
| raise FileNotFoundError(f"Không tìm thấy ảnh gốc: {original_image_path}")
|
| original_rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)
|
|
|
| mask_gray = cv2.imread(person_mask_path, cv2.IMREAD_GRAYSCALE)
|
| if mask_gray is None:
|
| raise FileNotFoundError(f"Không tìm thấy ảnh mặt nạ: {person_mask_path}")
|
|
|
|
|
| if mask_gray.shape[:2] != original_rgb.shape[:2]:
|
| mask_gray = cv2.resize(mask_gray, (original_rgb.shape[1], original_rgb.shape[0]),
|
| interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
| mask_float = mask_gray / 255.0
|
|
|
| mask_3channel = np.repeat(np.expand_dims(mask_float, axis=2), 3, axis=2)
|
|
|
|
|
| white_background_rgb = np.full_like(original_rgb, 255, dtype=np.uint8)
|
|
|
|
|
| composited_rgb = (original_rgb.astype(float) * mask_3channel + \
|
| white_background_rgb.astype(float) * (1 - mask_3channel))
|
| composited_uint8 = np.clip(composited_rgb, 0, 255).astype(np.uint8)
|
|
|
|
|
|
|
| resized_image_rgb = cv2.resize(composited_uint8, target_size, interpolation=cv2.INTER_AREA)
|
|
|
|
|
| output_bgr = cv2.cvtColor(resized_image_rgb, cv2.COLOR_RGB2BGR)
|
| cv2.imwrite(output_image_path, output_bgr)
|
| return output_image_path
|
|
|
|
|
| def run_viton_try_on_model(
|
| person_image_path: str,
|
| clothing_image_path: str,
|
| person_mask_path: str,
|
| output_dir: str
|
| ) -> str:
|
| """
|
| Hàm giả lập việc gọi mô hình VITON-Extends.
|
| Trong thực tế, bạn sẽ thay thế phần này bằng code gọi mô hình của bạn.
|
| """
|
| print(f"Gọi mô hình VITON (giả lập) với:")
|
| print(f" - Ảnh người: {person_image_path}")
|
| print(f" - Ảnh trang phục: {clothing_image_path}")
|
| print(f" - Mặt nạ người ('edge'): {person_mask_path}")
|
|
|
|
|
| dummy_output_name = f"viton_result_{uuid.uuid4().hex}.png"
|
| dummy_output_path = os.path.join(output_dir, dummy_output_name)
|
|
|
|
|
| if os.path.exists(person_image_path):
|
| shutil.copy(person_image_path, dummy_output_path)
|
| print(f"Mô hình VITON (giả lập) đã lưu kết quả tại: {dummy_output_path}")
|
| return dummy_output_path
|
| else:
|
|
|
| error_img = np.zeros((256, 192, 3), dtype=np.uint8)
|
| cv2.putText(error_img, "VITON Error", (10,128), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),1)
|
| cv2.imwrite(dummy_output_path, error_img)
|
| print(f"Lỗi: Không tìm thấy ảnh người để tạo output giả lập. Đã tạo ảnh lỗi.")
|
| return dummy_output_path
|
|
|
|
|
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| TEMP_DIR = os.path.join(BASE_DIR, "temp_files_viton_api")
|
| os.makedirs(TEMP_DIR, exist_ok=True)
|
|
|
|
|
| UPLOADED_FILES_DIR = os.path.join(TEMP_DIR, "uploads")
|
| os.makedirs(UPLOADED_FILES_DIR, exist_ok=True)
|
|
|
|
|
| UNET_OUTPUT_DIR = os.path.join(TEMP_DIR, "unet_outputs")
|
| os.makedirs(UNET_OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
| VITON_OUTPUT_DIR = os.path.join(TEMP_DIR, "viton_results")
|
| os.makedirs(VITON_OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
| @app.post("/virtual-try-on/",
|
| summary="Thực hiện thử đồ ảo",
|
| description="Tải lên ảnh người và ảnh trang phục. API sẽ trả về ảnh người mặc trang phục đó.")
|
| async def virtual_try_on_endpoint(
|
| person_image: UploadFile = File(..., description="Ảnh người dùng (định dạng JPG, PNG)."),
|
| clothing_image: UploadFile = File(..., description="Ảnh trang phục (định dạng JPG, PNG).")
|
| ):
|
| if not MODEL_LOADED_SUCCESSFULLY or U2NET_MODEL is None:
|
| return {"error": "Mô hình U2NET chưa sẵn sàng hoặc tải lỗi. Vui lòng kiểm tra console server."}
|
|
|
| request_id = uuid.uuid4().hex
|
|
|
|
|
| person_image_name = f"{request_id}_person{os.path.splitext(person_image.filename)[1]}"
|
| clothing_image_name = f"{request_id}_clothing{os.path.splitext(clothing_image.filename)[1]}"
|
|
|
| original_person_image_path = os.path.join(UPLOADED_FILES_DIR, person_image_name)
|
| original_clothing_image_path = os.path.join(UPLOADED_FILES_DIR, clothing_image_name)
|
|
|
| try:
|
| with open(original_person_image_path, "wb") as buffer:
|
| shutil.copyfileobj(person_image.file, buffer)
|
| with open(original_clothing_image_path, "wb") as buffer:
|
| shutil.copyfileobj(clothing_image.file, buffer)
|
| except Exception as e:
|
| return {"error": f"Lỗi khi lưu file tải lên: {e}"}
|
| finally:
|
| person_image.file.close()
|
| clothing_image.file.close()
|
|
|
|
|
|
|
| person_mask_filename = f"{request_id}_person_mask.png"
|
| generated_person_mask_path = os.path.join(UNET_OUTPUT_DIR, person_mask_filename)
|
|
|
|
|
|
|
| processed_person_image_filename = f"{request_id}_person_processed.png"
|
| processed_person_image_path = os.path.join(UNET_OUTPUT_DIR, processed_person_image_filename)
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
| try:
|
| print(f"[{request_id}] Bắt đầu tạo mặt nạ cho: {original_person_image_path}")
|
| actual_person_mask_path = generate_person_mask_from_image(
|
| original_person_image_path,
|
| generated_person_mask_path,
|
| device
|
| )
|
| print(f"[{request_id}] Đã tạo mặt nạ người tại: {actual_person_mask_path}")
|
|
|
| print(f"[{request_id}] Bắt đầu xử lý ảnh người: {original_person_image_path}")
|
|
|
| actual_processed_person_image_path = apply_mask_and_prepare_person_image(
|
| original_image_path=original_person_image_path,
|
| person_mask_path=actual_person_mask_path,
|
| output_image_path=processed_person_image_path,
|
| target_size=(192, 256)
|
| )
|
| print(f"[{request_id}] Đã xử lý ảnh người tại: {actual_processed_person_image_path}")
|
|
|
| except FileNotFoundError as e:
|
| return {"error": f"Lỗi file trong quá trình xử lý U2NET: {e}"}
|
| except RuntimeError as e:
|
| return {"error": f"Lỗi Runtime U2NET: {e}"}
|
| except Exception as e:
|
| return {"error": f"Lỗi không xác định trong quá trình xử lý U2NET: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| try:
|
| print(f"[{request_id}] Bắt đầu gọi mô hình VITON (giả lập)...")
|
| final_try_on_image_path = run_viton_try_on_model(
|
| person_image_path=actual_processed_person_image_path,
|
| clothing_image_path=original_clothing_image_path,
|
| person_mask_path=actual_person_mask_path,
|
| output_dir=VITON_OUTPUT_DIR
|
| )
|
| print(f"[{request_id}] Mô hình VITON (giả lập) hoàn tất. Kết quả: {final_try_on_image_path}")
|
|
|
| if not os.path.exists(final_try_on_image_path):
|
| return {"error": "Mô hình VITON không tạo ra file output."}
|
|
|
|
|
| return FileResponse(final_try_on_image_path, media_type="image/png")
|
|
|
| except Exception as e:
|
|
|
| return {"error": f"Lỗi trong quá trình chạy mô hình VITON: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
| print("Để chạy API, sử dụng lệnh: uvicorn main:app --reload --host 0.0.0.0 --port 8000")
|
| uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|