File size: 4,031 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
Extract and save empty_embeds for conditional dropout.

This script extracts the empty embedding (from empty string prompt) 
and saves it to a file that can be loaded during training with precomputed features.
"""

import argparse
import os
import json
import torch
from transformers import T5EncoderModel, T5Tokenizer
from dataset_utils import tokenize_prompt, encode_prompt
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Extract empty embeddings for conditional dropout")
    
    parser.add_argument(
        "--text_encoder_architecture",
        type=str,
        default="umt5-base",
        choices=["umt5-base", "umt5-xxl", "t5"],
        help="Text encoder architecture",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        required=True,
        help="Path to save the empty_embeds (will save as .pt file and metadata as .json)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to use for encoding",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float16", "bfloat16", "float32"],
        help="Data type for saving embeddings",
    )
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    # Map architecture to model ID
    if args.text_encoder_architecture == "umt5-base":
        model_id = "google/umt5-base"
    elif args.text_encoder_architecture == "umt5-xxl":
        model_id = "google/umt5-xxl"
    elif args.text_encoder_architecture == "t5":
        model_id = "t5-base"
    else:
        raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
    
    # Map dtype
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }
    dtype = dtype_map[args.dtype]
    
    logger.info(f"Loading text encoder: {model_id}")
    logger.info(f"Device: {args.device}, Dtype: {args.dtype}")
    
    # Load text encoder and tokenizer
    text_encoder = T5EncoderModel.from_pretrained(model_id)
    tokenizer = T5Tokenizer.from_pretrained(model_id)
    
    # Move to device and set dtype
    text_encoder.to(device=args.device, dtype=dtype)
    text_encoder.eval()
    text_encoder.requires_grad_(False)
    
    # Extract empty embedding
    logger.info("Extracting empty embedding from empty string...")
    with torch.no_grad():
        empty_input_ids = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
        empty_input_ids = empty_input_ids.to(args.device)
        
        empty_embeds, cond_embeds = encode_prompt(
            text_encoder,
            empty_input_ids,
            args.text_encoder_architecture
        )
    
    # Convert to CPU and target dtype
    empty_embeds = empty_embeds.cpu().to(dtype)
    
    logger.info(f"Empty embedding shape: {empty_embeds.shape}")
    logger.info(f"Empty embedding dtype: {empty_embeds.dtype}")
    
    # Save empty_embeds
    output_dir = os.path.dirname(args.output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    
    # Save as .pt file
    torch.save(empty_embeds, args.output_path)
    logger.info(f"Saved empty_embeds to: {args.output_path}")
    
    # Save metadata
    metadata_path = args.output_path.replace('.pt', '.json')
    metadata = {
        "text_encoder_architecture": args.text_encoder_architecture,
        "model_id": model_id,
        "empty_embeds_shape": list(empty_embeds.shape),
        "empty_embeds_dtype": str(empty_embeds.dtype),
        "device": args.device,
        "dtype": args.dtype,
    }
    
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    logger.info(f"Saved metadata to: {metadata_path}")
    
    logger.info("Done!")


if __name__ == "__main__":
    main()