Spaces:
Paused
Paused
| import os | |
| import torch | |
| from PIL import Image | |
| from src.models.stage1_prior_transformer import Stage1_PriorTransformer | |
| from src.pipelines.stage1_prior_pipeline import Stage1_PriorPipeline | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| CLIPVisionModelWithProjection, | |
| CLIPImageProcessor, | |
| ) | |
| import argparse | |
| import numpy as np | |
| import torch.multiprocessing as mp | |
| import json | |
| import time | |
| # Read a text file and convert the coordinates into a tensor | |
| def read_coordinates_file(file_path): | |
| coordinates_list = [] | |
| with open(file_path, 'r') as file: | |
| for line in file: | |
| x, y = map(float, line.strip().split()) | |
| coordinates_list.extend([x, y]) | |
| coordinates_tensor = torch.tensor(coordinates_list, dtype=torch.float32).view(1, -1) | |
| return coordinates_tensor | |
| def split_list_into_chunks(lst, n): | |
| chunk_size = len(lst) // n | |
| chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] | |
| if len(chunks) > n: | |
| last_chunk = chunks.pop() | |
| chunks[-1].extend(last_chunk) | |
| return chunks | |
| def main(args): | |
| device = torch.device("cuda") | |
| generator = torch.Generator(device=device).manual_seed(args.seed_number) | |
| # save path | |
| save_dir = "{}/guidancescale{}_seed{}_numsteps{}/".format(args.save_path, args.guidance_scale, args.seed_number, args.num_inference_steps) | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| # prepare data aug | |
| clip_image_processor = CLIPImageProcessor() | |
| # prepare model | |
| model_ckpt = args.weights_name | |
| pipe = Stage1_PriorPipeline.from_pretrained(args.pretrained_model_name_or_path).to(device) | |
| pipe.prior= Stage1_PriorTransformer.from_pretrained(args.pretrained_model_name_or_path, subfolder="prior", num_embeddings=2,embedding_dim=1024, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device) | |
| prior_dict = torch.load(model_ckpt, map_location="cpu")["module"] | |
| pipe.prior.load_state_dict(prior_dict) | |
| pipe.enable_xformers_memory_efficient_attention() | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).eval().to(device) | |
| print('====================== model load finish ===================') | |
| # start test | |
| start_time = time.time() | |
| #prepare data | |
| s_img_path = 'imgs/sm.png' | |
| t_img_path = 'imgs/target.png' | |
| s_pose_path = args.pose_path + select_test_data['source_image'].replace('.jpg', '.txt') | |
| t_pose_path = (args.pose_path + select_test_data["target_image"].replace(".jpg", ".txt")) | |
| # image_pair | |
| s_image = Image.open(s_img_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC) | |
| #t_image = Image.open(t_img_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC) | |
| s_pose = read_coordinates_file(s_pose_path).to(device).unsqueeze(1) | |
| t_pose = read_coordinates_file(t_pose_path).to(device).unsqueeze(1) | |
| clip_s_image = clip_image_processor(images=s_image, return_tensors="pt").pixel_values | |
| #clip_t_image = clip_image_processor(images=t_image, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| s_img_embed = (image_encoder(clip_s_image.to(device)).image_embeds).unsqueeze(1) | |
| #target_embed = image_encoder(clip_t_image.to(device)).image_embeds | |
| output = pipe( | |
| s_embed = s_img_embed, | |
| s_pose = s_pose, | |
| t_pose = t_pose, | |
| num_images_per_prompt=1, | |
| num_inference_steps = args.num_inference_steps, | |
| generator = generator, | |
| guidance_scale = args.guidance_scale, | |
| ) | |
| # save features | |
| feature = output[0].cpu().detach().numpy() | |
| np.save('embed.npy', feature) | |
| # computer scores | |
| predict_embed = output[0] | |
| #cosine_similarities = F.cosine_similarity(predict_embed, target_embed) | |
| #sum_simm += cosine_similarities.item() | |
| end_time =time.time() | |
| print(end_time-start_time) | |
| """ | |
| avg_simm = sum_simm/number | |
| with open (save_dir+'/a_results.txt', 'a') as ff: | |
| ff.write('number is {}, guidance_scale is {}, all averge simm is :{} \n'.format(number, args.guidance_scale, avg_simm)) | |
| print('number is {}, guidance_scale is {}, all averge simm is :{}'.format(number, args.guidance_scale, avg_simm)) | |
| """ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Simple example of a prior model of stage1 script.") | |
| parser.add_argument("--pretrained_model_name_or_path",type=str,default="./kandinsky-2-2-prior", | |
| help="Path to pretrained model or model identifier from huggingface.co/models.",) | |
| parser.add_argument("--image_encoder_path",type=str,default="./OpenCLIP-ViT-H-14", | |
| help="Path to pretrained model or model identifier from huggingface.co/models.",) | |
| parser.add_argument("--img_path", type=str, default="./datasets/deepfashing/train_all_png/", help="image path", ) | |
| parser.add_argument("--pose_path", type=str, default="./datasets/deepfashing/normalized_pose_txt/", help="pose path", ) | |
| parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/test_data.json", help="json path", ) | |
| parser.add_argument("--save_path", type=str, default="./save_data/stage1", help="save path", ) | |
| parser.add_argument("--guidance_scale",type=int,default=0,help="guidance_scale",) | |
| parser.add_argument("--seed_number",type=int,default=42,help="seed number",) | |
| parser.add_argument("--num_inference_steps",type=int,default=20,help="num_inference_steps",) | |
| parser.add_argument("--img_width",type=int,default=512,help="image width",) | |
| parser.add_argument("--img_height",type=int,default=512,help="image height",) | |
| parser.add_argument("--weights_name",type=str,default="s1_512.pt",help="weights number",) | |
| args = parser.parse_args() | |
| print(args) | |
| """ | |
| # Set the number of GPUs. | |
| num_devices = torch.cuda.device_count() | |
| print("Using {} GPUs inference".format(num_devices)) | |
| # load data | |
| test_data = json.load(open(args.json_path)) | |
| select_test_datas = test_data | |
| print('The number of test data: {}'.format(len(select_test_datas))) | |
| # Create a process pool | |
| mp.set_start_method("spawn") | |
| data_list = split_list_into_chunks(select_test_datas, num_devices) | |
| processes = [] | |
| for rank in range(num_devices): | |
| p = mp.Process(target=main, args=(args,rank, data_list[rank], )) | |
| processes.append(p) | |
| p.start() | |
| for rank, p in enumerate(processes): | |
| p.join() | |
| """ | |
| main(args) | |