#!/usr/bin/env python3 """ Extract and save empty_embeds for conditional dropout. This script extracts the empty embedding (from empty string prompt) and saves it to a file that can be loaded during training with precomputed features. """ import argparse import os import json import torch from transformers import T5EncoderModel, T5Tokenizer from dataset_utils import tokenize_prompt, encode_prompt import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Extract empty embeddings for conditional dropout") parser.add_argument( "--text_encoder_architecture", type=str, default="umt5-base", choices=["umt5-base", "umt5-xxl", "t5"], help="Text encoder architecture", ) parser.add_argument( "--output_path", type=str, required=True, help="Path to save the empty_embeds (will save as .pt file and metadata as .json)", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for encoding", ) parser.add_argument( "--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"], help="Data type for saving embeddings", ) return parser.parse_args() def main(): args = parse_args() # Map architecture to model ID if args.text_encoder_architecture == "umt5-base": model_id = "google/umt5-base" elif args.text_encoder_architecture == "umt5-xxl": model_id = "google/umt5-xxl" elif args.text_encoder_architecture == "t5": model_id = "t5-base" else: raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}") # Map dtype dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } dtype = dtype_map[args.dtype] logger.info(f"Loading text encoder: {model_id}") logger.info(f"Device: {args.device}, Dtype: {args.dtype}") # Load text encoder and tokenizer text_encoder = T5EncoderModel.from_pretrained(model_id) tokenizer = T5Tokenizer.from_pretrained(model_id) # Move to device and set dtype text_encoder.to(device=args.device, dtype=dtype) text_encoder.eval() text_encoder.requires_grad_(False) # Extract empty embedding logger.info("Extracting empty embedding from empty string...") with torch.no_grad(): empty_input_ids = tokenize_prompt(tokenizer, "", args.text_encoder_architecture) empty_input_ids = empty_input_ids.to(args.device) empty_embeds, cond_embeds = encode_prompt( text_encoder, empty_input_ids, args.text_encoder_architecture ) # Convert to CPU and target dtype empty_embeds = empty_embeds.cpu().to(dtype) logger.info(f"Empty embedding shape: {empty_embeds.shape}") logger.info(f"Empty embedding dtype: {empty_embeds.dtype}") # Save empty_embeds output_dir = os.path.dirname(args.output_path) if output_dir: os.makedirs(output_dir, exist_ok=True) # Save as .pt file torch.save(empty_embeds, args.output_path) logger.info(f"Saved empty_embeds to: {args.output_path}") # Save metadata metadata_path = args.output_path.replace('.pt', '.json') metadata = { "text_encoder_architecture": args.text_encoder_architecture, "model_id": model_id, "empty_embeds_shape": list(empty_embeds.shape), "empty_embeds_dtype": str(empty_embeds.dtype), "device": args.device, "dtype": args.dtype, } with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) logger.info(f"Saved metadata to: {metadata_path}") logger.info("Done!") if __name__ == "__main__": main()