|
|
import openvino |
|
|
import torch |
|
|
import os |
|
|
import time |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class OVWrapperTextEncoder: |
|
|
def __init__(self, ov_model, core, device="CPU", torch_dtype=torch.bfloat16): |
|
|
self.model = ov_model |
|
|
self.compiled_model = core.compile_model(ov_model, device) |
|
|
self.device = device |
|
|
self.dtype = torch_dtype |
|
|
|
|
|
|
|
|
self.input_key = self.compiled_model.inputs[0].get_any_name() |
|
|
self.output_key = self.compiled_model.outputs[0].get_any_name() |
|
|
|
|
|
self.config = type('obj', (object,), { |
|
|
"d_model": 768, |
|
|
"hidden_size": 768, |
|
|
}) |
|
|
|
|
|
def __call__(self, input_ids, attention_mask=None, return_dict=False, output_hidden_states=False, **kwargs): |
|
|
|
|
|
if isinstance(input_ids, torch.Tensor): |
|
|
input_ids = input_ids.cpu().numpy() |
|
|
|
|
|
|
|
|
inputs = {self.input_key: input_ids} |
|
|
if attention_mask is not None and len(self.compiled_model.inputs) > 1: |
|
|
attn_key = self.compiled_model.inputs[1].get_any_name() |
|
|
inputs[attn_key] = attention_mask.cpu().numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask |
|
|
|
|
|
|
|
|
outputs = self.compiled_model(inputs) |
|
|
|
|
|
|
|
|
last_hidden_state = torch.from_numpy(outputs[self.output_key]).to(self.dtype) |
|
|
|
|
|
return (last_hidden_state,) |
|
|
|
|
|
|
|
|
def crop_to_multiple_of_16(img): |
|
|
""" |
|
|
Cắt ảnh để kích thước là bội số của 16. |
|
|
|
|
|
Parameters: |
|
|
img (PIL.Image): Ảnh đầu vào cần cắt. |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Ảnh đã được cắt với kích thước là bội số của 16. |
|
|
""" |
|
|
width, height = img.size |
|
|
|
|
|
|
|
|
new_width = width - (width % 16) |
|
|
new_height = height - (height % 16) |
|
|
|
|
|
|
|
|
left = (width - new_width) // 2 |
|
|
top = (height - new_height) // 2 |
|
|
right = left + new_width |
|
|
bottom = top + new_height |
|
|
|
|
|
|
|
|
cropped_img = img.crop((left, top, right, bottom)) |
|
|
|
|
|
return cropped_img |
|
|
|
|
|
|
|
|
def resize_and_pad_to_size(image, target_width, target_height): |
|
|
""" |
|
|
Thay đổi kích thước ảnh và thêm padding để đạt kích thước mục tiêu mà không làm biến dạng ảnh. |
|
|
|
|
|
Parameters: |
|
|
image (PIL.Image hoặc np.ndarray): Ảnh đầu vào cần xử lý. |
|
|
target_width (int): Chiều rộng mục tiêu của ảnh sau khi xử lý. |
|
|
target_height (int): Chiều cao mục tiêu của ảnh sau khi xử lý. |
|
|
|
|
|
Returns: |
|
|
tuple: Bao gồm: |
|
|
- PIL.Image: Ảnh đã được resize và padding. |
|
|
- int: Khoảng cách padding bên trái. |
|
|
- int: Khoảng cách padding bên trên. |
|
|
- int: Khoảng cách padding bên phải. |
|
|
- int: Khoảng cách padding bên dưới. |
|
|
""" |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
orig_width, orig_height = image.size |
|
|
|
|
|
|
|
|
target_ratio = target_width / target_height |
|
|
orig_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
if orig_ratio > target_ratio: |
|
|
|
|
|
new_width = target_width |
|
|
new_height = int(new_width / orig_ratio) |
|
|
else: |
|
|
|
|
|
new_height = target_height |
|
|
new_width = int(new_height * orig_ratio) |
|
|
|
|
|
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
|
|
|
|
|
|
|
padded_image = Image.new('RGB', (target_width, target_height), 'white') |
|
|
|
|
|
|
|
|
left_padding = (target_width - new_width) // 2 |
|
|
top_padding = (target_height - new_height) // 2 |
|
|
|
|
|
|
|
|
padded_image.paste(resized_image, (left_padding, top_padding)) |
|
|
|
|
|
return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding |
|
|
|
|
|
|
|
|
def resize_by_height(image, height): |
|
|
""" |
|
|
Thay đổi kích thước ảnh theo chiều cao đã cho và cắt ảnh để kích thước là bội số của 16. |
|
|
|
|
|
Parameters: |
|
|
image (PIL.Image hoặc np.ndarray): Ảnh đầu vào cần xử lý. |
|
|
height (int): Chiều cao mục tiêu của ảnh sau khi resize. |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Ảnh đã được resize theo chiều cao và cắt để kích thước là bội số của 16. |
|
|
""" |
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = image.resize((int(image.width * height / image.height), height)) |
|
|
return crop_to_multiple_of_16(image) |
|
|
|
|
|
|
|
|
from src.pipeline_flux_tryon import FluxTryonPipeline |
|
|
|
|
|
class changeClothes(): |
|
|
def __init__(self, local_dir="./flux_models"): |
|
|
""" |
|
|
local_dir: Đường dẫn đến thư mục đã lưu model |
|
|
""" |
|
|
self.device = torch.device("cuda") |
|
|
self.local_dir = local_dir |
|
|
self.pipe = self.load_models() |
|
|
|
|
|
|
|
|
def load_models(self): |
|
|
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig |
|
|
|
|
|
print("🔄 Loading 4bit transformer ...") |
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True |
|
|
) |
|
|
|
|
|
|
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
f"{self.local_dir}/transformer", |
|
|
quantization_config=bnb_config, |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
print("🔄 Loading optimized FLUX pipeline...") |
|
|
|
|
|
pipe = FluxTryonPipeline.from_pretrained( |
|
|
self.local_dir, |
|
|
transformer=transformer, |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
core = openvino.runtime.Core() |
|
|
ov_text_encoder_2 = core.read_model(f"{self.local_dir}/OV_text_encoder_2/openvino_model.xml") |
|
|
pipe.text_encoder_2 = OVWrapperTextEncoder(ov_text_encoder_2, core, torch_dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
print("Loading LoRA weights .....") |
|
|
pipe.load_lora_weights( |
|
|
"loooooong/Any2anyTryon", |
|
|
weight_name="dev_lora_any2any_tryon.safetensors", |
|
|
adapter_name="tryon" |
|
|
) |
|
|
|
|
|
|
|
|
pipe.fuse_lora() |
|
|
|
|
|
|
|
|
pipe.enable_attention_slicing() |
|
|
pipe.vae.enable_slicing() |
|
|
pipe.vae.enable_tiling() |
|
|
|
|
|
print("Load to CUDA ...........") |
|
|
pipe.to("cuda") |
|
|
|
|
|
return pipe |
|
|
|
|
|
@staticmethod |
|
|
def check_image(image_input): |
|
|
""" |
|
|
Kiểm tra và xử lý đầu vào ảnh |
|
|
|
|
|
Args: |
|
|
image_input: Có thể là None, đối tượng Image từ PIL hoặc đường dẫn đến file ảnh |
|
|
|
|
|
Returns: |
|
|
None nếu đầu vào là None, ngược lại trả về đối tượng Image từ PIL |
|
|
""" |
|
|
|
|
|
if image_input is None: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(image_input, str) and os.path.isfile(image_input): |
|
|
try: |
|
|
return Image.open(image_input) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Không thể mở file ảnh: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if hasattr(image_input, 'size'): |
|
|
return image_input |
|
|
|
|
|
|
|
|
if isinstance(image_input, np.ndarray): |
|
|
return Image.fromarray(image_input) |
|
|
|
|
|
|
|
|
|
|
|
return Image.fromarray(np.array(image_input)) |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"Không thể chuyển đối tượng thành Image: {e}") |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_image(self, model_image, garment_image, prompt="", prompt_2=None, height=384, width=288, |
|
|
seed=0, guidance_scale=3.5, num_inference_steps=30, device4GEN="cuda"): |
|
|
""" |
|
|
Tạo ảnh thử đồ ảo dựa trên mô hình người và hình ảnh quần áo đầu vào. |
|
|
|
|
|
Parameters: |
|
|
model_image (PIL.Image, optional): Ảnh người mẫu, có thể là None. |
|
|
garment_image (PIL.Image, optional): Ảnh quần áo cần thử, có thể là None. |
|
|
prompt (str, optional): Mô tả text cho quá trình tạo ảnh. Mặc định là chuỗi rỗng. |
|
|
prompt_2 (str, optional): Mô tả text thứ hai (nếu cần). Mặc định là None. |
|
|
height (int, optional): Chiều cao của ảnh xử lý. Mặc định là 512. |
|
|
width (int, optional): Chiều rộng của ảnh xử lý. Mặc định là 384. |
|
|
seed (int, optional): Giá trị seed để đảm bảo tính tái sản xuất. Mặc định là 0. |
|
|
guidance_scale (float, optional): Mức độ ảnh hưởng của prompt lên quá trình tạo ảnh. Mặc định là 3.5. |
|
|
num_inference_steps (int, optional): Số bước lặp trong quá trình tạo ảnh. Mặc định là 30. |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Ảnh kết quả sau khi thực hiện thử đồ ảo. Trong trường hợp thử đồ (cả model và garment), |
|
|
ảnh sẽ được resize về kích thước gốc của ảnh người mẫu. |
|
|
|
|
|
Notes: |
|
|
- Hàm này sử dụng decorator @torch.no_grad() để tắt gradient trong quá trình inference. |
|
|
- Kích thước xử lý sẽ được điều chỉnh để là bội số của 16. |
|
|
- Nếu cả model_image và garment_image được cung cấp, kết quả sẽ là ảnh thử đồ ảo và được |
|
|
trả về với kích thước gốc của ảnh người mẫu đầu vào. |
|
|
- Nếu chỉ cung cấp một trong hai, kết quả sẽ là ảnh được tạo dựa trên prompt và ảnh được cung cấp. |
|
|
""" |
|
|
|
|
|
height, width = int(height), int(width) |
|
|
width = width - (width % 16) |
|
|
height = height - (height % 16) |
|
|
|
|
|
concat_image_list = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))] |
|
|
|
|
|
|
|
|
model_image = self.check_image(model_image) |
|
|
garment_image = self.check_image(garment_image) |
|
|
|
|
|
has_model_image = model_image is not None |
|
|
has_garment_image = garment_image is not None |
|
|
|
|
|
if has_model_image: |
|
|
|
|
|
if has_garment_image: |
|
|
input_height, input_width = model_image.size[1], model_image.size[0] |
|
|
model_image, lp, tp, rp, bp = resize_and_pad_to_size(model_image, width, height) |
|
|
else: |
|
|
model_image = resize_by_height(model_image, height) |
|
|
concat_image_list.append(model_image) |
|
|
|
|
|
if has_garment_image: |
|
|
|
|
|
garment_image = resize_by_height(garment_image, height) |
|
|
concat_image_list.append(garment_image) |
|
|
|
|
|
image = Image.fromarray(np.concatenate([np.array(img) for img in concat_image_list], axis=1)) |
|
|
|
|
|
mask = np.zeros_like(np.array(image)) |
|
|
mask[:,:width] = 255 |
|
|
mask_image = Image.fromarray(mask) |
|
|
|
|
|
|
|
|
image = self.pipe( |
|
|
prompt=prompt, |
|
|
prompt_2=prompt_2, |
|
|
image=image, |
|
|
mask_image=mask_image, |
|
|
strength=1., |
|
|
height=height, |
|
|
width=image.width, |
|
|
target_width=width, |
|
|
tryon=has_model_image and has_garment_image, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
max_sequence_length=512, |
|
|
generator=torch.Generator(device4GEN).manual_seed(seed), |
|
|
output_type="pil", |
|
|
).images[0] |
|
|
|
|
|
if has_model_image and has_garment_image: |
|
|
image = image.crop((lp, tp, image.width-rp, image.height-bp)).resize((input_width, input_height)) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
change_clothes = changeClothes() |
|
|
|
|
|
|
|
|
virtal_tryon_prompt = "<MODEL> a person with fashion garment. <GARMENT> a garment. <TARGET> model with fashion garment" |
|
|
|
|
|
deice4GEN="cuda" |
|
|
model_image_path = "./images/model/model0.jpg" |
|
|
garment_image_path_1 = "./images/garment/garment5.jpg" |
|
|
garment_image_path_2 = "./images/garment/garment6.jpg" |
|
|
|
|
|
|
|
|
print('Warming up ......') |
|
|
output_image = change_clothes.generate_image( |
|
|
model_image=model_image_path, |
|
|
garment_image=garment_image_path_1, |
|
|
prompt=virtal_tryon_prompt, |
|
|
|
|
|
seed=1, |
|
|
guidance_scale=5, |
|
|
num_inference_steps=5, |
|
|
device4GEN=deice4GEN |
|
|
) |
|
|
|
|
|
|
|
|
print('Testing 1 ......') |
|
|
t = time.time() |
|
|
|
|
|
output_image = change_clothes.generate_image( |
|
|
model_image=model_image_path, |
|
|
garment_image=garment_image_path_1, |
|
|
prompt=virtal_tryon_prompt, |
|
|
|
|
|
seed=2, |
|
|
guidance_scale=5, |
|
|
num_inference_steps=30, |
|
|
device4GEN=deice4GEN |
|
|
) |
|
|
|
|
|
print(f"Time: {time.time()-t}") |
|
|
|
|
|
output_path = "./output/output_image_1.png" |
|
|
output_image.save(output_path) |
|
|
print(f"Image saved to {output_path}") |
|
|
|
|
|
|
|
|
print('Testing 2 ......') |
|
|
t = time.time() |
|
|
|
|
|
output_image = change_clothes.generate_image( |
|
|
model_image=model_image_path, |
|
|
garment_image=garment_image_path_2, |
|
|
prompt=virtal_tryon_prompt, |
|
|
|
|
|
seed=2, |
|
|
guidance_scale=5, |
|
|
num_inference_steps=30, |
|
|
device4GEN=deice4GEN |
|
|
) |
|
|
|
|
|
print(f"Time: {time.time()-t}") |
|
|
|
|
|
output_path = "./output/output_image_2.png" |
|
|
output_image.save(output_path) |
|
|
print(f"Image saved to {output_path}") |
|
|
|