| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import glob |
| | import hashlib |
| |
|
| | import pandas as pd |
| | import torch |
| | from transformers import T5EncoderModel |
| |
|
| | from diffusers import StableDiffusion3Pipeline |
| |
|
| |
|
| | PROMPT = "a photo of sks dog" |
| | MAX_SEQ_LENGTH = 77 |
| | LOCAL_DATA_DIR = "dog" |
| | OUTPUT_PATH = "sample_embeddings.parquet" |
| |
|
| |
|
| | def bytes_to_giga_bytes(bytes): |
| | return bytes / 1024 / 1024 / 1024 |
| |
|
| |
|
| | def generate_image_hash(image_path): |
| | with open(image_path, "rb") as f: |
| | img_data = f.read() |
| | return hashlib.sha256(img_data).hexdigest() |
| |
|
| |
|
| | def load_sd3_pipeline(): |
| | id = "stabilityai/stable-diffusion-3-medium-diffusers" |
| | text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto") |
| | pipeline = StableDiffusion3Pipeline.from_pretrained( |
| | id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced" |
| | ) |
| | return pipeline |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_embeddings(pipeline, prompt, max_sequence_length): |
| | ( |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds, |
| | ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length) |
| |
|
| | print( |
| | f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}" |
| | ) |
| |
|
| | max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) |
| | print(f"Max memory allocated: {max_memory:.3f} GB") |
| | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds |
| |
|
| |
|
| | def run(args): |
| | pipeline = load_sd3_pipeline() |
| | prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings( |
| | pipeline, args.prompt, args.max_sequence_length |
| | ) |
| |
|
| | |
| | |
| | image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg") |
| | data = [] |
| | for image_path in image_paths: |
| | img_hash = generate_image_hash(image_path) |
| | data.append( |
| | (img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) |
| | ) |
| |
|
| | |
| | embedding_cols = [ |
| | "prompt_embeds", |
| | "negative_prompt_embeds", |
| | "pooled_prompt_embeds", |
| | "negative_pooled_prompt_embeds", |
| | ] |
| | df = pd.DataFrame( |
| | data, |
| | columns=["image_hash"] + embedding_cols, |
| | ) |
| |
|
| | |
| | for col in embedding_cols: |
| | df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist()) |
| |
|
| | |
| | df.to_parquet(args.output_path) |
| | print(f"Data successfully serialized to {args.output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.") |
| | parser.add_argument( |
| | "--max_sequence_length", |
| | type=int, |
| | default=MAX_SEQ_LENGTH, |
| | help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.", |
| | ) |
| | parser.add_argument( |
| | "--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images." |
| | ) |
| | parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.") |
| | args = parser.parse_args() |
| |
|
| | run(args) |
| |
|