| | |
| | """ |
| | Precompute Embeddings Script |
| | 预计算嵌入脚本 - 提前计算Qwen嵌入和VAE潜在空间编码以加速训练 |
| | """ |
| |
|
| | import argparse |
| | import os |
| | import sys |
| | from pathlib import Path |
| | import torch |
| | from tqdm import tqdm |
| | import traceback |
| |
|
| | |
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | from arch import QwenTextEncoder |
| | from arch.data_loader import QwenIllustriousDataset |
| | from diffusers import AutoencoderKL |
| |
|
| | from arch.model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Precompute embeddings for QwenIllustrious training") |
| | |
| | parser.add_argument( |
| | "--qwen_model_path", |
| | type=str, |
| | default="models/Qwen3-Embedding-0.6B", |
| | help="Path to Qwen text encoder model" |
| | ) |
| | parser.add_argument( |
| | "--sdxl_model_path", |
| | type=str, |
| | help="Path to SDXL model (for VAE)" |
| | ) |
| | parser.add_argument( |
| | "--vae_model_path", |
| | type=str, |
| | default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors", |
| | help="Path to VAE model (if different from SDXL)" |
| | ) |
| | parser.add_argument( |
| | "--vae_config_path", |
| | type=str, |
| | default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json", |
| | help="Path to VAE config file" |
| | ) |
| | parser.add_argument( |
| | "--dataset_path", |
| | type=str, |
| | default="illustrious_generated", |
| | help="Path to illustrious_generated dataset" |
| | ) |
| | parser.add_argument( |
| | "--cache_dir", |
| | type=str, |
| | default="illustrious_generated/cache", |
| | help="Directory to store precomputed embeddings" |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=8, |
| | help="Batch size for processing" |
| | ) |
| | parser.add_argument( |
| | "--device", |
| | type=str, |
| | default="cuda", |
| | help="Device to use for computation" |
| | ) |
| | parser.add_argument( |
| | "--mixed_precision", |
| | type=str, |
| | default="fp16", |
| | choices=["no", "fp16", "bf16"], |
| | help="Mixed precision mode" |
| | ) |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | |
| | print("Setting up models...") |
| | |
| | |
| | device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| | |
| | |
| | print("Loading Qwen text encoder...") |
| | qwen_text_encoder = QwenTextEncoder( |
| | model_path=args.qwen_model_path, |
| | device=device, |
| | freeze_encoder=True |
| | |
| | ) |
| | qwen_text_encoder.to(device) |
| | |
| | print("Loading VAE...") |
| | vae = load_vae_from_safetensors(args.vae_model_path, args.vae_config_path, device=device, dtype=torch.bfloat16) |
| | vae.to(device) |
| | |
| | |
| | |
| | |
| | cache_dir = Path(args.cache_dir) |
| | cache_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | print("Setting up dataset...") |
| | dataset = QwenIllustriousDataset( |
| | dataset_path=args.dataset_path, |
| | qwen_text_encoder=qwen_text_encoder, |
| | vae=vae, |
| | cache_dir=args.cache_dir, |
| | precompute_embeddings=False |
| | ) |
| | |
| | print(f"Found {len(dataset)} items to process") |
| | |
| | |
| | print("Starting precomputation...") |
| | |
| | with torch.no_grad(): |
| | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"): |
| | batch_end = min(i + args.batch_size, len(dataset)) |
| | |
| | |
| | batch_prompts = [] |
| | batch_metadata = [] |
| | batch_images = [] |
| | |
| | for j in range(i, batch_end): |
| | try: |
| | item = dataset[j] |
| | batch_prompts.append(item['prompts']) |
| | batch_metadata.append(item['metadata']) |
| | batch_images.append(item['images'].unsqueeze(0)) |
| | except Exception as e: |
| | print(f"Error processing item {j}: {e}") |
| | traceback.print_exc() |
| | raise |
| | |
| | if not batch_prompts: |
| | continue |
| | |
| | |
| | try: |
| | print(f"Processing text embeddings for batch {i//args.batch_size + 1}...") |
| | |
| | |
| | qwen_embeddings = qwen_text_encoder.encode_prompts(batch_prompts, do_classifier_free_guidance=False) |
| | |
| | |
| | for k, (prompt, metadata, image_tensor) in enumerate(zip(batch_prompts, batch_metadata, batch_images)): |
| | filename_hash = metadata['filename_hash'] |
| | |
| | |
| | text_cache_path = dataset._get_text_cache_path(filename_hash) |
| | text_data = { |
| | 'text_embeddings': qwen_embeddings[0][k:k+1].cpu(), |
| | 'pooled_embeddings': qwen_embeddings[1][k:k+1].cpu() |
| | } |
| | torch.save(text_data, text_cache_path) |
| | |
| | |
| | try: |
| | image_tensor = image_tensor.to(device) |
| | latents = vae.encode(image_tensor.to(vae.dtype)).latent_dist.sample() |
| | latents = latents * vae.config.scaling_factor |
| | |
| | |
| | vae_cache_path = dataset._get_vae_cache_path(filename_hash) |
| | torch.save(latents.cpu(), vae_cache_path) |
| | |
| | except Exception as e: |
| | print(f"Error processing VAE latents for {filename_hash}: {e}") |
| | traceback.print_exc() |
| | raise |
| | |
| | except Exception as e: |
| | print(f"Error processing batch {i//args.batch_size + 1}: {e}") |
| | traceback.print_exc() |
| | raise |
| | |
| | print("Precomputation completed!") |
| | print(f"Cached embeddings saved to: {cache_dir}") |
| | |
| | |
| | text_cache_dir = cache_dir / "text_embeddings" |
| | vae_cache_dir = cache_dir / "vae_latents" |
| | |
| | text_files = list(text_cache_dir.glob("*.pt")) |
| | vae_files = list(vae_cache_dir.glob("*.pt")) |
| | |
| | print(f"Text embeddings cached: {len(text_files)}") |
| | print(f"VAE latents cached: {len(vae_files)}") |
| | print(f"Total dataset size: {len(dataset)}") |
| | |
| | if len(text_files) != len(dataset) or len(vae_files) != len(dataset): |
| | print("Warning: Not all items were successfully cached!") |
| | else: |
| | print("All items successfully cached!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|