BF_TryOn / test_4bit.py
hnpinq's picture
Update test_4bit.py
f2c6e39 verified
import openvino
import torch
import os
import time
from PIL import Image
import numpy as np
class OVWrapperTextEncoder:
def __init__(self, ov_model, core, device="CPU", torch_dtype=torch.bfloat16):
self.model = ov_model
self.compiled_model = core.compile_model(ov_model, device)
self.device = device
self.dtype = torch_dtype
# Lấy các tên đầu vào/ra của mô hình
self.input_key = self.compiled_model.inputs[0].get_any_name()
self.output_key = self.compiled_model.outputs[0].get_any_name()
self.config = type('obj', (object,), {
"d_model": 768,
"hidden_size": 768,
})
def __call__(self, input_ids, attention_mask=None, return_dict=False, output_hidden_states=False, **kwargs):
# Chuyển đổi torch tensor sang numpy
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.cpu().numpy()
# Chuẩn bị input
inputs = {self.input_key: input_ids}
if attention_mask is not None and len(self.compiled_model.inputs) > 1:
attn_key = self.compiled_model.inputs[1].get_any_name()
inputs[attn_key] = attention_mask.cpu().numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask
# Thực hiện suy luận
outputs = self.compiled_model(inputs)
# Chuyển đổi kết quả về torch tensor
last_hidden_state = torch.from_numpy(outputs[self.output_key]).to(self.dtype)
return (last_hidden_state,)
def crop_to_multiple_of_16(img):
"""
Cắt ảnh để kích thước là bội số của 16.
Parameters:
img (PIL.Image): Ảnh đầu vào cần cắt.
Returns:
PIL.Image: Ảnh đã được cắt với kích thước là bội số của 16.
"""
width, height = img.size
# Calculate new dimensions that are multiples of 8
new_width = width - (width % 16)
new_height = height - (height % 16)
# Calculate crop box coordinates
left = (width - new_width) // 2
top = (height - new_height) // 2
right = left + new_width
bottom = top + new_height
# Crop the image
cropped_img = img.crop((left, top, right, bottom))
return cropped_img
def resize_and_pad_to_size(image, target_width, target_height):
"""
Thay đổi kích thước ảnh và thêm padding để đạt kích thước mục tiêu mà không làm biến dạng ảnh.
Parameters:
image (PIL.Image hoặc np.ndarray): Ảnh đầu vào cần xử lý.
target_width (int): Chiều rộng mục tiêu của ảnh sau khi xử lý.
target_height (int): Chiều cao mục tiêu của ảnh sau khi xử lý.
Returns:
tuple: Bao gồm:
- PIL.Image: Ảnh đã được resize và padding.
- int: Khoảng cách padding bên trái.
- int: Khoảng cách padding bên trên.
- int: Khoảng cách padding bên phải.
- int: Khoảng cách padding bên dưới.
"""
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Get original dimensions
orig_width, orig_height = image.size
# Calculate aspect ratios
target_ratio = target_width / target_height
orig_ratio = orig_width / orig_height
# Calculate new dimensions while maintaining aspect ratio
if orig_ratio > target_ratio:
# Image is wider than target ratio - scale by width
new_width = target_width
new_height = int(new_width / orig_ratio)
else:
# Image is taller than target ratio - scale by height
new_height = target_height
new_width = int(new_height * orig_ratio)
# Resize image
resized_image = image.resize((new_width, new_height))
# Create white background image of target size
padded_image = Image.new('RGB', (target_width, target_height), 'white')
# Calculate padding to center the image
left_padding = (target_width - new_width) // 2
top_padding = (target_height - new_height) // 2
# Paste resized image onto padded background
padded_image.paste(resized_image, (left_padding, top_padding))
return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding
def resize_by_height(image, height):
"""
Thay đổi kích thước ảnh theo chiều cao đã cho và cắt ảnh để kích thước là bội số của 16.
Parameters:
image (PIL.Image hoặc np.ndarray): Ảnh đầu vào cần xử lý.
height (int): Chiều cao mục tiêu của ảnh sau khi resize.
Returns:
PIL.Image: Ảnh đã được resize theo chiều cao và cắt để kích thước là bội số của 16.
"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# image is a PIL image
image = image.resize((int(image.width * height / image.height), height))
return crop_to_multiple_of_16(image)
from src.pipeline_flux_tryon import FluxTryonPipeline
class changeClothes():
def __init__(self, local_dir="./flux_models"):
"""
local_dir: Đường dẫn đến thư mục đã lưu model
"""
self.device = torch.device("cuda")
self.local_dir = local_dir
self.pipe = self.load_models()
def load_models(self):
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
print("🔄 Loading 4bit transformer ...")
# 1. Quantize Transformer (core component)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
# Load quantized transformer
transformer = FluxTransformer2DModel.from_pretrained(
f"{self.local_dir}/transformer",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16
)
# Load với hardware quantization
print("🔄 Loading optimized FLUX pipeline...")
pipe = FluxTryonPipeline.from_pretrained(
self.local_dir,
transformer=transformer,
torch_dtype=torch.bfloat16
)
core = openvino.runtime.Core()
ov_text_encoder_2 = core.read_model(f"{self.local_dir}/OV_text_encoder_2/openvino_model.xml")
pipe.text_encoder_2 = OVWrapperTextEncoder(ov_text_encoder_2, core, torch_dtype=torch.bfloat16)
# Load LoRA
print("Loading LoRA weights .....")
pipe.load_lora_weights(
"loooooong/Any2anyTryon",
weight_name="dev_lora_any2any_tryon.safetensors",
adapter_name="tryon"
)
# Fuse LoRA for speed
pipe.fuse_lora()
# Memory optimizations
pipe.enable_attention_slicing()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
print("Load to CUDA ...........")
pipe.to("cuda")
return pipe
@staticmethod
def check_image(image_input):
"""
Kiểm tra và xử lý đầu vào ảnh
Args:
image_input: Có thể là None, đối tượng Image từ PIL hoặc đường dẫn đến file ảnh
Returns:
None nếu đầu vào là None, ngược lại trả về đối tượng Image từ PIL
"""
# Trường hợp 1: input là None
if image_input is None:
return None
# Trường hợp 2: input là đường dẫn đến file ảnh
if isinstance(image_input, str) and os.path.isfile(image_input):
try:
return Image.open(image_input)
except Exception as e:
raise ValueError(f"Không thể mở file ảnh: {e}")
# Trường hợp 3: input đã là một loại ảnh nào đó
try:
# Kiểm tra xem đã có thuộc tính size chưa
if hasattr(image_input, 'size'):
return image_input
# Nếu là numpy array, chuyển sang PIL Image
if isinstance(image_input, np.ndarray):
return Image.fromarray(image_input)
# Nếu là đối tượng có thể chuyển về ảnh (như OpenCV image)
# Thử chuyển đổi thành Image
return Image.fromarray(np.array(image_input))
except Exception as e:
raise ValueError(f"Không thể chuyển đối tượng thành Image: {e}")
@torch.no_grad()
def generate_image(self, model_image, garment_image, prompt="", prompt_2=None, height=384, width=288,
seed=0, guidance_scale=3.5, num_inference_steps=30, device4GEN="cuda"):
"""
Tạo ảnh thử đồ ảo dựa trên mô hình người và hình ảnh quần áo đầu vào.
Parameters:
model_image (PIL.Image, optional): Ảnh người mẫu, có thể là None.
garment_image (PIL.Image, optional): Ảnh quần áo cần thử, có thể là None.
prompt (str, optional): Mô tả text cho quá trình tạo ảnh. Mặc định là chuỗi rỗng.
prompt_2 (str, optional): Mô tả text thứ hai (nếu cần). Mặc định là None.
height (int, optional): Chiều cao của ảnh xử lý. Mặc định là 512.
width (int, optional): Chiều rộng của ảnh xử lý. Mặc định là 384.
seed (int, optional): Giá trị seed để đảm bảo tính tái sản xuất. Mặc định là 0.
guidance_scale (float, optional): Mức độ ảnh hưởng của prompt lên quá trình tạo ảnh. Mặc định là 3.5.
num_inference_steps (int, optional): Số bước lặp trong quá trình tạo ảnh. Mặc định là 30.
Returns:
PIL.Image: Ảnh kết quả sau khi thực hiện thử đồ ảo. Trong trường hợp thử đồ (cả model và garment),
ảnh sẽ được resize về kích thước gốc của ảnh người mẫu.
Notes:
- Hàm này sử dụng decorator @torch.no_grad() để tắt gradient trong quá trình inference.
- Kích thước xử lý sẽ được điều chỉnh để là bội số của 16.
- Nếu cả model_image và garment_image được cung cấp, kết quả sẽ là ảnh thử đồ ảo và được
trả về với kích thước gốc của ảnh người mẫu đầu vào.
- Nếu chỉ cung cấp một trong hai, kết quả sẽ là ảnh được tạo dựa trên prompt và ảnh được cung cấp.
"""
height, width = int(height), int(width)
width = width - (width % 16)
height = height - (height % 16)
concat_image_list = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))]
# Kiểm tra tính hợp lệ của ảnh, đồng thời đọc ảnh bằng PIL nếu là path
model_image = self.check_image(model_image)
garment_image = self.check_image(garment_image)
has_model_image = model_image is not None
has_garment_image = garment_image is not None
if has_model_image:
# model_image = Image.open(model_image_path)
if has_garment_image:
input_height, input_width = model_image.size[1], model_image.size[0]
model_image, lp, tp, rp, bp = resize_and_pad_to_size(model_image, width, height)
else:
model_image = resize_by_height(model_image, height)
concat_image_list.append(model_image)
if has_garment_image:
# garment_image = Image.open(garment_image_path)
garment_image = resize_by_height(garment_image, height)
concat_image_list.append(garment_image)
image = Image.fromarray(np.concatenate([np.array(img) for img in concat_image_list], axis=1))
mask = np.zeros_like(np.array(image))
mask[:,:width] = 255
mask_image = Image.fromarray(mask)
image = self.pipe(
prompt=prompt,
prompt_2=prompt_2,
image=image,
mask_image=mask_image,
strength=1.,
height=height,
width=image.width,
target_width=width,
tryon=has_model_image and has_garment_image,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=torch.Generator(device4GEN).manual_seed(seed),
output_type="pil",
).images[0]
if has_model_image and has_garment_image:
image = image.crop((lp, tp, image.width-rp, image.height-bp)).resize((input_width, input_height))
return image
change_clothes = changeClothes()
virtal_tryon_prompt = "<MODEL> a person with fashion garment. <GARMENT> a garment. <TARGET> model with fashion garment"
# prompt_2 = "remove the black outerwear, wear only the new garment"
deice4GEN="cuda"
model_image_path = "./images/model/model0.jpg"
garment_image_path_1 = "./images/garment/garment5.jpg"
garment_image_path_2 = "./images/garment/garment6.jpg"
print('Warming up ......')
output_image = change_clothes.generate_image(
model_image=model_image_path,
garment_image=garment_image_path_1,
prompt=virtal_tryon_prompt,
# prompt_2=prompt_2,
seed=1,
guidance_scale=5,
num_inference_steps=5,
device4GEN=deice4GEN
)
print('Testing 1 ......')
t = time.time()
output_image = change_clothes.generate_image(
model_image=model_image_path,
garment_image=garment_image_path_1,
prompt=virtal_tryon_prompt,
# prompt_2=prompt_2,
seed=2,
guidance_scale=5,
num_inference_steps=30,
device4GEN=deice4GEN
)
print(f"Time: {time.time()-t}")
output_path = "./output/output_image_1.png"
output_image.save(output_path)
print(f"Image saved to {output_path}")
print('Testing 2 ......')
t = time.time()
output_image = change_clothes.generate_image(
model_image=model_image_path,
garment_image=garment_image_path_2,
prompt=virtal_tryon_prompt,
# prompt_2=prompt_2,
seed=2,
guidance_scale=5,
num_inference_steps=30,
device4GEN=deice4GEN
)
print(f"Time: {time.time()-t}")
output_path = "./output/output_image_2.png"
output_image.save(output_path)
print(f"Image saved to {output_path}")