File size: 15,088 Bytes
2bd5c7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import cv2
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import shutil
import os
import uuid
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse

# Giả định rằng file u2net.py chứa class U2NET và file u2net.pth 
# nằm cùng thư mục với main.py hoặc trong PYTHONPATH.
# Ví dụ: from your_project.u2net import U2NET
# Nếu u2net.py là file định nghĩa class, bạn có thể import trực tiếp:
try:
    from u2net import U2NET  # Model definition
except ImportError:
    print("Lỗi: Không tìm thấy file u2net.py định nghĩa class U2NET.")
    print("Hãy đảm bảo file u2net.py (chứa class U2NET) nằm trong cùng thư mục hoặc PYTHONPATH.")
    exit()


# --- Khởi tạo FastAPI app ---
app = FastAPI(title="VITON-Extends API")

# --- Cấu hình và tải mô hình U2NET ---
# Thực hiện một lần khi ứng dụng khởi động
U2NET_MODEL = None
_API_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
U2NET_MODEL_PATH = os.path.join(_API_SCRIPT_DIR, 'checkpoints', 'u2net.pth') 

def load_u2net_model():
    global U2NET_MODEL
    if not os.path.exists(U2NET_MODEL_PATH):
        print(f"Lỗi: Không tìm thấy file trọng số U2NET tại '{U2NET_MODEL_PATH}'.")
        print("Hãy đảm bảo file u2net.pth nằm đúng vị trí.")
        # Không exit() ở đây để FastAPI vẫn có thể khởi động và báo lỗi qua API nếu cần
        # Hoặc bạn có thể quyết định exit() nếu U2NET là bắt buộc.
        return False

    try:
        U2NET_MODEL = U2NET(3, 1) # 3 kênh đầu vào (RGB), 1 kênh đầu ra (mask)
        # Sử dụng torch.device để đảm bảo tương thích CPU/GPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        U2NET_MODEL.load_state_dict(torch.load(U2NET_MODEL_PATH, map_location=device))
        U2NET_MODEL.to(device)
        U2NET_MODEL.eval()
        print(f"Đã tải thành công mô hình U2NET lên {device}.")
        return True
    except Exception as e:
        print(f"Lỗi khi tải mô hình U2NET: {e}")
        U2NET_MODEL = None # Đảm bảo model là None nếu tải lỗi
        return False

# Gọi hàm tải model khi ứng dụng khởi động
# FastAPI sẽ chạy hàm này trong sự kiện startup nếu bạn dùng @app.on_event("startup")
# Tuy nhiên, để đơn giản, ta gọi trực tiếp. Nếu có lỗi, các endpoint sẽ kiểm tra U2NET_MODEL.
MODEL_LOADED_SUCCESSFULLY = load_u2net_model()


