File size: 4,031 Bytes
3d1c0e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#!/usr/bin/env python3
"""
Extract and save empty_embeds for conditional dropout.
This script extracts the empty embedding (from empty string prompt)
and saves it to a file that can be loaded during training with precomputed features.
"""
import argparse
import os
import json
import torch
from transformers import T5EncoderModel, T5Tokenizer
from dataset_utils import tokenize_prompt, encode_prompt
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Extract empty embeddings for conditional dropout")
parser.add_argument(
"--text_encoder_architecture",
type=str,
default="umt5-base",
choices=["umt5-base", "umt5-xxl", "t5"],
help="Text encoder architecture",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Path to save the empty_embeds (will save as .pt file and metadata as .json)",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for encoding",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float16", "bfloat16", "float32"],
help="Data type for saving embeddings",
)
return parser.parse_args()
def main():
args = parse_args()
# Map architecture to model ID
if args.text_encoder_architecture == "umt5-base":
model_id = "google/umt5-base"
elif args.text_encoder_architecture == "umt5-xxl":
model_id = "google/umt5-xxl"
elif args.text_encoder_architecture == "t5":
model_id = "t5-base"
else:
raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
# Map dtype
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map[args.dtype]
logger.info(f"Loading text encoder: {model_id}")
logger.info(f"Device: {args.device}, Dtype: {args.dtype}")
# Load text encoder and tokenizer
text_encoder = T5EncoderModel.from_pretrained(model_id)
tokenizer = T5Tokenizer.from_pretrained(model_id)
# Move to device and set dtype
text_encoder.to(device=args.device, dtype=dtype)
text_encoder.eval()
text_encoder.requires_grad_(False)
# Extract empty embedding
logger.info("Extracting empty embedding from empty string...")
with torch.no_grad():
empty_input_ids = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
empty_input_ids = empty_input_ids.to(args.device)
empty_embeds, cond_embeds = encode_prompt(
text_encoder,
empty_input_ids,
args.text_encoder_architecture
)
# Convert to CPU and target dtype
empty_embeds = empty_embeds.cpu().to(dtype)
logger.info(f"Empty embedding shape: {empty_embeds.shape}")
logger.info(f"Empty embedding dtype: {empty_embeds.dtype}")
# Save empty_embeds
output_dir = os.path.dirname(args.output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Save as .pt file
torch.save(empty_embeds, args.output_path)
logger.info(f"Saved empty_embeds to: {args.output_path}")
# Save metadata
metadata_path = args.output_path.replace('.pt', '.json')
metadata = {
"text_encoder_architecture": args.text_encoder_architecture,
"model_id": model_id,
"empty_embeds_shape": list(empty_embeds.shape),
"empty_embeds_dtype": str(empty_embeds.dtype),
"device": args.device,
"dtype": args.dtype,
}
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata to: {metadata_path}")
logger.info("Done!")
if __name__ == "__main__":
main()
|