|
|
|
|
|
""" |
|
|
Extract and save empty_embeds for conditional dropout. |
|
|
|
|
|
This script extracts the empty embedding (from empty string prompt) |
|
|
and saves it to a file that can be loaded during training with precomputed features. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
from dataset_utils import tokenize_prompt, encode_prompt |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Extract empty embeddings for conditional dropout") |
|
|
|
|
|
parser.add_argument( |
|
|
"--text_encoder_architecture", |
|
|
type=str, |
|
|
default="umt5-base", |
|
|
choices=["umt5-base", "umt5-xxl", "t5"], |
|
|
help="Text encoder architecture", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to save the empty_embeds (will save as .pt file and metadata as .json)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
|
help="Device to use for encoding", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dtype", |
|
|
type=str, |
|
|
default="float16", |
|
|
choices=["float16", "bfloat16", "float32"], |
|
|
help="Data type for saving embeddings", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
if args.text_encoder_architecture == "umt5-base": |
|
|
model_id = "google/umt5-base" |
|
|
elif args.text_encoder_architecture == "umt5-xxl": |
|
|
model_id = "google/umt5-xxl" |
|
|
elif args.text_encoder_architecture == "t5": |
|
|
model_id = "t5-base" |
|
|
else: |
|
|
raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}") |
|
|
|
|
|
|
|
|
dtype_map = { |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float32": torch.float32, |
|
|
} |
|
|
dtype = dtype_map[args.dtype] |
|
|
|
|
|
logger.info(f"Loading text encoder: {model_id}") |
|
|
logger.info(f"Device: {args.device}, Dtype: {args.dtype}") |
|
|
|
|
|
|
|
|
text_encoder = T5EncoderModel.from_pretrained(model_id) |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
text_encoder.to(device=args.device, dtype=dtype) |
|
|
text_encoder.eval() |
|
|
text_encoder.requires_grad_(False) |
|
|
|
|
|
|
|
|
logger.info("Extracting empty embedding from empty string...") |
|
|
with torch.no_grad(): |
|
|
empty_input_ids = tokenize_prompt(tokenizer, "", args.text_encoder_architecture) |
|
|
empty_input_ids = empty_input_ids.to(args.device) |
|
|
|
|
|
empty_embeds, cond_embeds = encode_prompt( |
|
|
text_encoder, |
|
|
empty_input_ids, |
|
|
args.text_encoder_architecture |
|
|
) |
|
|
|
|
|
|
|
|
empty_embeds = empty_embeds.cpu().to(dtype) |
|
|
|
|
|
logger.info(f"Empty embedding shape: {empty_embeds.shape}") |
|
|
logger.info(f"Empty embedding dtype: {empty_embeds.dtype}") |
|
|
|
|
|
|
|
|
output_dir = os.path.dirname(args.output_path) |
|
|
if output_dir: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(empty_embeds, args.output_path) |
|
|
logger.info(f"Saved empty_embeds to: {args.output_path}") |
|
|
|
|
|
|
|
|
metadata_path = args.output_path.replace('.pt', '.json') |
|
|
metadata = { |
|
|
"text_encoder_architecture": args.text_encoder_architecture, |
|
|
"model_id": model_id, |
|
|
"empty_embeds_shape": list(empty_embeds.shape), |
|
|
"empty_embeds_dtype": str(empty_embeds.dtype), |
|
|
"device": args.device, |
|
|
"dtype": args.dtype, |
|
|
} |
|
|
|
|
|
with open(metadata_path, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
logger.info(f"Saved metadata to: {metadata_path}") |
|
|
|
|
|
logger.info("Done!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|