# --- Các hàm xử lý ảnh với U2NET (tương tự unet.py) ---
def _preprocess_for_u2net(pil_image: Image.Image) -> torch.Tensor:
    """Tiền xử lý ảnh PIL cho đầu vào U2NET."""
    transform = transforms.Compose([
        transforms.Resize((256, 192)), # (H, W) theo convention của U2NET trong unet.py
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return transform(pil_image).unsqueeze(0)

def generate_person_mask_from_image(

    image_path: str, 

    output_mask_path: str,

    device: torch.device

) -> str:
    """

    Tạo mặt nạ cho người từ ảnh đầu vào bằng U2NET.

    Mặt nạ được resize về kích thước ảnh gốc và lưu dưới dạng ảnh grayscale.

    """
    if U2NET_MODEL is None:
        raise RuntimeError("Mô hình U2NET chưa được tải thành công.")

    pil_image = Image.open(image_path).convert('RGB')
    original_w, original_h = pil_image.size
    
    input_tensor = _preprocess_for_u2net(pil_image)
    input_tensor = input_tensor.to(device) # Chuyển tensor sang device của model
    
    with torch.no_grad():
        # U2NET thường trả về nhiều output ở các scale khác nhau, d1 là output chính
        d1, *_ = U2NET_MODEL(input_tensor) 
        
        pred = d1[:,0,:,:] # Lấy mask từ output, shape: (1, H_u2net, W_u2net)
        pred = pred.squeeze().cpu().numpy() # Chuyển về numpy array trên CPU, shape: (H_u2net, W_u2net)
        
        # Chuẩn hóa giá trị của mask về khoảng [0, 1]
        pred_min = pred.min()
        pred_max = pred.max()
        if pred_max - pred_min > 1e-8: # Tránh chia cho 0
            pred = (pred - pred_min) / (pred_max - pred_min)
        else:
            pred = np.zeros_like(pred) # Nếu ảnh đầu vào đồng màu, mask có thể là 0
            
        # Resize mask về kích thước ảnh gốc
        # cv2.resize yêu cầu dsize là (width, height)
        mask_resized = cv2.resize(pred, (original_w, original_h), interpolation=cv2.INTER_LINEAR)
        
        # Lưu mask dưới dạng ảnh grayscale (giá trị 0-255)
        cv2.imwrite(output_mask_path, (mask_resized * 255).astype(np.uint8))
    return output_mask_path

def apply_mask_and_prepare_person_image(

    original_image_path: str,

    person_mask_path: str, # Đường dẫn đến file mask grayscale (0-255)

    output_image_path: str,

    target_size: tuple = (192, 256) # (W, H) cho ảnh output cuối cùng của người

) -> str:
    """

    Áp dụng mặt nạ lên ảnh gốc (nền trắng) và resize về kích thước mục tiêu.

    """
    original_bgr = cv2.imread(original_image_path)
    if original_bgr is None:
        raise FileNotFoundError(f"Không tìm thấy ảnh gốc: {original_image_path}")
    original_rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)
    
    mask_gray = cv2.imread(person_mask_path, cv2.IMREAD_GRAYSCALE)
    if mask_gray is None:
        raise FileNotFoundError(f"Không tìm thấy ảnh mặt nạ: {person_mask_path}")

    # Đảm bảo mask có cùng kích thước với ảnh gốc
    if mask_gray.shape[:2] != original_rgb.shape[:2]:
        mask_gray = cv2.resize(mask_gray, (original_rgb.shape[1], original_rgb.shape[0]), 
                               interpolation=cv2.INTER_LINEAR)

    # Chuẩn hóa mask về khoảng [0, 1]
    mask_float = mask_gray / 255.0
    # Mở rộng mask thành 3 kênh để áp dụng cho ảnh RGB
    mask_3channel = np.repeat(np.expand_dims(mask_float, axis=2), 3, axis=2)
    
    # Tạo ảnh nền trắng
    white_background_rgb = np.full_like(original_rgb, 255, dtype=np.uint8)
    
    # Áp dụng công thức: result = foreground * mask + background * (1 - mask)
    composited_rgb = (original_rgb.astype(float) * mask_3channel + \
                      white_background_rgb.astype(float) * (1 - mask_3channel))
    composited_uint8 = np.clip(composited_rgb, 0, 255).astype(np.uint8)
    
    # Resize ảnh đã xử lý về kích thước mục tiêu (W, H)
    # cv2.resize yêu cầu dsize là (width, height)
    resized_image_rgb = cv2.resize(composited_uint8, target_size, interpolation=cv2.INTER_AREA)
        
    # Chuyển lại BGR để lưu bằng OpenCV
    output_bgr = cv2.cvtColor(resized_image_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite(output_image_path, output_bgr)
    return output_image_path

# --- Placeholder cho mô hình VITON-Extends ---
def run_viton_try_on_model(

    person_image_path: str,      # Ảnh người đã xử lý (nền trắng, resized)

    clothing_image_path: str,    # Ảnh trang phục

    person_mask_path: str,       # Mặt nạ người (có thể dùng làm "edge")

    output_dir: str

) -> str:
    """

    Hàm giả lập việc gọi mô hình VITON-Extends.

    Trong thực tế, bạn sẽ thay thế phần này bằng code gọi mô hình của bạn.

    """
    print(f"Gọi mô hình VITON (giả lập) với:")
    print(f"  - Ảnh người: {person_image_path}")
    print(f"  - Ảnh trang phục: {clothing_image_path}")
    print(f"  - Mặt nạ người ('edge'): {person_mask_path}")

    # Tạo một file output giả lập
    dummy_output_name = f"viton_result_{uuid.uuid4().hex}.png"
    dummy_output_path = os.path.join(output_dir, dummy_output_name)
    
    # Ví dụ: copy ảnh người làm kết quả giả lập
    if os.path.exists(person_image_path):
        shutil.copy(person_image_path, dummy_output_path)
        print(f"Mô hình VITON (giả lập) đã lưu kết quả tại: {dummy_output_path}")
        return dummy_output_path
    else:
        # Tạo ảnh trống nếu không có ảnh người
        error_img = np.zeros((256, 192, 3), dtype=np.uint8)
        cv2.putText(error_img, "VITON Error", (10,128), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),1)
        cv2.imwrite(dummy_output_path, error_img)
        print(f"Lỗi: Không tìm thấy ảnh người để tạo output giả lập. Đã tạo ảnh lỗi.")
        return dummy_output_path


# --- Thiết lập thư mục tạm ---
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
TEMP_DIR = os.path.join(BASE_DIR, "temp_files_viton_api")
os.makedirs(TEMP_DIR, exist_ok=True)

# Thư mục cho các file input gốc
UPLOADED_FILES_DIR = os.path.join(TEMP_DIR, "uploads")
os.makedirs(UPLOADED_FILES_DIR, exist_ok=True)

# Thư mục cho output của U2NET (ảnh người đã xử lý, mặt nạ người)
UNET_OUTPUT_DIR = os.path.join(TEMP_DIR, "unet_outputs")
os.makedirs(UNET_OUTPUT_DIR, exist_ok=True)

# Thư mục cho output cuối cùng của VITON
VITON_OUTPUT_DIR = os.path.join(TEMP_DIR, "viton_results")
os.makedirs(VITON_OUTPUT_DIR, exist_ok=True)


# --- API Endpoint ---
@app.post("/virtual-try-on/", 

          summary="Thực hiện thử đồ ảo",

          description="Tải lên ảnh người và ảnh trang phục. API sẽ trả về ảnh người mặc trang phục đó.")
async def virtual_try_on_endpoint(

    person_image: UploadFile = File(..., description="Ảnh người dùng (định dạng JPG, PNG)."),

    clothing_image: UploadFile = File(..., description="Ảnh trang phục (định dạng JPG, PNG).")

):
    if not MODEL_LOADED_SUCCESSFULLY or U2NET_MODEL is None:
        return {"error": "Mô hình U2NET chưa sẵn sàng hoặc tải lỗi. Vui lòng kiểm tra console server."}

    request_id = uuid.uuid4().hex
    
    # --- 1. Lưu ảnh tải lên ---
    person_image_name = f"{request_id}_person{os.path.splitext(person_image.filename)[1]}"
    clothing_image_name = f"{request_id}_clothing{os.path.splitext(clothing_image.filename)[1]}"
    
    original_person_image_path = os.path.join(UPLOADED_FILES_DIR, person_image_name)
    original_clothing_image_path = os.path.join(UPLOADED_FILES_DIR, clothing_image_name)

    try:
        with open(original_person_image_path, "wb") as buffer:
            shutil.copyfileobj(person_image.file, buffer)
        with open(original_clothing_image_path, "wb") as buffer:
            shutil.copyfileobj(clothing_image.file, buffer)
    except Exception as e:
        return {"error": f"Lỗi khi lưu file tải lên: {e}"}
    finally:
        person_image.file.close()
        clothing_image.file.close()

    # --- 2. Xử lý ảnh người bằng U2NET ---
    # Đường dẫn cho mặt nạ người (person mask)
    person_mask_filename = f"{request_id}_person_mask.png"
    generated_person_mask_path = os.path.join(UNET_OUTPUT_DIR, person_mask_filename)

    # Đường dẫn cho ảnh người đã xử lý (nền trắng, resized) - đây sẽ là "img" cho VITON
    # Kích thước (192, 256) WxH như trong unet.py gốc
    processed_person_image_filename = f"{request_id}_person_processed.png"
    processed_person_image_path = os.path.join(UNET_OUTPUT_DIR, processed_person_image_filename)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    try:
        print(f"[{request_id}] Bắt đầu tạo mặt nạ cho: {original_person_image_path}")
        actual_person_mask_path = generate_person_mask_from_image(
            original_person_image_path, 
            generated_person_mask_path,
            device
        )
        print(f"[{request_id}] Đã tạo mặt nạ người tại: {actual_person_mask_path}")

        print(f"[{request_id}] Bắt đầu xử lý ảnh người: {original_person_image_path}")
        # "img" cho VITON (192W x 256H, nền trắng)
        actual_processed_person_image_path = apply_mask_and_prepare_person_image(
            original_image_path=original_person_image_path,
            person_mask_path=actual_person_mask_path,
            output_image_path=processed_person_image_path,
            target_size=(192, 256) # (W, H)
        )
        print(f"[{request_id}] Đã xử lý ảnh người tại: {actual_processed_person_image_path}")

    except FileNotFoundError as e:
        return {"error": f"Lỗi file trong quá trình xử lý U2NET: {e}"}
    except RuntimeError as e: # Bắt lỗi Runtime từ U2NET (ví dụ model chưa tải)
        return {"error": f"Lỗi Runtime U2NET: {e}"}
    except Exception as e:
        return {"error": f"Lỗi không xác định trong quá trình xử lý U2NET: {e}"}

    # --- 3. Gọi mô hình VITON-Extends (Placeholder) ---
    # Đầu vào cho mô hình VITON:
    # - img: actual_processed_person_image_path (ảnh người 192x256, nền trắng)
    # - clothes: original_clothing_image_path (ảnh trang phục gốc)
    # - edge: actual_person_mask_path (mặt nạ người, kích thước gốc)
    try:
        print(f"[{request_id}] Bắt đầu gọi mô hình VITON (giả lập)...")
        final_try_on_image_path = run_viton_try_on_model(
            person_image_path=actual_processed_person_image_path,
            clothing_image_path=original_clothing_image_path,
            person_mask_path=actual_person_mask_path, # "edge" là mặt nạ người
            output_dir=VITON_OUTPUT_DIR
        )
        print(f"[{request_id}] Mô hình VITON (giả lập) hoàn tất. Kết quả: {final_try_on_image_path}")

        if not os.path.exists(final_try_on_image_path):
            return {"error": "Mô hình VITON không tạo ra file output."}
        
        # Trả về file ảnh kết quả
        return FileResponse(final_try_on_image_path, media_type="image/png")

    except Exception as e:
        # Cân nhắc dọn dẹp file tạm ở đây nếu cần
        return {"error": f"Lỗi trong quá trình chạy mô hình VITON: {e}"}

# --- Chạy FastAPI app (ví dụ với uvicorn) ---
# Để chạy: mở terminal, cd vào thư mục chứa file này và chạy:
# uvicorn main:app --reload
# (main là tên file .py, app là tên biến FastAPI instance)

if __name__ == "__main__":
    import uvicorn
    print("Để chạy API, sử dụng lệnh: uvicorn main:app --reload --host 0.0.0.0 --port 8000")
    uvicorn.run(app, host="0.0.0.0", port=8000) # Bỏ comment để chạy trực tiếp khi thực thi file