| | import os |
| | import os.path as osp |
| | import json |
| | from PIL import Image |
| | from tqdm import tqdm |
| | import torch |
| | import random |
| | from train import * |
| |
|
| | pretrained_path = "black-forest-labs/FLUX.1-dev" |
| | OUTPUT_DIR = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds" |
| | DATASET_JSONL = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard/cuboids__234subjects.jsonl" |
| |
|
| |
|
| | def load_text_encoders(class_one, class_two): |
| | text_encoder_one = class_one.from_pretrained( |
| | pretrained_path, subfolder="text_encoder", revision=None, variant=None |
| | ) |
| | text_encoder_two = class_two.from_pretrained( |
| | pretrained_path, subfolder="text_encoder_2", revision=None, variant=None |
| | ) |
| | return text_encoder_one, text_encoder_two |
| |
|
| |
|
| | def load_prompts_from_jsonl(DATASET_JSONL): |
| | """ |
| | Load unique prompts from JSONL file with subjects filled in. |
| | |
| | Args: |
| | DATASET_JSONL: Path to the JSONL file |
| | |
| | Returns: |
| | List of unique prompts with subjects filled in |
| | """ |
| | unique_prompts = set() |
| | |
| | print(f"Reading prompts from {DATASET_JSONL}...") |
| | with open(DATASET_JSONL, 'r') as f: |
| | for line in f: |
| | entry = json.loads(line) |
| | |
| | |
| | placeholder_prompt = entry.get("PLACEHOLDER_prompts", "") |
| | subjects = entry.get("subjects", []) |
| | |
| | if not placeholder_prompt or not subjects: |
| | continue |
| | |
| | |
| | placeholder_text = "" |
| | for subject_idx, subject in enumerate(subjects): |
| | if subject_idx == 0: |
| | placeholder_text = placeholder_text + f"{subject}" |
| | else: |
| | placeholder_text = placeholder_text + f" and {subject}" |
| | |
| | |
| | filled_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text) |
| | unique_prompts.add(filled_prompt) |
| | |
| | prompts_list = list(unique_prompts) |
| | print(f"Found {len(prompts_list)} unique prompts after filling subjects.") |
| | |
| | return prompts_list |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | with torch.no_grad(): |
| | accelerator = Accelerator() |
| | text_encoder_cls_one = import_model_class_from_model_name_or_path( |
| | pretrained_path, revision=None, |
| | ) |
| | text_encoder_cls_two = import_model_class_from_model_name_or_path( |
| | pretrained_path, subfolder="text_encoder_2", revision=None, |
| | ) |
| |
|
| | text_encoder_one, text_encoder_two = load_text_encoders( |
| | text_encoder_cls_one, text_encoder_cls_two |
| | ) |
| | text_encoder_one = text_encoder_one.to(accelerator.device) |
| | text_encoder_two = text_encoder_two.to(accelerator.device) |
| |
|
| | tokenizer_one = CLIPTokenizer.from_pretrained( |
| | pretrained_path, |
| | subfolder="tokenizer", |
| | revision=None, |
| | ) |
| | tokenizer_two = T5TokenizerFast.from_pretrained( |
| | pretrained_path, |
| | subfolder="tokenizer_2", |
| | revision=None, |
| | ) |
| |
|
| | |
| | |
| | all_prompts = load_prompts_from_jsonl(DATASET_JSONL) |
| | |
| | |
| | all_prompts.extend(["", " "]) |
| | |
| | print(f"Total prompts to cache (including negative embeds): {len(all_prompts)}") |
| |
|
| | random.seed() |
| | random.shuffle(all_prompts) |
| | for prompt in all_prompts: |
| | if prompt == "": |
| | latents_path = osp.join(OUTPUT_DIR, "negative_embeds.pth") |
| | elif prompt == " ": |
| | latents_path = osp.join(OUTPUT_DIR, "space_prompt.pth") |
| | else: |
| | latents_path = osp.join(OUTPUT_DIR, f"{'_'.join(prompt.split())}.pth") |
| | latents_path = latents_path.replace("____", "__") |
| | if osp.exists(latents_path): |
| | print(f"Embeds for {prompt = } already cached at {latents_path}, skipping...") |
| | continue |
| | print(f"doing {prompt = }") |
| | prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( |
| | text_encoders=[text_encoder_one, text_encoder_two], |
| | tokenizers=[tokenizer_one, tokenizer_two], |
| | prompt=[prompt], |
| | max_sequence_length=512, |
| | device=accelerator.device |
| | ) |
| | assert torch.allclose(text_ids, torch.zeros_like(text_ids)), f"{text_ids = }" |
| |
|
| | os.makedirs(osp.dirname(latents_path), exist_ok=True) |
| | embeds = { |
| | "prompt": prompt, |
| | "prompt_embeds": prompt_embeds, |
| | "pooled_prompt_embeds": pooled_prompt_embeds, |
| | } |
| | torch.save(embeds, latents_path) |
| | accelerator.wait_for_everyone() |
| | print(f"\nCaching complete! Embeddings saved to {OUTPUT_DIR}") |