qwenillustrious / train /precompute_embeddings.py
lsmpp's picture
Add files using upload-large-folder tool
d926b4c verified
#!/usr/bin/env python3
"""
Precompute Embeddings Script
预计算嵌入脚本 - 提前计算Qwen嵌入和VAE潜在空间编码以加速训练
"""
import argparse
import os
import sys
from pathlib import Path
import torch
from tqdm import tqdm
import traceback
# 添加项目路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from arch import QwenTextEncoder
from arch.data_loader import QwenIllustriousDataset
from diffusers import AutoencoderKL
from arch.model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler
def parse_args():
parser = argparse.ArgumentParser(description="Precompute embeddings for QwenIllustrious training")
parser.add_argument(
"--qwen_model_path",
type=str,
default="models/Qwen3-Embedding-0.6B",
help="Path to Qwen text encoder model"
)
parser.add_argument(
"--sdxl_model_path",
type=str,
help="Path to SDXL model (for VAE)"
)
parser.add_argument(
"--vae_model_path",
type=str,
default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors",
help="Path to VAE model (if different from SDXL)"
)
parser.add_argument(
"--vae_config_path",
type=str,
default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json",
help="Path to VAE config file"
)
parser.add_argument(
"--dataset_path",
type=str,
default="illustrious_generated",
help="Path to illustrious_generated dataset"
)
parser.add_argument(
"--cache_dir",
type=str,
default="illustrious_generated/cache",
help="Directory to store precomputed embeddings"
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size for processing"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for computation"
)
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help="Mixed precision mode"
)
return parser.parse_args()
def main():
args = parse_args()
print("Setting up models...")
# Setup device
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models
print("Loading Qwen text encoder...")
qwen_text_encoder = QwenTextEncoder(
model_path=args.qwen_model_path,
device=device,
freeze_encoder=True
# torch_dtype=torch.float16 if args.mixed_precision == "fp16" else torch.float32
)
qwen_text_encoder.to(device)
print("Loading VAE...")
vae = load_vae_from_safetensors(args.vae_model_path, args.vae_config_path, device=device, dtype=torch.bfloat16)
vae.to(device)
# Note: We don't load adapter here as it's a trainable component
# Create cache directory
cache_dir = Path(args.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
# Setup dataset (without precomputation initially)
print("Setting up dataset...")
dataset = QwenIllustriousDataset(
dataset_path=args.dataset_path,
qwen_text_encoder=qwen_text_encoder,
vae=vae,
cache_dir=args.cache_dir,
precompute_embeddings=False # We'll do this manually
)
print(f"Found {len(dataset)} items to process")
# Process in batches
print("Starting precomputation...")
with torch.no_grad():
for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"):
batch_end = min(i + args.batch_size, len(dataset))
# Process items in current batch
batch_prompts = []
batch_metadata = []
batch_images = []
for j in range(i, batch_end):
try:
item = dataset[j] # This will load image and get prompt
batch_prompts.append(item['prompts'])
batch_metadata.append(item['metadata'])
batch_images.append(item['images'].unsqueeze(0))
except Exception as e:
print(f"Error processing item {j}: {e}")
traceback.print_exc() # 打印完整的调用栈
raise # 重新抛出异常,中断程序执行
if not batch_prompts:
continue
# Batch process text embeddings
try:
print(f"Processing text embeddings for batch {i//args.batch_size + 1}...")
# Encode texts with Qwen (save raw embeddings for training)
qwen_embeddings = qwen_text_encoder.encode_prompts(batch_prompts, do_classifier_free_guidance=False)
# Process each item in the batch
for k, (prompt, metadata, image_tensor) in enumerate(zip(batch_prompts, batch_metadata, batch_images)):
filename_hash = metadata['filename_hash']
# Save raw Qwen embeddings (before adapter)
text_cache_path = dataset._get_text_cache_path(filename_hash)
text_data = {
'text_embeddings': qwen_embeddings[0][k:k+1].cpu(),
'pooled_embeddings': qwen_embeddings[1][k:k+1].cpu()
}
torch.save(text_data, text_cache_path)
# Process VAE latents
try:
image_tensor = image_tensor.to(device)
latents = vae.encode(image_tensor.to(vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Save VAE latents
vae_cache_path = dataset._get_vae_cache_path(filename_hash)
torch.save(latents.cpu(), vae_cache_path)
except Exception as e:
print(f"Error processing VAE latents for {filename_hash}: {e}")
traceback.print_exc()
raise
except Exception as e:
print(f"Error processing batch {i//args.batch_size + 1}: {e}")
traceback.print_exc()
raise # 中断程序执行
print("Precomputation completed!")
print(f"Cached embeddings saved to: {cache_dir}")
# Verify cache
text_cache_dir = cache_dir / "text_embeddings"
vae_cache_dir = cache_dir / "vae_latents"
text_files = list(text_cache_dir.glob("*.pt"))
vae_files = list(vae_cache_dir.glob("*.pt"))
print(f"Text embeddings cached: {len(text_files)}")
print(f"VAE latents cached: {len(vae_files)}")
print(f"Total dataset size: {len(dataset)}")
if len(text_files) != len(dataset) or len(vae_files) != len(dataset):
print("Warning: Not all items were successfully cached!")
else:
print("All items successfully cached!")
if __name__ == "__main__":
main()