va1bhavagrawa1's picture
setup repo as python package
9d1d414
import os
import os.path as osp
import json
from PIL import Image
from tqdm import tqdm
import torch
import random
from train import *
pretrained_path = "black-forest-labs/FLUX.1-dev"
OUTPUT_DIR = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds"
DATASET_JSONL = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard/cuboids__234subjects.jsonl"
def load_text_encoders(class_one, class_two):
text_encoder_one = class_one.from_pretrained(
pretrained_path, subfolder="text_encoder", revision=None, variant=None
)
text_encoder_two = class_two.from_pretrained(
pretrained_path, subfolder="text_encoder_2", revision=None, variant=None
)
return text_encoder_one, text_encoder_two
def load_prompts_from_jsonl(DATASET_JSONL):
"""
Load unique prompts from JSONL file with subjects filled in.
Args:
DATASET_JSONL: Path to the JSONL file
Returns:
List of unique prompts with subjects filled in
"""
unique_prompts = set()
print(f"Reading prompts from {DATASET_JSONL}...")
with open(DATASET_JSONL, 'r') as f:
for line in f:
entry = json.loads(line)
# Get the placeholder prompt and subjects
placeholder_prompt = entry.get("PLACEHOLDER_prompts", "")
subjects = entry.get("subjects", [])
if not placeholder_prompt or not subjects:
continue
# Create placeholder text from subjects
placeholder_text = ""
for subject_idx, subject in enumerate(subjects):
if subject_idx == 0:
placeholder_text = placeholder_text + f"{subject}"
else:
placeholder_text = placeholder_text + f" and {subject}"
# Replace PLACEHOLDER with actual subjects
filled_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text)
unique_prompts.add(filled_prompt)
prompts_list = list(unique_prompts)
print(f"Found {len(prompts_list)} unique prompts after filling subjects.")
return prompts_list
if __name__ == "__main__":
# import correct text encoder classes
with torch.no_grad():
accelerator = Accelerator()
text_encoder_cls_one = import_model_class_from_model_name_or_path(
pretrained_path, revision=None,
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
pretrained_path, subfolder="text_encoder_2", revision=None,
)
text_encoder_one, text_encoder_two = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two
)
text_encoder_one = text_encoder_one.to(accelerator.device)
text_encoder_two = text_encoder_two.to(accelerator.device)
tokenizer_one = CLIPTokenizer.from_pretrained(
pretrained_path,
subfolder="tokenizer",
revision=None,
)
tokenizer_two = T5TokenizerFast.from_pretrained(
pretrained_path,
subfolder="tokenizer_2",
revision=None,
)
# Load unique prompts from JSONL
all_prompts = load_prompts_from_jsonl(DATASET_JSONL)
# Add empty string and space for negative embeds
all_prompts.extend(["", " "])
print(f"Total prompts to cache (including negative embeds): {len(all_prompts)}")
random.seed()
random.shuffle(all_prompts) # if this is run on multiple processes, then random shuffling will reduce the chances of multiple processes trying to cache the same prompt at the same time
for prompt in all_prompts:
if prompt == "":
latents_path = osp.join(OUTPUT_DIR, "negative_embeds.pth")
elif prompt == " ":
latents_path = osp.join(OUTPUT_DIR, "space_prompt.pth")
else:
latents_path = osp.join(OUTPUT_DIR, f"{'_'.join(prompt.split())}.pth")
latents_path = latents_path.replace("____", "__")
if osp.exists(latents_path):
print(f"Embeds for {prompt = } already cached at {latents_path}, skipping...")
continue
print(f"doing {prompt = }")
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[tokenizer_one, tokenizer_two],
prompt=[prompt],
max_sequence_length=512,
device=accelerator.device
)
assert torch.allclose(text_ids, torch.zeros_like(text_ids)), f"{text_ids = }"
os.makedirs(osp.dirname(latents_path), exist_ok=True)
embeds = {
"prompt": prompt,
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
}
torch.save(embeds, latents_path)
accelerator.wait_for_everyone()
print(f"\nCaching complete! Embeddings saved to {OUTPUT_DIR}")