43 / Meissonic /train /extract_empty_embeds.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
#!/usr/bin/env python3
"""
Extract and save empty_embeds for conditional dropout.
This script extracts the empty embedding (from empty string prompt)
and saves it to a file that can be loaded during training with precomputed features.
"""
import argparse
import os
import json
import torch
from transformers import T5EncoderModel, T5Tokenizer
from dataset_utils import tokenize_prompt, encode_prompt
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Extract empty embeddings for conditional dropout")
parser.add_argument(
"--text_encoder_architecture",
type=str,
default="umt5-base",
choices=["umt5-base", "umt5-xxl", "t5"],
help="Text encoder architecture",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Path to save the empty_embeds (will save as .pt file and metadata as .json)",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for encoding",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float16", "bfloat16", "float32"],
help="Data type for saving embeddings",
)
return parser.parse_args()
def main():
args = parse_args()
# Map architecture to model ID
if args.text_encoder_architecture == "umt5-base":
model_id = "google/umt5-base"
elif args.text_encoder_architecture == "umt5-xxl":
model_id = "google/umt5-xxl"
elif args.text_encoder_architecture == "t5":
model_id = "t5-base"
else:
raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
# Map dtype
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map[args.dtype]
logger.info(f"Loading text encoder: {model_id}")
logger.info(f"Device: {args.device}, Dtype: {args.dtype}")
# Load text encoder and tokenizer
text_encoder = T5EncoderModel.from_pretrained(model_id)
tokenizer = T5Tokenizer.from_pretrained(model_id)
# Move to device and set dtype
text_encoder.to(device=args.device, dtype=dtype)
text_encoder.eval()
text_encoder.requires_grad_(False)
# Extract empty embedding
logger.info("Extracting empty embedding from empty string...")
with torch.no_grad():
empty_input_ids = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
empty_input_ids = empty_input_ids.to(args.device)
empty_embeds, cond_embeds = encode_prompt(
text_encoder,
empty_input_ids,
args.text_encoder_architecture
)
# Convert to CPU and target dtype
empty_embeds = empty_embeds.cpu().to(dtype)
logger.info(f"Empty embedding shape: {empty_embeds.shape}")
logger.info(f"Empty embedding dtype: {empty_embeds.dtype}")
# Save empty_embeds
output_dir = os.path.dirname(args.output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Save as .pt file
torch.save(empty_embeds, args.output_path)
logger.info(f"Saved empty_embeds to: {args.output_path}")
# Save metadata
metadata_path = args.output_path.replace('.pt', '.json')
metadata = {
"text_encoder_architecture": args.text_encoder_architecture,
"model_id": model_id,
"empty_embeds_shape": list(empty_embeds.shape),
"empty_embeds_dtype": str(empty_embeds.dtype),
"device": args.device,
"dtype": args.dtype,
}
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata to: {metadata_path}")
logger.info("Done!")
if __name__ == "__main__":
main()