# --- train_mamba_sd.py --- import argparse import logging import math import os import shutil import random from pathlib import Path import traceback import io # <<< NEED THIS import requests # <<< NEED THIS AGAIN for URL fetching import accelerate import datasets import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset # No special Features needed now from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers from PIL import Image, UnidentifiedImageError # Need PIL and error handling import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available # Import Mamba block and replacement function try: # Assuming msd_utils.py is in the same directory or Python path from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba print("Successfully imported MambaSequentialBlock and replacement function from msd_utils.py.") except ImportError as e: print("="*50); print("ERROR: Failed to import from msd_utils.py!"); print("Ensure 'msd_utils.py' exists and contains necessary definitions."); print(f"Import error: {e}"); print("="*50); exit(1) # Import BasicTransformerBlock for type checking in unfreeze logic from diffusers.models.attention import BasicTransformerBlock check_min_version("0.28.0") # Define default columns FOR URL/TEXT datasets DEFAULT_IMAGE_COLUMN = "URL" DEFAULT_CAPTION_COLUMN = "TEXT" # --- Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser(description="Train Stable Diffusion with Mamba using a URL/Text dataset (e.g., MS_COCO_2017_URL_TEXT).") # <<< UPDATED DESC # Model Paths parser.add_argument("--pretrained_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5", help="Path to pretrained model or model identifier from huggingface.co/models.") # Dataset Arguments parser.add_argument( "--dataset_name", type=str, default="ChristophSchuhmann/MS_COCO_2017_URL_TEXT", # <<< UPDATED DEFAULT help="The HuggingFace dataset identifier for a dataset with URL and TEXT columns." ) parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data (Not recommended for URL datasets). Overrides --dataset_name.") parser.add_argument( "--image_column", type=str, default=DEFAULT_IMAGE_COLUMN, # <<< BACK TO URL help="The column of the dataset containing image URLs." ) parser.add_argument( "--caption_column", type=str, default=DEFAULT_CAPTION_COLUMN, # <<< BACK TO TEXT help="The column of the dataset containing single text captions." ) parser.add_argument("--max_train_samples", type=int, default=5000, help="Limit the number of training examples. Loads dataset metadata first, then selects.") # Keep requested default # Validation Arguments parser.add_argument("--validation_prompt", type=str, default="A photo of a busy city street with cars and pedestrians.", help="A prompt that is used during validation to verify that the model is learning.") parser.add_argument("--num_validation_images", type=int, default=4, help="Number of images that should be generated during validation with `validation_prompt`.") parser.add_argument("--validation_epochs", type=int, default=1, help="Run validation every X epochs (ignored if validation_steps is set).") parser.add_argument("--validation_steps", type=int, default=500, help="Run validation every X steps. Overrides validation_epochs.") # Output and Saving Arguments parser.add_argument("--output_dir", type=str, default="sd-mamba-trained-mscoco-urltext-5k", help="The output directory where the model predictions and checkpoints will be written.") # <<< UPDATED DEFAULT parser.add_argument("--cache_dir", type=str, default=None, help="Directory to cache downloaded models and datasets.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument("--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.") # Preprocessing Arguments parser.add_argument("--resolution", type=int, default=512, help="The resolution for input images, all images will be resized to this size.") parser.add_argument("--center_crop", action="store_true", default=True, help="Whether to center crop images after downloading.") # Default True parser.add_argument("--random_flip", action="store_true", default=True, help="Whether to randomly flip images horizontally.") # Default True # Training Hyperparameters parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. Overrides num_train_epochs.") parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--gradient_checkpointing", action="store_true", default=True, help="Whether to use gradient checkpointing to save memory.") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Initial learning rate (after the potential warmup period) to use.") parser.add_argument("--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.") parser.add_argument("--lr_scheduler", type=str, default="cosine", help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]') parser.add_argument("--lr_warmup_steps", type=int, default=100, help="Number of steps for the warmup in the learning rate scheduler.") parser.add_argument("--use_8bit_adam", action="store_true", default=True, help="Whether to use 8-bit AdamW optimizer.") parser.add_argument("--allow_tf32", action="store_true", default=True, help="Whether to allow TF32 on Ampere GPUs. Can speed up training.") parser.add_argument("--dataloader_num_workers", type=int, default=8, help="Number of subprocesses to use for data loading.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") # Accelerator Arguments parser.add_argument("--logging_dir", type=str, default="logs", help="Location for TensorBoard logs.") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision.") parser.add_argument("--report_to", type=str, default="tensorboard", help='The integration to report the results and logs to.') parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") # Checkpointing Arguments parser.add_argument("--checkpointing_steps", type=int, default=500, help="Save a checkpoint of the training state every X updates steps.") parser.add_argument("--checkpoints_total_limit", type=int, default=3, help="Max number of checkpoints to store.") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Whether to resume training from a previous checkpoint directory or 'latest'.") # Mamba Specific Arguments parser.add_argument("--mamba_d_state", type=int, default=16, help="Mamba ssm state dimension.") parser.add_argument("--mamba_d_conv", type=int, default=4, help="Mamba ssm convolution dimension.") parser.add_argument("--mamba_expand", type=int, default=2, help="Mamba ssm expansion factor.") # Preprocessing Specific Arguments parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for data preprocessing (defaults to cpu count capped at 16).") # **************************************************************** # args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: print(f"INFO: Overriding local_rank {args.local_rank} with environment variable LOCAL_RANK {env_local_rank}") args.local_rank = env_local_rank # Validation checks if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either --dataset_name or --train_data_dir") if args.dataset_name and args.train_data_dir: print("WARNING: Both --dataset_name and --train_data_dir provided. Using --dataset_name as URL dataset is specified.") args.train_data_dir = None # Prefer dataset_name for URL datasets # Set default preprocessing workers if not specified if args.preprocessing_num_workers is None: try: # Use min(os.sched_getaffinity(0), 16) for linux, fallback for others args.preprocessing_num_workers = min(len(os.sched_getaffinity(0)), 16) except AttributeError: args.preprocessing_num_workers = min(os.cpu_count(), 16) print(f"INFO: Auto-detected preprocessing_num_workers: {args.preprocessing_num_workers}") # Ensure max_train_samples is positive if set if args.max_train_samples is not None and args.max_train_samples <= 0: raise ValueError("--max_train_samples must be a positive integer.") return args # --- Dataset Handling --- def prepare_dataset(args, tokenizer, logger): # Pass logger explicitly """Loads, selects, preprocesses (downloads URLs), and filters the dataset.""" if args.dataset_name is not None: logger.info(f"Loading dataset '{args.dataset_name}' metadata...") try: # Load the dataset using the provided name (no config likely needed) dataset = load_dataset( args.dataset_name, cache_dir=args.cache_dir, # Consider adding split='train' directly if sure it exists # split="train", # You might add this if you know 'train' is always the split ) logger.info("Dataset metadata loaded successfully.") except Exception as e: logger.error(f"Failed to load dataset '{args.dataset_name}': {e}", exc_info=True) raise # Select 'train' split (most common), handle if not present split_to_use = "train" if split_to_use not in dataset: available_splits = list(dataset.keys()) if len(available_splits) == 1: split_to_use = available_splits[0] logger.warning(f"'train' split not found. Using the only available split: '{split_to_use}'.") else: raise ValueError(f"'train' split not found in dataset '{args.dataset_name}'. Available splits: {available_splits}. Please check the dataset structure or specify the split.") dataset = dataset[split_to_use] logger.info(f"Using '{split_to_use}' split. Initial size: {len(dataset)}") logger.info(f"Dataset features: {dataset.features}") else: # Should not happen with current checks, but keep for safety logger.error("Local data directory loading (--train_data_dir) is not the intended use case for this script modification.") raise NotImplementedError("This script is modified for URL datasets via --dataset_name.") # --- Check Columns --- column_names = dataset.column_names logger.info(f"Original dataset columns: {column_names}") if args.image_column not in column_names: # Should be "URL" raise ValueError(f"--image_column '{args.image_column}' not found in dataset '{args.dataset_name}'. Available columns: {column_names}") if args.caption_column not in column_names: # Should be "TEXT" raise ValueError(f"--caption_column '{args.caption_column}' not found in dataset '{args.dataset_name}'. Available columns: {column_names}") # --- SELECT SAMPLES (AFTER loading metadata, BEFORE downloading/mapping) --- if args.max_train_samples is not None: num_samples = len(dataset) max_samples_to_select = min(args.max_train_samples, num_samples) if args.max_train_samples > num_samples: logger.warning( f"--max_train_samples ({args.max_train_samples}) is larger than the dataset size ({num_samples}). " f"Using all {num_samples} samples." ) logger.info(f"Selecting {max_samples_to_select} samples from the dataset (shuffling first).") # Shuffle before selecting for randomness if max_train_samples is less than total if max_samples_to_select < num_samples: dataset = dataset.shuffle(seed=args.seed).select(range(max_samples_to_select)) else: # No need to shuffle if using all samples, map will handle shuffling later if needed dataset = dataset.select(range(max_samples_to_select)) # Selects all logger.info(f"Dataset size after selecting samples: {len(dataset)}") if len(dataset) == 0: raise ValueError(f"Selected 0 samples. Check --max_train_samples ({args.max_train_samples}) and dataset availability.") # --- Image Transforms (Applied after download) --- train_transforms = transforms.Compose([ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) logger.info("Image transforms defined.") # --- Preprocess Function (Handles one example: downloads URL, tokenizes TEXT) --- def preprocess_train_single(example): image_url = example[args.image_column] caption = example[args.caption_column] # 1. Download and Process Image from URL processed_image_tensor = None # Initialize outside try try: # Basic check if URL seems valid (optional, requests handles most) if not isinstance(image_url, str) or not image_url.startswith(("http://", "https://")): # Use debug level for frequent skips # logger.debug(f"Skipping invalid URL format: {str(image_url)[:100]}...") return None # Signal failure # --- INCREASED TIMEOUT --- response = requests.get(image_url, timeout=20, stream=False) # stream=False to download content immediately response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) img_bytes = response.content if not img_bytes: raise ValueError("Empty image content received") image_pil = Image.open(io.BytesIO(img_bytes)) # --- ADDED: Check for extremely large images BEFORE conversion/transforms --- # Avoid potential OOM Killer in subprocess for huge images (adjust threshold as needed) MAX_PIXELS = 4096 * 4096 # ~16 megapixels if image_pil.width * image_pil.height > MAX_PIXELS: # Use debug level for frequent skips # logger.debug(f"Skipping excessively large image ({image_pil.width}x{image_pil.height}): {image_url}") return None # Signal failure image_pil = image_pil.convert("RGB") # Ensure RGB # Apply transforms processed_image_tensor = train_transforms(image_pil) # --- ADDED/REFINED: More specific error catching --- except requests.exceptions.Timeout: # logger.debug(f"Timeout fetching {image_url}. Skipping.") return None except requests.exceptions.TooManyRedirects: # logger.debug(f"Too many redirects for {image_url}. Skipping.") return None except requests.exceptions.SSLError: # logger.debug(f"SSL Error for {image_url}. Skipping.") return None except requests.exceptions.RequestException as http_err: # Catches other request errors (ConnectionError, HTTPError etc.) # logger.debug(f"HTTP Error fetching {image_url}: {http_err}. Skipping.") return None except UnidentifiedImageError: # logger.debug(f"Cannot identify image file from {image_url}. Skipping.") return None except ValueError as val_err: # Catch empty content or other PIL value errors # logger.debug(f"Value error processing image {image_url}: {val_err}. Skipping.") return None except OSError as os_err: # Catch potential truncated images or other OS level issues from PIL # logger.debug(f"OS error processing image {image_url}: {os_err}. Skipping.") return None except Exception as img_err: # Catch-all for other unexpected errors during image processing/transforms logger.warning(f"Generic error processing/transforming image from {image_url}: {img_err}. Skipping.") # Consider logging the full traceback here for debugging if needed: # logger.exception(f"Traceback for generic image error on {image_url}:") return None # Check if image processing was successful before proceeding if processed_image_tensor is None: # This case should ideally be caught by exceptions above, but as a safeguard: # logger.debug(f"Image tensor is None after try-except block for {image_url}. Skipping.") return None # 2. Tokenize Caption (Keep previous error handling) try: caption_str = str(caption) if caption is not None else "" if not caption_str: # logger.debug(f"Skipping entry with empty caption for URL: {image_url}") return None # Signal failure inputs = tokenizer( caption_str, max_length=tokenizer.model_max_length, padding="max_length", # Pad to max length truncation=True, return_tensors="pt" # Return PyTorch tensors ) input_ids_tensor = inputs.input_ids.squeeze(0) # Remove batch dim added by tokenizer except Exception as tok_err: logger.warning(f"Error tokenizing caption '{str(caption)[:50]}...' for URL {image_url}: {tok_err}. Skipping.") return None # Signal failure # Return dictionary ONLY if both image and text processing succeeded return {"pixel_values": processed_image_tensor, "input_ids": input_ids_tensor} # --- Apply REVISED Preprocessing using map (non-batched URL download) --- num_proc = args.preprocessing_num_workers logger.info(f"Preprocessing dataset (downloading URLs, single item processing) using {num_proc} workers...") # It's crucial to understand that this `map` step will perform the downloads. # This can be slow and network-intensive. Consider using HF datasets caching. # The map function needs the list of columns to remove *before* processing columns_to_remove = dataset.column_names processed_dataset = dataset.map( preprocess_train_single, # Use the single-item URL download function batched=False, # Process item by item is NECESSARY for requests num_proc=num_proc, remove_columns=columns_to_remove, # Remove original cols AFTER processing load_from_cache_file=True, # Enable caching of mapped results (highly recommended!) desc="Downloading images and tokenizing captions", ) logger.info(f"Dataset size after map (potential download/processing): {len(processed_dataset)}") # --- Filter out None results (from errors in preprocess_train_single) --- original_count = len(processed_dataset) # Filter needs access to the function's return value; it implicitly gets the row processed_dataset = processed_dataset.filter(lambda example: example is not None, num_proc=1) new_count = len(processed_dataset) if original_count != new_count: logger.warning(f"Filtered out {original_count - new_count} entries due to download/processing errors.") if new_count == 0: raise RuntimeError("Dataset is empty after preprocessing and filtering. Check download errors (network, timeouts, invalid URLs/images), dataset integrity, and --max_train_samples.") logger.info(f"Final dataset size after filtering: {new_count}") # --- Set Format and Collate --- try: # logger.info(f"Attempting to set dataset format to 'torch' for columns: ['pixel_values', 'input_ids']") # Ensure columns exist before setting format final_columns = processed_dataset.column_names columns_to_set = [col for col in ["pixel_values", "input_ids"] if col in final_columns] if columns_to_set: processed_dataset.set_format(type="torch", columns=columns_to_set) logger.info(f"Successfully set dataset format to 'torch' for columns: {columns_to_set}.") # Optional: Print a sample to verify # if len(processed_dataset) > 0: # sample = processed_dataset[0] # pv_type = type(sample['pixel_values']) if 'pixel_values' in sample else 'Missing' # id_type = type(sample['input_ids']) if 'input_ids' in sample else 'Missing' # logger.info(f"Sample 0 types after set_format: pixel_values={pv_type}, input_ids={id_type}") # if isinstance(sample.get('pixel_values'), torch.Tensor): logger.info(f" PV shape: {sample['pixel_values'].shape}") # if isinstance(sample.get('input_ids'), torch.Tensor): logger.info(f" ID shape: {sample['input_ids'].shape}") else: logger.warning(f"Columns {['pixel_values', 'input_ids']} not found after filtering/mapping, skipping set_format. Available columns: {final_columns}") except Exception as e: logger.error(f"Failed to set dataset format to torch: {e}", exc_info=True) # Consider raising the error if this step is critical # raise RuntimeError("Failed to set dataset format") from e # --- Collate Function (Stacks tensors from the list of dicts) --- def collate_fn(examples): # Filter out any potential None values that might have slipped through (should be rare after .filter) valid_examples = [e for e in examples if e is not None and "pixel_values" in e and "input_ids" in e] if not valid_examples: # This might happen if a whole batch worth of URLs failed *concurrently* # logger.warning("Collate function received an empty list of valid examples. Returning empty batch.") return {} # Return empty dict, training loop MUST handle this try: # Stack tensors from the list of dictionaries pixel_values = torch.stack([example["pixel_values"] for example in valid_examples]) input_ids = torch.stack([example["input_ids"] for example in valid_examples]) except Exception as e: logger.error(f"Error during collation (likely size mismatch or invalid data): {e}", exc_info=True) # Log shapes of first few items to help debug for i, ex in enumerate(valid_examples[:5]): pv_shape = ex["pixel_values"].shape if isinstance(ex.get("pixel_values"), torch.Tensor) else type(ex.get("pixel_values")) id_shape = ex["input_ids"].shape if isinstance(ex.get("input_ids"), torch.Tensor) else type(ex.get("input_ids")) logger.error(f" Example {i}: PV shape/type={pv_shape}, ID shape/type={id_shape}") return {} # Return empty dict on error # Final check for safety if pixel_values.shape[0] != input_ids.shape[0]: logger.error(f"Collation error: Mismatched batch sizes after stacking. Images: {pixel_values.shape[0]}, Texts: {input_ids.shape[0]}. Skipping batch.") return {} return {"pixel_values": pixel_values, "input_ids": input_ids} logger.info("Dataset preparation function finished.") #return processed_dataset, collate_fn return processed_dataset, collate_fn, new_count, original_count # Return final count and pre-filter count # --- Main Training Function --- def main(): # --- Parse Args FIRST --- args = parse_args() # --- Initialize Accelerator SECOND --- logging_dir = Path(args.output_dir, args.logging_dir) # Use Path object accelerator_project_config = ProjectConfiguration(project_dir=str(args.output_dir), logging_dir=str(logging_dir)) # Ensure strings accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # --- Setup Logging THIRD (Now Accelerator is ready) --- # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s [%(process)d] - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, # Keep base level INFO ) logger = get_logger(__name__, log_level="INFO") # Use Accelerate logger # Setup logging, we only want one process per machine to log things on the screen. # accelerator.is_local_main_process is only True for one process per machine. if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # --- Log Accelerator State and Config FOURTH --- logger.info(f"Accelerator state: {accelerator.state}", main_process_only=False) # Set higher level for frequently noisy libraries during download/processing logging.getLogger("PIL").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) # --- Log Parsed Arguments FIFTH --- logger.info("Starting training script with arguments:") for k, v in sorted(vars(args).items()): logger.info(f" {k}: {v}") logger.info(f"Using dataset: '{args.dataset_name}'") logger.info(f"Using image column: '{args.image_column}', caption column: '{args.caption_column}'") # --- Set Seed --- if args.seed is not None: set_seed(args.seed) logger.info(f"Set random seed to {args.seed}") # --- Handle Hub Repo and Output Dir --- repo_id = None if accelerator.is_main_process: output_dir_path = Path(args.output_dir) # Use Path object if args.output_dir: output_dir_path.mkdir(parents=True, exist_ok=True) # Use Path object method logger.info(f"Output directory ensured: {args.output_dir}") if args.push_to_hub: # Hub creation logic... (kept as is) try: repo_id = create_repo( repo_id=args.hub_model_id or output_dir_path.name, exist_ok=True, token=args.hub_token ).repo_id logger.info(f"Created/verified Hub repo: {repo_id}") except Exception as e: logger.error(f"Failed to create/verify Hub repo: {e}", exc_info=True) logger.warning("Disabling Hub push due to error.") args.push_to_hub = False # --- Load models and tokenizer --- # (Keep this section as is, assuming base model parts are fine) logger.info("Loading tokenizer...") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", cache_dir=args.cache_dir) logger.info("Loading text encoder...") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", cache_dir=args.cache_dir) logger.info("Loading VAE...") vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", cache_dir=args.cache_dir) logger.info("Loading noise scheduler...") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # --- Load and modify U-Net --- # (Keep Mamba replacement logic as is) logger.info("Loading base U-Net state dict..."); try: # Use low_cpu_mem_usage=False initially, maybe True causes issues with config loading indirectly? original_unet_state_dict = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", cache_dir=args.cache_dir, low_cpu_mem_usage=False).state_dict() except TypeError: logger.warning("low_cpu_mem_usage=False failed for UNet loading (unexpected), trying with low_cpu_mem_usage=True.") original_unet_state_dict = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", cache_dir=args.cache_dir, low_cpu_mem_usage=True).state_dict() except Exception as load_err: logger.error(f"Failed to load original UNet state_dict: {load_err}", exc_info=True) raise logger.info("Creating new U-Net structure..."); unet_config = UNet2DConditionModel.load_config(args.pretrained_model_name_or_path, subfolder="unet") unet = UNet2DConditionModel.from_config(unet_config) logger.info("Replacing U-Net Self-Attention with Mamba blocks..."); mamba_kwargs = {'d_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand} unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) logger.info("Loading partial pre-trained weights into new structure..."); modified_keys = set(unet.state_dict().keys()) filtered_state_dict = { k: v for k, v in original_unet_state_dict.items() if k in modified_keys and unet.state_dict()[k].shape == v.shape } load_result = unet.load_state_dict(filtered_state_dict, strict=False) logger.info(f"U-Net Load Result - Missing Keys: {len(load_result.missing_keys)}, Unexpected Keys: {len(load_result.unexpected_keys)}") # Log some examples if needed for debugging Mamba replacement if accelerator.is_main_process: if load_result.missing_keys: logger.debug(f" Example Missing Keys (likely Mamba): {load_result.missing_keys[:5]}...") if load_result.unexpected_keys: logger.debug(f" Example Unexpected Keys (likely Attention): {load_result.unexpected_keys[:5]}...") del original_unet_state_dict, filtered_state_dict # Free memory # --- Freeze/Unfreeze logic --- # (Keep Mamba unfreezing logic as is) vae.requires_grad_(False); text_encoder.requires_grad_(False); unet.requires_grad_(False) logger.info("Froze VAE and Text Encoder.") logger.info("Unfreezing specified Mamba/Norm parameters in U-Net...") unfrozen_params_count = 0; total_params_count = 0; unfrozen_param_names = [] trainable_params = [] # Store parameters to optimize for name, param in unet.named_parameters(): total_params_count += param.numel() module_path_parts = name.split('.') should_unfreeze = False # Check if it's directly within a MambaSequentialBlock # Or if it's norm1 related to a replaced BasicTransformerBlock's attn1 current_module = unet is_in_mamba_block = False try: for part in module_path_parts[:-1]: # Iterate down to the parent module current_module = getattr(current_module, part) if isinstance(current_module, MambaSequentialBlock): is_in_mamba_block = True break if is_in_mamba_block: should_unfreeze = True else: # Check for the norm1 pattern after replacement is_norm1 = name.endswith(".norm1.weight") or name.endswith(".norm1.bias") if is_norm1 and len(module_path_parts) > 2: grandparent_module_path = '.'.join(module_path_parts[:-2]) grandparent_module = unet.get_submodule(grandparent_module_path) # Check if the grandparent used to be a BasicTransformerBlock # and its attn1 is now a Mamba block # This relies on the structure post-replacement. # A safer check might involve inspecting the replacement map if available. # Assuming direct replacement: if isinstance(grandparent_module, BasicTransformerBlock) and hasattr(grandparent_module, 'attn1') and isinstance(grandparent_module.attn1, MambaSequentialBlock): should_unfreeze = True except AttributeError: pass # Module path doesn't exist if should_unfreeze: param.requires_grad_(True) unfrozen_params_count += param.numel() unfrozen_param_names.append(name) trainable_params.append(param) # Add to list for optimizer logger.info(f"Unfroze {unfrozen_params_count} / {total_params_count} parameters ({unfrozen_params_count/total_params_count:.2%}) in U-Net.") if unfrozen_params_count > 0 and accelerator.is_main_process: logger.info(f"Example unfrozen parameters: {unfrozen_param_names[:5]}...") elif unfrozen_params_count == 0: logger.error("CRITICAL: No U-Net parameters were unfrozen! Check Mamba replacement and unfreezing logic."); exit(1) # --- Optimizations --- if args.gradient_checkpointing: unet.enable_gradient_checkpointing(); logger.info("Enabled gradient checkpointing for U-Net.") if args.allow_tf32: if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: logger.info("Allowing TF32 for matmul and cuDNN.") torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True else: logger.info("TF32 not enabled (requires Ampere+ GPU or CUDA setup).") if is_xformers_available(): try: unet.enable_xformers_memory_efficient_attention(); logger.info("Enabled xformers memory efficient attention.") except Exception as e: logger.warning(f"Could not enable xformers (may not be relevant if Mamba replaced all): {e}.") # --- Optimizer --- logger.info(f"Number of trainable parameters for optimizer: {len(trainable_params)}") if not trainable_params: logger.error("CRITICAL: No trainable parameters found for optimizer!"); exit(1) if args.use_8bit_adam: try: import bitsandbytes as bnb; optimizer_cls = bnb.optim.AdamW8bit; logger.info("Using 8-bit AdamW optimizer.") except ImportError: logger.warning("bitsandbytes not installed. Falling back to standard AdamW."); optimizer_cls = torch.optim.AdamW else: optimizer_cls = torch.optim.AdamW; logger.info("Using standard AdamW optimizer.") # Scale LR? if args.scale_lr: # Note: trainable_params might be only a subset of unet params # Scaling based on total batch size is common effective_total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps args.learning_rate = args.learning_rate * effective_total_batch_size logger.info(f"Scaled learning rate to {args.learning_rate} (original * {effective_total_batch_size})") optimizer = optimizer_cls(trainable_params, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon) # --- Dataset and DataLoader --- logger.info("Preparing dataset and dataloader (will download images during mapping)...") try: train_dataset, collate_fn, final_dataset_size, count_before_filter = prepare_dataset(args, tokenizer, logger) # Receive new counts except Exception as e: logger.error(f"Failed during dataset preparation: {e}", exc_info=True); exit(1) final_dataset_size = len(train_dataset) #logger.info(f"Successfully prepared dataset metadata. Final size after filtering errors: {final_dataset_size}") #if final_dataset_size == 0: logger.error("Training dataset is empty after filtering download/processing errors. Cannot train."); exit(1) # Add a warning if filtering removed a large percentage #if args.max_train_samples: #initial_sample_count = min(args.max_train_samples, len(load_dataset(args.dataset_name, cache_dir=args.cache_dir, split=split_to_use))) # Re-check initial selected count #if initial_sample_count > 0: #filter_ratio = (initial_sample_count - final_dataset_size) / initial_sample_count #if filter_ratio > 0.5: # Warn if > 50% filtered #logger.warning(f"High filtering ratio: Filtered {initial_sample_count - final_dataset_size}/{initial_sample_count} ({filter_ratio:.1%}) samples due to errors. Check network/dataset quality.") logger.info(f"Successfully prepared dataset. Final size after filtering errors: {final_dataset_size}") # Use the returned final_dataset_size if final_dataset_size == 0: logger.error("Training dataset is empty after filtering download/processing errors. Cannot train."); exit(1) # Optional: Re-implement the warning using the returned counts (more efficient) if count_before_filter > 0: # Check if we had samples before filtering filter_ratio = (count_before_filter - final_dataset_size) / count_before_filter # Adjust threshold for warning if needed (e.g., warn if > 20% filtered) if filter_ratio > 0.2: logger.warning(f"Filtering ratio: Filtered {count_before_filter - final_dataset_size}/{count_before_filter} ({filter_ratio:.1%}) samples during download/processing. Check network/dataset quality if ratio is high.") elif args.max_train_samples: # This case means even after map, the count was 0. logger.warning(f"Dataset size before filtering was 0, despite requesting samples. Initial map/download may have failed for all items.") train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, # Shuffle the filtered dataset collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, pin_memory=True, # Usually good if workers > 0 persistent_workers=True if args.dataloader_num_workers > 0 else False, # Avoid worker startup overhead ) logger.info("DataLoader created.") # --- Calculate training steps --- # Need to account for possibility of len(train_dataloader) being 0 if batch_size > final_dataset_size if len(train_dataloader) == 0 and final_dataset_size > 0: logger.warning(f"DataLoader length is 0 but dataset size is {final_dataset_size}. Check batch size ({args.train_batch_size}). Effective steps per epoch will be 0.") num_update_steps_per_epoch = 0 elif len(train_dataloader) == 0 and final_dataset_size == 0: logger.error("Both dataset size and dataloader length are 0. Cannot train.") exit(1) else: num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: if num_update_steps_per_epoch == 0: logger.error("Cannot calculate max_train_steps (steps per epoch is 0). Please set --max_train_steps explicitly."); exit(1) args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch logger.info(f"Calculated max_train_steps: {args.max_train_steps} ({args.num_train_epochs} epochs * {num_update_steps_per_epoch} steps/epoch)") else: if num_update_steps_per_epoch > 0: args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch); logger.info(f"Training for {args.max_train_steps} steps (~{args.num_train_epochs} epochs).") else: args.num_train_epochs = 0; logger.warning(f"Training for {args.max_train_steps} steps, but calculated steps per epoch is zero.") # --- Scheduler --- lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes) # Scale steps for scheduler? Often done this way. Check get_scheduler docs. Let's stick to global steps for now. lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps, # Warmup over global steps num_training_steps=args.max_train_steps # Total global steps ) logger.info(f"Initialized LR scheduler: {args.lr_scheduler} ({args.lr_warmup_steps} warmup, {args.max_train_steps} total steps).") # --- Prepare with Accelerator --- logger.info("Preparing models, optimizer, dataloader, and scheduler with Accelerator...") # Order matters: models, optimizer, dataloader, scheduler unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) logger.info("Accelerator preparation finished.") # --- Cast non-trainable models --- # Determine dtype AFTER accelerator.prepare (as it might change model dtype based on mixed_precision) # However, non-trainable models should be cast manually. weight_dtype = torch.float32 # Default if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 #logger.info(f"Moving VAE and Text Encoder to device {accelerator.device} and casting to {weight_dtype}...") #vae.to(accelerator.device, dtype=weight_dtype) #text_encoder.to(accelerator.device, dtype=weight_dtype) #logger.info("Casting finished.") logger.info(f"Moving VAE and Text Encoder to device {accelerator.device} (keeping float32)...") vae.to(accelerator.device) text_encoder.to(accelerator.device) # --- Init trackers --- if accelerator.is_main_process: tracker_project_name = "mamba-sd-train-url" # Sanitize dataset name for run name clean_dataset_name = args.dataset_name.split('/')[-1].replace('-', '_').replace('/','_').replace('.','_') if args.dataset_name else "local_data" effective_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps run_name = f"{clean_dataset_name}_{args.max_train_samples or 'all'}samples_lr{args.learning_rate}_bs{effective_batch_size}_mamba{args.mamba_d_state}-{args.mamba_d_conv}-{args.mamba_expand}" try: accelerator.init_trackers(tracker_project_name, config=vars(args), init_kwargs={"wandb": {"name": run_name}}) logger.info(f"Initialized trackers (Project: {tracker_project_name}, Run: {run_name})") except Exception as e: logger.warning(f"Could not initialize trackers ({args.report_to}): {e}.") # --- Resume logic --- # (Keep resume logic as is) global_step = 0; first_epoch = 0; resume_step = 0 if args.resume_from_checkpoint: checkpoint_path = None checkpoint_dir = Path(args.output_dir) if args.resume_from_checkpoint == "latest": # Find the latest checkpoint directory based on step number dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")] if dirs: try: latest_checkpoint = max(dirs, key=lambda d: int(d.name.split('-')[-1])) checkpoint_path = str(latest_checkpoint) logger.info(f"Resuming from latest checkpoint: {checkpoint_path}") except (ValueError, IndexError): logger.warning(f"Could not determine step number from checkpoint names in {checkpoint_dir}. Cannot resume 'latest'.") args.resume_from_checkpoint = None # Disable resume else: logger.info("No 'latest' checkpoint found to resume from."); args.resume_from_checkpoint = None else: checkpoint_path = args.resume_from_checkpoint if checkpoint_path and os.path.isdir(checkpoint_path): logger.info(f"Attempting resume from specific checkpoint: {checkpoint_path}") try: accelerator.load_state(checkpoint_path) # Extract global step from checkpoint directory name path_stem = Path(checkpoint_path).stem global_step = int(path_stem.split("-")[-1]) logger.info(f"Loaded state. Resuming from global step {global_step}.") # Recalculate steps per epoch AFTER prepare (dataloader length might change) steps_per_epoch_after_prepare = 0 if len(train_dataloader) > 0: steps_per_epoch_after_prepare = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if steps_per_epoch_after_prepare > 0: first_epoch = global_step // steps_per_epoch_after_prepare resume_step = global_step % steps_per_epoch_after_prepare logger.info(f"Calculated resume point: Epoch {first_epoch}, Step within epoch ~{resume_step}.") else: logger.warning("Steps/epoch is 0 after prepare. Cannot accurately calculate resume epoch/step within epoch. Starting from epoch 0.") first_epoch = 0; resume_step = 0 except FileNotFoundError: logger.error(f"Resume checkpoint directory not found: {checkpoint_path}. Starting fresh."); global_step=0; first_epoch=0; resume_step=0 except (ValueError, IndexError): logger.error(f"Could not parse step number from checkpoint name: {checkpoint_path}. Starting fresh."); global_step=0; first_epoch=0; resume_step=0 except Exception as e: logger.error(f"Failed to load checkpoint state: {e}. Starting fresh.", exc_info=True); global_step=0; first_epoch=0; resume_step=0 elif args.resume_from_checkpoint: logger.warning(f"Resume checkpoint path invalid or not found: '{args.resume_from_checkpoint}'. Starting fresh."); global_step=0; first_epoch=0; resume_step=0 else: # Case where resume_from_checkpoint was 'latest' but none found logger.info("Starting training from scratch (no checkpoint to resume)."); global_step=0; first_epoch=0; resume_step=0 # --- Training Loop --- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info(f"***** Running training ({args.dataset_name} - {final_dataset_size} Effective Samples) *****") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (effective) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Starting Epoch = {first_epoch}") logger.info(f" Starting Global Step = {global_step}") logger.info(f" Resume Step in Epoch (approx) = {resume_step}") # Approx because dataloader shuffling changes order progress_bar = tqdm(range(global_step, args.max_train_steps), initial=global_step, total=args.max_train_steps, desc="Optimization Steps", disable=not accelerator.is_local_main_process) # >>> Determine the weight_dtype based on mixed precision AFTER accelerator is ready <<< # This was likely done before the loop, ensure 'weight_dtype' is defined in this scope # Add it here for clarity if it wasn't defined before the loop in main() weight_dtype = torch.float32 # Default if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Make sure VAE and Text Encoder are kept in float32 as per previous fix for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 logger.info(f"--- Starting Epoch {epoch} ---") for step, batch in enumerate(train_dataloader): if not batch or "pixel_values" not in batch or batch["pixel_values"].shape[0] == 0: if accelerator.is_main_process: if global_step % 100 == 0: logger.warning(f"Skipping empty/invalid batch at raw step {step} (Epoch {epoch}, Global ~{global_step}). Likely due to download/collation errors.") continue # --- Accumulate Gradients --- with accelerator.accumulate(unet): try: # --- >>> MODIFIED FORWARD PASS START <<< --- # pixel_values usually float32 from dataloader/transforms pixel_values = batch["pixel_values"].to(accelerator.device) # 1. VAE Encoding (VAE is float32 on accelerator.device) with torch.no_grad(): # Explicitly cast VAE input to float32 latents = vae.encode(pixel_values.to(dtype=torch.float32)).latent_dist.sample() * vae.config.scaling_factor # 'latents' are float32 output from VAE # 2. Prepare Noise (matches latents dtype -> float32) and Timesteps noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() # 3. Add Noise (scheduler handles dtypes, noisy_latents should be float32) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # 4. Text Encoding (Text Encoder is float32 on accelerator.device) with torch.no_grad(): input_ids = batch["input_ids"].to(accelerator.device) # Output 'encoder_hidden_states' is float32 encoder_hidden_states = text_encoder(input_ids)[0] # --- 5. Cast UNet inputs to mixed precision type --- # 'weight_dtype' is float16 or bfloat16 if mixed precision is enabled noisy_latents_input = noisy_latents.to(dtype=weight_dtype) encoder_hidden_states_input = encoder_hidden_states.to(dtype=weight_dtype) # --- End Cast --- # 6. Predict Noise using UNet (UNet runs in mixed precision) model_pred = unet( noisy_latents_input, timesteps, encoder_hidden_states_input ).sample # 'model_pred' is likely in weight_dtype (e.g., float16) # 7. Get Target for Loss (Should be float32) if noise_scheduler.config.prediction_type == "epsilon": target = noise # noise is float32 elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) # float32 else: raise ValueError(f"Unsupported prediction type {noise_scheduler.config.prediction_type}") # 8. Calculate Loss (Cast BOTH model_pred and target to float32) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # 9. Gather Loss for Logging avg_loss = accelerator.gather(loss.unsqueeze(0)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps # --- >>> MODIFIED FORWARD PASS END <<< --- # --- Backward Pass --- accelerator.backward(loss) # --- Optimizer Step --- # This happens *outside* the try block but *inside* the sync_gradients check below # DO NOT PUT OPTIMIZER STEP HERE except Exception as forward_err: logger.error(f"Error during training step {step} (Epoch {epoch}, Global ~{global_step}): {forward_err}", exc_info=True) try: pv_shape = batch.get('pixel_values').shape if batch and isinstance(batch.get('pixel_values'), torch.Tensor) else 'N/A' id_shape = batch.get('input_ids').shape if batch and isinstance(batch.get('input_ids'), torch.Tensor) else 'N/A' logger.error(f" Batch Shapes - Pixels: {pv_shape}, IDs: {id_shape}") except Exception as log_err: logger.error(f" (Could not log batch details: {log_err})") continue # Skip to next batch # --- Sync Gradients, Step Optimizer, Log, Checkpoint, Validate --- if accelerator.sync_gradients: try: # Wrap optimizer step and gradient clipping if args.max_grad_norm > 0: accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) except Exception as optim_err: logger.error(f"Error during optimizer step/grad clipping at Global Step {global_step}: {optim_err}", exc_info=True) # Decide if you want to continue or stop on optimizer errors continue # Skip to next step for now # --- Progress Bar and Global Step --- progress_bar.update(1) global_step += 1 # --- Log Metrics --- if accelerator.is_main_process: logs = {"train_loss": train_loss} # Log the averaged accumulated loss if hasattr(lr_scheduler, "get_last_lr"): current_lr = lr_scheduler.get_last_lr() logs["lr"] = current_lr[0] if isinstance(current_lr, list) else current_lr else: logs["lr"] = optimizer.param_groups[0]['lr'] # Get LR from optimizer if scheduler doesn't have method try: accelerator.log(logs, step=global_step) except Exception as log_err: logger.warning(f"Logging failed for step {global_step}: {log_err}") train_loss = 0.0 # Reset accumulated loss for next set of accumulations # --- Checkpointing --- if global_step > 0 and global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: save_path = Path(args.output_dir) / f"checkpoint-{global_step}" try: logger.info(f"Saving checkpoint: {save_path}...") accelerator.save_state(str(save_path)) unwrapped_unet = accelerator.unwrap_model(unet) unet_save_path = save_path / "unet_mamba" unwrapped_unet.save_pretrained( str(unet_save_path), state_dict=unwrapped_unet.state_dict(), safe_serialization=True ) logger.info(f"Checkpoint saved to {save_path}") # Delete old checkpoints if args.checkpoints_total_limit is not None and args.checkpoints_total_limit > 0: checkpoint_dir = Path(args.output_dir) ckpts = sorted( [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")], key=lambda d: int(d.name.split("-")[-1]) ) if len(ckpts) > args.checkpoints_total_limit: num_to_delete = len(ckpts) - args.checkpoints_total_limit for old_ckpt in ckpts[:num_to_delete]: logger.info(f"Deleting old checkpoint: {old_ckpt}") shutil.rmtree(old_ckpt, ignore_errors=True) # Add ignore_errors except Exception as ckpt_err: logger.error(f"Checkpoint saving failed for step {global_step}: {ckpt_err}", exc_info=True) # --- Validation --- run_validation = False if args.validation_steps and global_step > 0 and global_step % args.validation_steps == 0: run_validation = True elif not args.validation_steps and args.validation_epochs > 0 and (epoch + 1) % args.validation_epochs == 0: is_last_accum_step = step == len(train_dataloader) - 1 if is_last_accum_step: run_validation = True if run_validation and accelerator.is_main_process: logger.info(f"Running validation at Global Step {global_step} (Epoch {epoch})...") log_validation_images = [] pipeline = None original_unet_training_mode = unet.training # Store training mode unet.eval() # Set unet to eval mode for validation try: # Models (VAE, Text Encoder) are already on device and float32 unet_val = accelerator.unwrap_model(unet) # Use unwrapped for pipeline pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet_val, vae=vae, # Use the float32 vae text_encoder=text_encoder, # Use the float32 text_encoder tokenizer=tokenizer, scheduler=noise_scheduler, safety_checker=None, torch_dtype=torch.float32, # <<< Run pipeline inference in float32 for stability >>> # or torch_dtype=weight_dtype if you are sure about VAE/TextEncoder fp16 stability cache_dir=args.cache_dir ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None logger.info(f"Generating {args.num_validation_images} validation images...") for i in range(args.num_validation_images): # Autocast might still be useful if internal pipeline ops benefit with torch.autocast(str(accelerator.device).split(":")[0], dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"), torch.no_grad(): image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] log_validation_images.append(np.array(image)) if log_validation_images: logger.info(f"Logging {len(log_validation_images)} validation images to trackers...") try: images_np = np.stack(log_validation_images) accelerator.log({"validation_images": images_np}, step=global_step) logger.info("Validation images logged.") except Exception as tracker_err: logger.warning(f"Failed to log validation images: {tracker_err}") else: logger.warning("No validation images were generated.") except Exception as val_err: logger.error(f"Validation failed at step {global_step}: {val_err}", exc_info=True) finally: # Cleanup pipeline and restore UNet training mode if pipeline is not None: del pipeline torch.cuda.empty_cache() unet.train(original_unet_training_mode) # Restore original mode logger.info("Validation run finished.") # --- Update progress bar postfix --- if accelerator.is_main_process: try: loss_val = loss.detach().item() current_lr_val = lr_scheduler.get_last_lr()[0] if hasattr(lr_scheduler, "get_last_lr") else optimizer.param_groups[0]['lr'] logs_postfix = {"loss": f"{loss_val:.4f}", "lr": f"{current_lr_val:.2e}"} progress_bar.set_postfix(**logs_postfix) except NameError: logs_postfix = {"loss": "N/A", "lr": optimizer.param_groups[0]['lr'] if optimizer.param_groups else 'N/A'} progress_bar.set_postfix(**logs_postfix) except Exception as pf_err: logger.debug(f"Postfix update error: {pf_err}") progress_bar.set_postfix({"step_status":"error"}) # --- Check for Training Completion --- if global_step >= args.max_train_steps: logger.info(f"Reached max_train_steps ({args.max_train_steps}). Stopping training.") break # Exit step (inner) loop # --- End of Epoch --- logger.info(f"--- Finished Epoch {epoch} (Reached Global Step {global_step}) ---") if global_step >= args.max_train_steps: break # Exit epoch loop # --- End of Training --- logger.info("Training finished. Waiting for all processes..."); accelerator.wait_for_everyone(); if accelerator.is_main_process: progress_bar.close() # Final Save if accelerator.is_main_process: logger.info("Saving final trained U-Net model..."); try: unet_final = accelerator.unwrap_model(unet) final_save_path = Path(args.output_dir) unet_final.save_pretrained( final_save_path / "unet_mamba_final", safe_serialization=True, state_dict=unet_final.state_dict() ) logger.info(f"Final UNet saved to: {final_save_path / 'unet_mamba_final'}") tokenizer.save_pretrained(str(final_save_path / "tokenizer_final")) logger.info(f"Final Tokenizer saved to: {final_save_path / 'tokenizer_final'}") except Exception as e: logger.error(f"Failed to save final UNet/Tokenizer: {e}", exc_info=True) # Hub Push Logic if args.push_to_hub: logger.info("Attempting to push final model to Hub..."); if repo_id is None: logger.warning("Cannot push to Hub (repo_id not defined or Hub creation failed).") else: try: logger.info(f"Pushing contents of {args.output_dir} to repository {repo_id}..."); upload_folder( repo_id=repo_id, folder_path=args.output_dir, commit_message="End of training - Mamba SD URL Text", ignore_patterns=["step_*", "epoch_*", "checkpoint-*/**", "checkpoint-*/", "*.safetensors.index.json", "logs/**"], token=args.hub_token ) logger.info("Push to Hub successful.") except Exception as e: logger.error(f"Hub upload failed: {e}", exc_info=True) logger.info("Ending training script..."); accelerator.end_training(); logger.info("Script finished.") # --- Entry Point --- # (Keep the entry point section as it was) if __name__ == "__main__": try: main() except Exception as e: print(f"\n\n !!! --- FATAL SCRIPT ERROR --- !!!") print(f"Error Type: {type(e).__name__}") print(f"Error Details: {e}") print(f"Traceback:") print(traceback.format_exc()) print(f" !!! --- SCRIPT TERMINATED DUE TO ERROR --- !!!") exit(1)