| import gradio as gr |
| import torch |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import os |
| |
| |
| from models.acgpn import ACGPN |
| from utils.preprocessing import preprocess_image, parse_human |
|
|
| |
| device = torch.device("cpu") |
|
|
| |
| def load_model(): |
| model = ACGPN() |
| checkpoint_path = "checkpoints/acgpn_checkpoint.pth" |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) |
| model.to(device) |
| model.eval() |
| return model |
|
|
| model = load_model() |
|
|
| |
| def virtual_try_on(person_image, cloth_image): |
| try: |
| |
| person_img = np.array(person_image) |
| cloth_img = np.array(cloth_image) |
|
|
| |
| person_processed, person_mask = preprocess_image(person_img, is_person=True) |
| cloth_processed = preprocess_image(cloth_img, is_person=False) |
|
|
| |
| pose_map, parse_map = parse_human(person_processed) |
|
|
| |
| person_tensor = torch.from_numpy(person_processed).float().to(device) |
| cloth_tensor = torch.from_numpy(cloth_processed).float().to(device) |
| pose_tensor = torch.from_numpy(pose_map).float().to(device) |
| parse_tensor = torch.from_numpy(parse_map).float().to(device) |
|
|
| |
| with torch.no_grad(): |
| output = model(person_tensor, cloth_tensor, pose_tensor, parse_tensor) |
| output = output.cpu().numpy() |
|
|
| |
| output_img = (output * 255).astype(np.uint8) |
| output_img = Image.fromarray(output_img) |
|
|
| return output_img |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| iface = gr.Interface( |
| fn=virtual_try_on, |
| inputs=[ |
| gr.Image(type="pil", label="Upload Person Image"), |
| gr.Image(type="pil", label="Upload Clothing Image"), |
| ], |
| outputs=gr.Image(type="pil", label="Try-On Result"), |
| title="ACGPN Virtual Try-On", |
| description="Upload a person image and a clothing image to see the virtual try-on result.", |
| ) |
|
|
| |
| if __name__ == "__main__": |
| iface.launch(server_name="0.0.0.0", server_port=7860) |