import openvino import torch import os import time from PIL import Image import numpy as np class OVWrapperTextEncoder: def __init__(self, ov_model, core, device="CPU", torch_dtype=torch.bfloat16): self.model = ov_model self.compiled_model = core.compile_model(ov_model, device) self.device = device self.dtype = torch_dtype # Lấy các tên đầu vào/ra của mô hình self.input_key = self.compiled_model.inputs[0].get_any_name() self.output_key = self.compiled_model.outputs[0].get_any_name() self.config = type('obj', (object,), { "d_model": 768, "hidden_size": 768, }) def __call__(self, input_ids, attention_mask=None, return_dict=False, output_hidden_states=False, **kwargs): # Chuyển đổi torch tensor sang numpy if isinstance(input_ids, torch.Tensor): input_ids = input_ids.cpu().numpy() # Chuẩn bị input inputs = {self.input_key: input_ids} if attention_mask is not None and len(self.compiled_model.inputs) > 1: attn_key = self.compiled_model.inputs[1].get_any_name() inputs[attn_key] = attention_mask.cpu().numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask # Thực hiện suy luận outputs = self.compiled_model(inputs) # Chuyển đổi kết quả về torch tensor last_hidden_state = torch.from_numpy(outputs[self.output_key]).to(self.dtype) return (last_hidden_state,) def crop_to_multiple_of_16(img): """ Cắt ảnh để kích thước là bội số của 16. Parameters: img (PIL.Image): Ảnh đầu vào cần cắt. Returns: PIL.Image: Ảnh đã được cắt với kích thước là bội số của 16. """ width, height = img.size # Calculate new dimensions that are multiples of 8 new_width = width - (width % 16) new_height = height - (height % 16) # Calculate crop box coordinates left = (width - new_width) // 2 top = (height - new_height) // 2 right = left + new_width bottom = top + new_height # Crop the image cropped_img = img.crop((left, top, right, bottom)) return cropped_img def resize_and_pad_to_size(image, target_width, target_height): """ Thay đổi kích thước ảnh và thêm padding để đạt kích thước mục tiêu mà không làm biến dạng ảnh. Parameters: image (PIL.Image hoặc np.ndarray): Ảnh đầu vào cần xử lý. target_width (int): Chiều rộng mục tiêu của ảnh sau khi xử lý. target_height (int): Chiều cao mục tiêu của ảnh sau khi xử lý. Returns: tuple: Bao gồm: - PIL.Image: Ảnh đã được resize và padding. - int: Khoảng cách padding bên trái. - int: Khoảng cách padding bên trên. - int: Khoảng cách padding bên phải. - int: Khoảng cách padding bên dưới. """ # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Get original dimensions orig_width, orig_height = image.size # Calculate aspect ratios target_ratio = target_width / target_height orig_ratio = orig_width / orig_height # Calculate new dimensions while maintaining aspect ratio if orig_ratio > target_ratio: # Image is wider than target ratio - scale by width new_width = target_width new_height = int(new_width / orig_ratio) else: # Image is taller than target ratio - scale by height new_height = target_height new_width = int(new_height * orig_ratio) # Resize image resized_image = image.resize((new_width, new_height)) # Create white background image of target size padded_image = Image.new('RGB', (target_width, target_height), 'white') # Calculate padding to center the image left_padding = (target_width - new_width) // 2 top_padding = (target_height - new_height) // 2 # Paste resized image onto padded background padded_image.paste(resized_image, (left_padding, top_padding)) return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding def resize_by_height(image, height): """ Thay đổi kích thước ảnh theo chiều cao đã cho và cắt ảnh để kích thước là bội số của 16. Parameters: image (PIL.Image hoặc np.ndarray): Ảnh đầu vào cần xử lý. height (int): Chiều cao mục tiêu của ảnh sau khi resize. Returns: PIL.Image: Ảnh đã được resize theo chiều cao và cắt để kích thước là bội số của 16. """ if isinstance(image, np.ndarray): image = Image.fromarray(image) # image is a PIL image image = image.resize((int(image.width * height / image.height), height)) return crop_to_multiple_of_16(image) from src.pipeline_flux_tryon import FluxTryonPipeline class changeClothes(): def __init__(self, local_dir="./flux_models"): """ local_dir: Đường dẫn đến thư mục đã lưu model """ self.device = torch.device("cuda") self.local_dir = local_dir self.pipe = self.load_models() def load_models(self): from diffusers import FluxTransformer2DModel, BitsAndBytesConfig print("🔄 Loading 4bit transformer ...") # 1. Quantize Transformer (core component) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True ) # Load quantized transformer transformer = FluxTransformer2DModel.from_pretrained( f"{self.local_dir}/transformer", quantization_config=bnb_config, torch_dtype=torch.bfloat16 ) # Load với hardware quantization print("🔄 Loading optimized FLUX pipeline...") pipe = FluxTryonPipeline.from_pretrained( self.local_dir, transformer=transformer, torch_dtype=torch.bfloat16 ) core = openvino.runtime.Core() ov_text_encoder_2 = core.read_model(f"{self.local_dir}/OV_text_encoder_2/openvino_model.xml") pipe.text_encoder_2 = OVWrapperTextEncoder(ov_text_encoder_2, core, torch_dtype=torch.bfloat16) # Load LoRA print("Loading LoRA weights .....") pipe.load_lora_weights( "loooooong/Any2anyTryon", weight_name="dev_lora_any2any_tryon.safetensors", adapter_name="tryon" ) # Fuse LoRA for speed pipe.fuse_lora() # Memory optimizations pipe.enable_attention_slicing() pipe.vae.enable_slicing() pipe.vae.enable_tiling() print("Load to CUDA ...........") pipe.to("cuda") return pipe @staticmethod def check_image(image_input): """ Kiểm tra và xử lý đầu vào ảnh Args: image_input: Có thể là None, đối tượng Image từ PIL hoặc đường dẫn đến file ảnh Returns: None nếu đầu vào là None, ngược lại trả về đối tượng Image từ PIL """ # Trường hợp 1: input là None if image_input is None: return None # Trường hợp 2: input là đường dẫn đến file ảnh if isinstance(image_input, str) and os.path.isfile(image_input): try: return Image.open(image_input) except Exception as e: raise ValueError(f"Không thể mở file ảnh: {e}") # Trường hợp 3: input đã là một loại ảnh nào đó try: # Kiểm tra xem đã có thuộc tính size chưa if hasattr(image_input, 'size'): return image_input # Nếu là numpy array, chuyển sang PIL Image if isinstance(image_input, np.ndarray): return Image.fromarray(image_input) # Nếu là đối tượng có thể chuyển về ảnh (như OpenCV image) # Thử chuyển đổi thành Image return Image.fromarray(np.array(image_input)) except Exception as e: raise ValueError(f"Không thể chuyển đối tượng thành Image: {e}") @torch.no_grad() def generate_image(self, model_image, garment_image, prompt="", prompt_2=None, height=384, width=288, seed=0, guidance_scale=3.5, num_inference_steps=30, device4GEN="cuda"): """ Tạo ảnh thử đồ ảo dựa trên mô hình người và hình ảnh quần áo đầu vào. Parameters: model_image (PIL.Image, optional): Ảnh người mẫu, có thể là None. garment_image (PIL.Image, optional): Ảnh quần áo cần thử, có thể là None. prompt (str, optional): Mô tả text cho quá trình tạo ảnh. Mặc định là chuỗi rỗng. prompt_2 (str, optional): Mô tả text thứ hai (nếu cần). Mặc định là None. height (int, optional): Chiều cao của ảnh xử lý. Mặc định là 512. width (int, optional): Chiều rộng của ảnh xử lý. Mặc định là 384. seed (int, optional): Giá trị seed để đảm bảo tính tái sản xuất. Mặc định là 0. guidance_scale (float, optional): Mức độ ảnh hưởng của prompt lên quá trình tạo ảnh. Mặc định là 3.5. num_inference_steps (int, optional): Số bước lặp trong quá trình tạo ảnh. Mặc định là 30. Returns: PIL.Image: Ảnh kết quả sau khi thực hiện thử đồ ảo. Trong trường hợp thử đồ (cả model và garment), ảnh sẽ được resize về kích thước gốc của ảnh người mẫu. Notes: - Hàm này sử dụng decorator @torch.no_grad() để tắt gradient trong quá trình inference. - Kích thước xử lý sẽ được điều chỉnh để là bội số của 16. - Nếu cả model_image và garment_image được cung cấp, kết quả sẽ là ảnh thử đồ ảo và được trả về với kích thước gốc của ảnh người mẫu đầu vào. - Nếu chỉ cung cấp một trong hai, kết quả sẽ là ảnh được tạo dựa trên prompt và ảnh được cung cấp. """ height, width = int(height), int(width) width = width - (width % 16) height = height - (height % 16) concat_image_list = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))] # Kiểm tra tính hợp lệ của ảnh, đồng thời đọc ảnh bằng PIL nếu là path model_image = self.check_image(model_image) garment_image = self.check_image(garment_image) has_model_image = model_image is not None has_garment_image = garment_image is not None if has_model_image: # model_image = Image.open(model_image_path) if has_garment_image: input_height, input_width = model_image.size[1], model_image.size[0] model_image, lp, tp, rp, bp = resize_and_pad_to_size(model_image, width, height) else: model_image = resize_by_height(model_image, height) concat_image_list.append(model_image) if has_garment_image: # garment_image = Image.open(garment_image_path) garment_image = resize_by_height(garment_image, height) concat_image_list.append(garment_image) image = Image.fromarray(np.concatenate([np.array(img) for img in concat_image_list], axis=1)) mask = np.zeros_like(np.array(image)) mask[:,:width] = 255 mask_image = Image.fromarray(mask) image = self.pipe( prompt=prompt, prompt_2=prompt_2, image=image, mask_image=mask_image, strength=1., height=height, width=image.width, target_width=width, tryon=has_model_image and has_garment_image, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, max_sequence_length=512, generator=torch.Generator(device4GEN).manual_seed(seed), output_type="pil", ).images[0] if has_model_image and has_garment_image: image = image.crop((lp, tp, image.width-rp, image.height-bp)).resize((input_width, input_height)) return image change_clothes = changeClothes() virtal_tryon_prompt = " a person with fashion garment. a garment. model with fashion garment" # prompt_2 = "remove the black outerwear, wear only the new garment" deice4GEN="cuda" model_image_path = "./images/model/model0.jpg" garment_image_path_1 = "./images/garment/garment5.jpg" garment_image_path_2 = "./images/garment/garment6.jpg" print('Warming up ......') output_image = change_clothes.generate_image( model_image=model_image_path, garment_image=garment_image_path_1, prompt=virtal_tryon_prompt, # prompt_2=prompt_2, seed=1, guidance_scale=5, num_inference_steps=5, device4GEN=deice4GEN ) print('Testing 1 ......') t = time.time() output_image = change_clothes.generate_image( model_image=model_image_path, garment_image=garment_image_path_1, prompt=virtal_tryon_prompt, # prompt_2=prompt_2, seed=2, guidance_scale=5, num_inference_steps=30, device4GEN=deice4GEN ) print(f"Time: {time.time()-t}") output_path = "./output/output_image_1.png" output_image.save(output_path) print(f"Image saved to {output_path}") print('Testing 2 ......') t = time.time() output_image = change_clothes.generate_image( model_image=model_image_path, garment_image=garment_image_path_2, prompt=virtal_tryon_prompt, # prompt_2=prompt_2, seed=2, guidance_scale=5, num_inference_steps=30, device4GEN=deice4GEN ) print(f"Time: {time.time()-t}") output_path = "./output/output_image_2.png" output_image.save(output_path) print(f"Image saved to {output_path}")