MSD / train_mamba_sd.py
root
Initial clean upload: checkpoint + scripts + PNG via LFS
5e7715d
# --- train_mamba_sd.py ---
import argparse
import logging
import math
import os
import shutil
import random
from pathlib import Path
import traceback
import io # <<< NEED THIS
import requests # <<< NEED THIS AGAIN for URL fetching
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset # No special Features needed now
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
from PIL import Image, UnidentifiedImageError # Need PIL and error handling
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Import Mamba block and replacement function
try:
# Assuming msd_utils.py is in the same directory or Python path
from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba
print("Successfully imported MambaSequentialBlock and replacement function from msd_utils.py.")
except ImportError as e:
print("="*50); print("ERROR: Failed to import from msd_utils.py!"); print("Ensure 'msd_utils.py' exists and contains necessary definitions."); print(f"Import error: {e}"); print("="*50); exit(1)
# Import BasicTransformerBlock for type checking in unfreeze logic
from diffusers.models.attention import BasicTransformerBlock
check_min_version("0.28.0")
# Define default columns FOR URL/TEXT datasets
DEFAULT_IMAGE_COLUMN = "URL"
DEFAULT_CAPTION_COLUMN = "TEXT"
# --- Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser(description="Train Stable Diffusion with Mamba using a URL/Text dataset (e.g., MS_COCO_2017_URL_TEXT).") # <<< UPDATED DESC
# Model Paths
parser.add_argument("--pretrained_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5", help="Path to pretrained model or model identifier from huggingface.co/models.")
# Dataset Arguments
parser.add_argument(
"--dataset_name",
type=str,
default="ChristophSchuhmann/MS_COCO_2017_URL_TEXT", # <<< UPDATED DEFAULT
help="The HuggingFace dataset identifier for a dataset with URL and TEXT columns."
)
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data (Not recommended for URL datasets). Overrides --dataset_name.")
parser.add_argument(
"--image_column",
type=str,
default=DEFAULT_IMAGE_COLUMN, # <<< BACK TO URL
help="The column of the dataset containing image URLs."
)
parser.add_argument(
"--caption_column",
type=str,
default=DEFAULT_CAPTION_COLUMN, # <<< BACK TO TEXT
help="The column of the dataset containing single text captions."
)
parser.add_argument("--max_train_samples", type=int, default=5000, help="Limit the number of training examples. Loads dataset metadata first, then selects.") # Keep requested default
# Validation Arguments
parser.add_argument("--validation_prompt", type=str, default="A photo of a busy city street with cars and pedestrians.", help="A prompt that is used during validation to verify that the model is learning.")
parser.add_argument("--num_validation_images", type=int, default=4, help="Number of images that should be generated during validation with `validation_prompt`.")
parser.add_argument("--validation_epochs", type=int, default=1, help="Run validation every X epochs (ignored if validation_steps is set).")
parser.add_argument("--validation_steps", type=int, default=500, help="Run validation every X steps. Overrides validation_epochs.")
# Output and Saving Arguments
parser.add_argument("--output_dir", type=str, default="sd-mamba-trained-mscoco-urltext-5k", help="The output directory where the model predictions and checkpoints will be written.") # <<< UPDATED DEFAULT
parser.add_argument("--cache_dir", type=str, default=None, help="Directory to cache downloaded models and datasets.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.")
# Preprocessing Arguments
parser.add_argument("--resolution", type=int, default=512, help="The resolution for input images, all images will be resized to this size.")
parser.add_argument("--center_crop", action="store_true", default=True, help="Whether to center crop images after downloading.") # Default True
parser.add_argument("--random_flip", action="store_true", default=True, help="Whether to randomly flip images horizontally.") # Default True
# Training Hyperparameters
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.")
parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. Overrides num_train_epochs.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--gradient_checkpointing", action="store_true", default=True, help="Whether to use gradient checkpointing to save memory.")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.")
parser.add_argument("--lr_scheduler", type=str, default="cosine", help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]')
parser.add_argument("--lr_warmup_steps", type=int, default=100, help="Number of steps for the warmup in the learning rate scheduler.")
parser.add_argument("--use_8bit_adam", action="store_true", default=True, help="Whether to use 8-bit AdamW optimizer.")
parser.add_argument("--allow_tf32", action="store_true", default=True, help="Whether to allow TF32 on Ampere GPUs. Can speed up training.")
parser.add_argument("--dataloader_num_workers", type=int, default=8, help="Number of subprocesses to use for data loading.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
# Accelerator Arguments
parser.add_argument("--logging_dir", type=str, default="logs", help="Location for TensorBoard logs.")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision.")
parser.add_argument("--report_to", type=str, default="tensorboard", help='The integration to report the results and logs to.')
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
# Checkpointing Arguments
parser.add_argument("--checkpointing_steps", type=int, default=500, help="Save a checkpoint of the training state every X updates steps.")
parser.add_argument("--checkpoints_total_limit", type=int, default=3, help="Max number of checkpoints to store.")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Whether to resume training from a previous checkpoint directory or 'latest'.")
# Mamba Specific Arguments
parser.add_argument("--mamba_d_state", type=int, default=16, help="Mamba ssm state dimension.")
parser.add_argument("--mamba_d_conv", type=int, default=4, help="Mamba ssm convolution dimension.")
parser.add_argument("--mamba_expand", type=int, default=2, help="Mamba ssm expansion factor.")
# Preprocessing Specific Arguments
parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for data preprocessing (defaults to cpu count capped at 16).")
# **************************************************************** #
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
print(f"INFO: Overriding local_rank {args.local_rank} with environment variable LOCAL_RANK {env_local_rank}")
args.local_rank = env_local_rank
# Validation checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either --dataset_name or --train_data_dir")
if args.dataset_name and args.train_data_dir:
print("WARNING: Both --dataset_name and --train_data_dir provided. Using --dataset_name as URL dataset is specified.")
args.train_data_dir = None # Prefer dataset_name for URL datasets
# Set default preprocessing workers if not specified
if args.preprocessing_num_workers is None:
try:
# Use min(os.sched_getaffinity(0), 16) for linux, fallback for others
args.preprocessing_num_workers = min(len(os.sched_getaffinity(0)), 16)
except AttributeError:
args.preprocessing_num_workers = min(os.cpu_count(), 16)
print(f"INFO: Auto-detected preprocessing_num_workers: {args.preprocessing_num_workers}")
# Ensure max_train_samples is positive if set
if args.max_train_samples is not None and args.max_train_samples <= 0:
raise ValueError("--max_train_samples must be a positive integer.")
return args
# --- Dataset Handling ---
def prepare_dataset(args, tokenizer, logger): # Pass logger explicitly
"""Loads, selects, preprocesses (downloads URLs), and filters the dataset."""
if args.dataset_name is not None:
logger.info(f"Loading dataset '{args.dataset_name}' metadata...")
try:
# Load the dataset using the provided name (no config likely needed)
dataset = load_dataset(
args.dataset_name,
cache_dir=args.cache_dir,
# Consider adding split='train' directly if sure it exists
# split="train", # You might add this if you know 'train' is always the split
)
logger.info("Dataset metadata loaded successfully.")
except Exception as e:
logger.error(f"Failed to load dataset '{args.dataset_name}': {e}", exc_info=True)
raise
# Select 'train' split (most common), handle if not present
split_to_use = "train"
if split_to_use not in dataset:
available_splits = list(dataset.keys())
if len(available_splits) == 1:
split_to_use = available_splits[0]
logger.warning(f"'train' split not found. Using the only available split: '{split_to_use}'.")
else:
raise ValueError(f"'train' split not found in dataset '{args.dataset_name}'. Available splits: {available_splits}. Please check the dataset structure or specify the split.")
dataset = dataset[split_to_use]
logger.info(f"Using '{split_to_use}' split. Initial size: {len(dataset)}")
logger.info(f"Dataset features: {dataset.features}")
else: # Should not happen with current checks, but keep for safety
logger.error("Local data directory loading (--train_data_dir) is not the intended use case for this script modification.")
raise NotImplementedError("This script is modified for URL datasets via --dataset_name.")
# --- Check Columns ---
column_names = dataset.column_names
logger.info(f"Original dataset columns: {column_names}")
if args.image_column not in column_names: # Should be "URL"
raise ValueError(f"--image_column '{args.image_column}' not found in dataset '{args.dataset_name}'. Available columns: {column_names}")
if args.caption_column not in column_names: # Should be "TEXT"
raise ValueError(f"--caption_column '{args.caption_column}' not found in dataset '{args.dataset_name}'. Available columns: {column_names}")
# --- SELECT SAMPLES (AFTER loading metadata, BEFORE downloading/mapping) ---
if args.max_train_samples is not None:
num_samples = len(dataset)
max_samples_to_select = min(args.max_train_samples, num_samples)
if args.max_train_samples > num_samples:
logger.warning(
f"--max_train_samples ({args.max_train_samples}) is larger than the dataset size ({num_samples}). "
f"Using all {num_samples} samples."
)
logger.info(f"Selecting {max_samples_to_select} samples from the dataset (shuffling first).")
# Shuffle before selecting for randomness if max_train_samples is less than total
if max_samples_to_select < num_samples:
dataset = dataset.shuffle(seed=args.seed).select(range(max_samples_to_select))
else:
# No need to shuffle if using all samples, map will handle shuffling later if needed
dataset = dataset.select(range(max_samples_to_select)) # Selects all
logger.info(f"Dataset size after selecting samples: {len(dataset)}")
if len(dataset) == 0:
raise ValueError(f"Selected 0 samples. Check --max_train_samples ({args.max_train_samples}) and dataset availability.")
# --- Image Transforms (Applied after download) ---
train_transforms = transforms.Compose([
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
logger.info("Image transforms defined.")
# --- Preprocess Function (Handles one example: downloads URL, tokenizes TEXT) ---
def preprocess_train_single(example):
image_url = example[args.image_column]
caption = example[args.caption_column]
# 1. Download and Process Image from URL
processed_image_tensor = None # Initialize outside try
try:
# Basic check if URL seems valid (optional, requests handles most)
if not isinstance(image_url, str) or not image_url.startswith(("http://", "https://")):
# Use debug level for frequent skips
# logger.debug(f"Skipping invalid URL format: {str(image_url)[:100]}...")
return None # Signal failure
# --- INCREASED TIMEOUT ---
response = requests.get(image_url, timeout=20, stream=False) # stream=False to download content immediately
response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
img_bytes = response.content
if not img_bytes:
raise ValueError("Empty image content received")
image_pil = Image.open(io.BytesIO(img_bytes))
# --- ADDED: Check for extremely large images BEFORE conversion/transforms ---
# Avoid potential OOM Killer in subprocess for huge images (adjust threshold as needed)
MAX_PIXELS = 4096 * 4096 # ~16 megapixels
if image_pil.width * image_pil.height > MAX_PIXELS:
# Use debug level for frequent skips
# logger.debug(f"Skipping excessively large image ({image_pil.width}x{image_pil.height}): {image_url}")
return None # Signal failure
image_pil = image_pil.convert("RGB") # Ensure RGB
# Apply transforms
processed_image_tensor = train_transforms(image_pil)
# --- ADDED/REFINED: More specific error catching ---
except requests.exceptions.Timeout:
# logger.debug(f"Timeout fetching {image_url}. Skipping.")
return None
except requests.exceptions.TooManyRedirects:
# logger.debug(f"Too many redirects for {image_url}. Skipping.")
return None
except requests.exceptions.SSLError:
# logger.debug(f"SSL Error for {image_url}. Skipping.")
return None
except requests.exceptions.RequestException as http_err:
# Catches other request errors (ConnectionError, HTTPError etc.)
# logger.debug(f"HTTP Error fetching {image_url}: {http_err}. Skipping.")
return None
except UnidentifiedImageError:
# logger.debug(f"Cannot identify image file from {image_url}. Skipping.")
return None
except ValueError as val_err: # Catch empty content or other PIL value errors
# logger.debug(f"Value error processing image {image_url}: {val_err}. Skipping.")
return None
except OSError as os_err: # Catch potential truncated images or other OS level issues from PIL
# logger.debug(f"OS error processing image {image_url}: {os_err}. Skipping.")
return None
except Exception as img_err:
# Catch-all for other unexpected errors during image processing/transforms
logger.warning(f"Generic error processing/transforming image from {image_url}: {img_err}. Skipping.")
# Consider logging the full traceback here for debugging if needed:
# logger.exception(f"Traceback for generic image error on {image_url}:")
return None
# Check if image processing was successful before proceeding
if processed_image_tensor is None:
# This case should ideally be caught by exceptions above, but as a safeguard:
# logger.debug(f"Image tensor is None after try-except block for {image_url}. Skipping.")
return None
# 2. Tokenize Caption (Keep previous error handling)
try:
caption_str = str(caption) if caption is not None else ""
if not caption_str:
# logger.debug(f"Skipping entry with empty caption for URL: {image_url}")
return None # Signal failure
inputs = tokenizer(
caption_str,
max_length=tokenizer.model_max_length,
padding="max_length", # Pad to max length
truncation=True,
return_tensors="pt" # Return PyTorch tensors
)
input_ids_tensor = inputs.input_ids.squeeze(0) # Remove batch dim added by tokenizer
except Exception as tok_err:
logger.warning(f"Error tokenizing caption '{str(caption)[:50]}...' for URL {image_url}: {tok_err}. Skipping.")
return None # Signal failure
# Return dictionary ONLY if both image and text processing succeeded
return {"pixel_values": processed_image_tensor, "input_ids": input_ids_tensor}
# --- Apply REVISED Preprocessing using map (non-batched URL download) ---
num_proc = args.preprocessing_num_workers
logger.info(f"Preprocessing dataset (downloading URLs, single item processing) using {num_proc} workers...")
# It's crucial to understand that this `map` step will perform the downloads.
# This can be slow and network-intensive. Consider using HF datasets caching.
# The map function needs the list of columns to remove *before* processing
columns_to_remove = dataset.column_names
processed_dataset = dataset.map(
preprocess_train_single, # Use the single-item URL download function
batched=False, # Process item by item is NECESSARY for requests
num_proc=num_proc,
remove_columns=columns_to_remove, # Remove original cols AFTER processing
load_from_cache_file=True, # Enable caching of mapped results (highly recommended!)
desc="Downloading images and tokenizing captions",
)
logger.info(f"Dataset size after map (potential download/processing): {len(processed_dataset)}")
# --- Filter out None results (from errors in preprocess_train_single) ---
original_count = len(processed_dataset)
# Filter needs access to the function's return value; it implicitly gets the row
processed_dataset = processed_dataset.filter(lambda example: example is not None, num_proc=1)
new_count = len(processed_dataset)
if original_count != new_count:
logger.warning(f"Filtered out {original_count - new_count} entries due to download/processing errors.")
if new_count == 0:
raise RuntimeError("Dataset is empty after preprocessing and filtering. Check download errors (network, timeouts, invalid URLs/images), dataset integrity, and --max_train_samples.")
logger.info(f"Final dataset size after filtering: {new_count}")
# --- Set Format and Collate ---
try:
# logger.info(f"Attempting to set dataset format to 'torch' for columns: ['pixel_values', 'input_ids']")
# Ensure columns exist before setting format
final_columns = processed_dataset.column_names
columns_to_set = [col for col in ["pixel_values", "input_ids"] if col in final_columns]
if columns_to_set:
processed_dataset.set_format(type="torch", columns=columns_to_set)
logger.info(f"Successfully set dataset format to 'torch' for columns: {columns_to_set}.")
# Optional: Print a sample to verify
# if len(processed_dataset) > 0:
# sample = processed_dataset[0]
# pv_type = type(sample['pixel_values']) if 'pixel_values' in sample else 'Missing'
# id_type = type(sample['input_ids']) if 'input_ids' in sample else 'Missing'
# logger.info(f"Sample 0 types after set_format: pixel_values={pv_type}, input_ids={id_type}")
# if isinstance(sample.get('pixel_values'), torch.Tensor): logger.info(f" PV shape: {sample['pixel_values'].shape}")
# if isinstance(sample.get('input_ids'), torch.Tensor): logger.info(f" ID shape: {sample['input_ids'].shape}")
else:
logger.warning(f"Columns {['pixel_values', 'input_ids']} not found after filtering/mapping, skipping set_format. Available columns: {final_columns}")
except Exception as e:
logger.error(f"Failed to set dataset format to torch: {e}", exc_info=True)
# Consider raising the error if this step is critical
# raise RuntimeError("Failed to set dataset format") from e
# --- Collate Function (Stacks tensors from the list of dicts) ---
def collate_fn(examples):
# Filter out any potential None values that might have slipped through (should be rare after .filter)
valid_examples = [e for e in examples if e is not None and "pixel_values" in e and "input_ids" in e]
if not valid_examples:
# This might happen if a whole batch worth of URLs failed *concurrently*
# logger.warning("Collate function received an empty list of valid examples. Returning empty batch.")
return {} # Return empty dict, training loop MUST handle this
try:
# Stack tensors from the list of dictionaries
pixel_values = torch.stack([example["pixel_values"] for example in valid_examples])
input_ids = torch.stack([example["input_ids"] for example in valid_examples])
except Exception as e:
logger.error(f"Error during collation (likely size mismatch or invalid data): {e}", exc_info=True)
# Log shapes of first few items to help debug
for i, ex in enumerate(valid_examples[:5]):
pv_shape = ex["pixel_values"].shape if isinstance(ex.get("pixel_values"), torch.Tensor) else type(ex.get("pixel_values"))
id_shape = ex["input_ids"].shape if isinstance(ex.get("input_ids"), torch.Tensor) else type(ex.get("input_ids"))
logger.error(f" Example {i}: PV shape/type={pv_shape}, ID shape/type={id_shape}")
return {} # Return empty dict on error
# Final check for safety
if pixel_values.shape[0] != input_ids.shape[0]:
logger.error(f"Collation error: Mismatched batch sizes after stacking. Images: {pixel_values.shape[0]}, Texts: {input_ids.shape[0]}. Skipping batch.")
return {}
return {"pixel_values": pixel_values, "input_ids": input_ids}
logger.info("Dataset preparation function finished.")
#return processed_dataset, collate_fn
return processed_dataset, collate_fn, new_count, original_count # Return final count and pre-filter count
# --- Main Training Function ---
def main():
# --- Parse Args FIRST ---
args = parse_args()
# --- Initialize Accelerator SECOND ---
logging_dir = Path(args.output_dir, args.logging_dir) # Use Path object
accelerator_project_config = ProjectConfiguration(project_dir=str(args.output_dir), logging_dir=str(logging_dir)) # Ensure strings
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# --- Setup Logging THIRD (Now Accelerator is ready) ---
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s [%(process)d] - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO, # Keep base level INFO
)
logger = get_logger(__name__, log_level="INFO") # Use Accelerate logger
# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# --- Log Accelerator State and Config FOURTH ---
logger.info(f"Accelerator state: {accelerator.state}", main_process_only=False)
# Set higher level for frequently noisy libraries during download/processing
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
# --- Log Parsed Arguments FIFTH ---
logger.info("Starting training script with arguments:")
for k, v in sorted(vars(args).items()):
logger.info(f" {k}: {v}")
logger.info(f"Using dataset: '{args.dataset_name}'")
logger.info(f"Using image column: '{args.image_column}', caption column: '{args.caption_column}'")
# --- Set Seed ---
if args.seed is not None:
set_seed(args.seed)
logger.info(f"Set random seed to {args.seed}")
# --- Handle Hub Repo and Output Dir ---
repo_id = None
if accelerator.is_main_process:
output_dir_path = Path(args.output_dir) # Use Path object
if args.output_dir:
output_dir_path.mkdir(parents=True, exist_ok=True) # Use Path object method
logger.info(f"Output directory ensured: {args.output_dir}")
if args.push_to_hub:
# Hub creation logic... (kept as is)
try:
repo_id = create_repo(
repo_id=args.hub_model_id or output_dir_path.name, exist_ok=True, token=args.hub_token
).repo_id
logger.info(f"Created/verified Hub repo: {repo_id}")
except Exception as e:
logger.error(f"Failed to create/verify Hub repo: {e}", exc_info=True)
logger.warning("Disabling Hub push due to error.")
args.push_to_hub = False
# --- Load models and tokenizer ---
# (Keep this section as is, assuming base model parts are fine)
logger.info("Loading tokenizer...")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", cache_dir=args.cache_dir)
logger.info("Loading text encoder...")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", cache_dir=args.cache_dir)
logger.info("Loading VAE...")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", cache_dir=args.cache_dir)
logger.info("Loading noise scheduler...")
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# --- Load and modify U-Net ---
# (Keep Mamba replacement logic as is)
logger.info("Loading base U-Net state dict...");
try:
# Use low_cpu_mem_usage=False initially, maybe True causes issues with config loading indirectly?
original_unet_state_dict = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", cache_dir=args.cache_dir, low_cpu_mem_usage=False).state_dict()
except TypeError:
logger.warning("low_cpu_mem_usage=False failed for UNet loading (unexpected), trying with low_cpu_mem_usage=True.")
original_unet_state_dict = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", cache_dir=args.cache_dir, low_cpu_mem_usage=True).state_dict()
except Exception as load_err:
logger.error(f"Failed to load original UNet state_dict: {load_err}", exc_info=True)
raise
logger.info("Creating new U-Net structure...");
unet_config = UNet2DConditionModel.load_config(args.pretrained_model_name_or_path, subfolder="unet")
unet = UNet2DConditionModel.from_config(unet_config)
logger.info("Replacing U-Net Self-Attention with Mamba blocks...");
mamba_kwargs = {'d_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand}
unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs)
logger.info("Loading partial pre-trained weights into new structure...");
modified_keys = set(unet.state_dict().keys())
filtered_state_dict = {
k: v for k, v in original_unet_state_dict.items()
if k in modified_keys and unet.state_dict()[k].shape == v.shape
}
load_result = unet.load_state_dict(filtered_state_dict, strict=False)
logger.info(f"U-Net Load Result - Missing Keys: {len(load_result.missing_keys)}, Unexpected Keys: {len(load_result.unexpected_keys)}")
# Log some examples if needed for debugging Mamba replacement
if accelerator.is_main_process:
if load_result.missing_keys: logger.debug(f" Example Missing Keys (likely Mamba): {load_result.missing_keys[:5]}...")
if load_result.unexpected_keys: logger.debug(f" Example Unexpected Keys (likely Attention): {load_result.unexpected_keys[:5]}...")
del original_unet_state_dict, filtered_state_dict # Free memory
# --- Freeze/Unfreeze logic ---
# (Keep Mamba unfreezing logic as is)
vae.requires_grad_(False); text_encoder.requires_grad_(False); unet.requires_grad_(False)
logger.info("Froze VAE and Text Encoder.")
logger.info("Unfreezing specified Mamba/Norm parameters in U-Net...")
unfrozen_params_count = 0; total_params_count = 0; unfrozen_param_names = []
trainable_params = [] # Store parameters to optimize
for name, param in unet.named_parameters():
total_params_count += param.numel()
module_path_parts = name.split('.')
should_unfreeze = False
# Check if it's directly within a MambaSequentialBlock
# Or if it's norm1 related to a replaced BasicTransformerBlock's attn1
current_module = unet
is_in_mamba_block = False
try:
for part in module_path_parts[:-1]: # Iterate down to the parent module
current_module = getattr(current_module, part)
if isinstance(current_module, MambaSequentialBlock):
is_in_mamba_block = True
break
if is_in_mamba_block:
should_unfreeze = True
else:
# Check for the norm1 pattern after replacement
is_norm1 = name.endswith(".norm1.weight") or name.endswith(".norm1.bias")
if is_norm1 and len(module_path_parts) > 2:
grandparent_module_path = '.'.join(module_path_parts[:-2])
grandparent_module = unet.get_submodule(grandparent_module_path)
# Check if the grandparent used to be a BasicTransformerBlock
# and its attn1 is now a Mamba block
# This relies on the structure post-replacement.
# A safer check might involve inspecting the replacement map if available.
# Assuming direct replacement:
if isinstance(grandparent_module, BasicTransformerBlock) and hasattr(grandparent_module, 'attn1') and isinstance(grandparent_module.attn1, MambaSequentialBlock):
should_unfreeze = True
except AttributeError:
pass # Module path doesn't exist
if should_unfreeze:
param.requires_grad_(True)
unfrozen_params_count += param.numel()
unfrozen_param_names.append(name)
trainable_params.append(param) # Add to list for optimizer
logger.info(f"Unfroze {unfrozen_params_count} / {total_params_count} parameters ({unfrozen_params_count/total_params_count:.2%}) in U-Net.")
if unfrozen_params_count > 0 and accelerator.is_main_process: logger.info(f"Example unfrozen parameters: {unfrozen_param_names[:5]}...")
elif unfrozen_params_count == 0: logger.error("CRITICAL: No U-Net parameters were unfrozen! Check Mamba replacement and unfreezing logic."); exit(1)
# --- Optimizations ---
if args.gradient_checkpointing: unet.enable_gradient_checkpointing(); logger.info("Enabled gradient checkpointing for U-Net.")
if args.allow_tf32:
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
logger.info("Allowing TF32 for matmul and cuDNN.")
torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True
else: logger.info("TF32 not enabled (requires Ampere+ GPU or CUDA setup).")
if is_xformers_available():
try: unet.enable_xformers_memory_efficient_attention(); logger.info("Enabled xformers memory efficient attention.")
except Exception as e: logger.warning(f"Could not enable xformers (may not be relevant if Mamba replaced all): {e}.")
# --- Optimizer ---
logger.info(f"Number of trainable parameters for optimizer: {len(trainable_params)}")
if not trainable_params: logger.error("CRITICAL: No trainable parameters found for optimizer!"); exit(1)
if args.use_8bit_adam:
try: import bitsandbytes as bnb; optimizer_cls = bnb.optim.AdamW8bit; logger.info("Using 8-bit AdamW optimizer.")
except ImportError: logger.warning("bitsandbytes not installed. Falling back to standard AdamW."); optimizer_cls = torch.optim.AdamW
else: optimizer_cls = torch.optim.AdamW; logger.info("Using standard AdamW optimizer.")
# Scale LR?
if args.scale_lr:
# Note: trainable_params might be only a subset of unet params
# Scaling based on total batch size is common
effective_total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
args.learning_rate = args.learning_rate * effective_total_batch_size
logger.info(f"Scaled learning rate to {args.learning_rate} (original * {effective_total_batch_size})")
optimizer = optimizer_cls(trainable_params, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon)
# --- Dataset and DataLoader ---
logger.info("Preparing dataset and dataloader (will download images during mapping)...")
try: train_dataset, collate_fn, final_dataset_size, count_before_filter = prepare_dataset(args, tokenizer, logger) # Receive new counts
except Exception as e: logger.error(f"Failed during dataset preparation: {e}", exc_info=True); exit(1)
final_dataset_size = len(train_dataset)
#logger.info(f"Successfully prepared dataset metadata. Final size after filtering errors: {final_dataset_size}")
#if final_dataset_size == 0: logger.error("Training dataset is empty after filtering download/processing errors. Cannot train."); exit(1)
# Add a warning if filtering removed a large percentage
#if args.max_train_samples:
#initial_sample_count = min(args.max_train_samples, len(load_dataset(args.dataset_name, cache_dir=args.cache_dir, split=split_to_use))) # Re-check initial selected count
#if initial_sample_count > 0:
#filter_ratio = (initial_sample_count - final_dataset_size) / initial_sample_count
#if filter_ratio > 0.5: # Warn if > 50% filtered
#logger.warning(f"High filtering ratio: Filtered {initial_sample_count - final_dataset_size}/{initial_sample_count} ({filter_ratio:.1%}) samples due to errors. Check network/dataset quality.")
logger.info(f"Successfully prepared dataset. Final size after filtering errors: {final_dataset_size}") # Use the returned final_dataset_size
if final_dataset_size == 0: logger.error("Training dataset is empty after filtering download/processing errors. Cannot train."); exit(1)
# Optional: Re-implement the warning using the returned counts (more efficient)
if count_before_filter > 0: # Check if we had samples before filtering
filter_ratio = (count_before_filter - final_dataset_size) / count_before_filter
# Adjust threshold for warning if needed (e.g., warn if > 20% filtered)
if filter_ratio > 0.2:
logger.warning(f"Filtering ratio: Filtered {count_before_filter - final_dataset_size}/{count_before_filter} ({filter_ratio:.1%}) samples during download/processing. Check network/dataset quality if ratio is high.")
elif args.max_train_samples:
# This case means even after map, the count was 0.
logger.warning(f"Dataset size before filtering was 0, despite requesting samples. Initial map/download may have failed for all items.")
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True, # Shuffle the filtered dataset
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
pin_memory=True, # Usually good if workers > 0
persistent_workers=True if args.dataloader_num_workers > 0 else False, # Avoid worker startup overhead
)
logger.info("DataLoader created.")
# --- Calculate training steps ---
# Need to account for possibility of len(train_dataloader) being 0 if batch_size > final_dataset_size
if len(train_dataloader) == 0 and final_dataset_size > 0:
logger.warning(f"DataLoader length is 0 but dataset size is {final_dataset_size}. Check batch size ({args.train_batch_size}). Effective steps per epoch will be 0.")
num_update_steps_per_epoch = 0
elif len(train_dataloader) == 0 and final_dataset_size == 0:
logger.error("Both dataset size and dataloader length are 0. Cannot train.")
exit(1)
else:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if num_update_steps_per_epoch == 0: logger.error("Cannot calculate max_train_steps (steps per epoch is 0). Please set --max_train_steps explicitly."); exit(1)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
logger.info(f"Calculated max_train_steps: {args.max_train_steps} ({args.num_train_epochs} epochs * {num_update_steps_per_epoch} steps/epoch)")
else:
if num_update_steps_per_epoch > 0: args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch); logger.info(f"Training for {args.max_train_steps} steps (~{args.num_train_epochs} epochs).")
else: args.num_train_epochs = 0; logger.warning(f"Training for {args.max_train_steps} steps, but calculated steps per epoch is zero.")
# --- Scheduler ---
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes) # Scale steps for scheduler? Often done this way. Check get_scheduler docs. Let's stick to global steps for now.
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps, # Warmup over global steps
num_training_steps=args.max_train_steps # Total global steps
)
logger.info(f"Initialized LR scheduler: {args.lr_scheduler} ({args.lr_warmup_steps} warmup, {args.max_train_steps} total steps).")
# --- Prepare with Accelerator ---
logger.info("Preparing models, optimizer, dataloader, and scheduler with Accelerator...")
# Order matters: models, optimizer, dataloader, scheduler
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
logger.info("Accelerator preparation finished.")
# --- Cast non-trainable models ---
# Determine dtype AFTER accelerator.prepare (as it might change model dtype based on mixed_precision)
# However, non-trainable models should be cast manually.
weight_dtype = torch.float32 # Default
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
#logger.info(f"Moving VAE and Text Encoder to device {accelerator.device} and casting to {weight_dtype}...")
#vae.to(accelerator.device, dtype=weight_dtype)
#text_encoder.to(accelerator.device, dtype=weight_dtype)
#logger.info("Casting finished.")
logger.info(f"Moving VAE and Text Encoder to device {accelerator.device} (keeping float32)...")
vae.to(accelerator.device)
text_encoder.to(accelerator.device)
# --- Init trackers ---
if accelerator.is_main_process:
tracker_project_name = "mamba-sd-train-url"
# Sanitize dataset name for run name
clean_dataset_name = args.dataset_name.split('/')[-1].replace('-', '_').replace('/','_').replace('.','_') if args.dataset_name else "local_data"
effective_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
run_name = f"{clean_dataset_name}_{args.max_train_samples or 'all'}samples_lr{args.learning_rate}_bs{effective_batch_size}_mamba{args.mamba_d_state}-{args.mamba_d_conv}-{args.mamba_expand}"
try:
accelerator.init_trackers(tracker_project_name, config=vars(args), init_kwargs={"wandb": {"name": run_name}})
logger.info(f"Initialized trackers (Project: {tracker_project_name}, Run: {run_name})")
except Exception as e: logger.warning(f"Could not initialize trackers ({args.report_to}): {e}.")
# --- Resume logic ---
# (Keep resume logic as is)
global_step = 0; first_epoch = 0; resume_step = 0
if args.resume_from_checkpoint:
checkpoint_path = None
checkpoint_dir = Path(args.output_dir)
if args.resume_from_checkpoint == "latest":
# Find the latest checkpoint directory based on step number
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")]
if dirs:
try:
latest_checkpoint = max(dirs, key=lambda d: int(d.name.split('-')[-1]))
checkpoint_path = str(latest_checkpoint)
logger.info(f"Resuming from latest checkpoint: {checkpoint_path}")
except (ValueError, IndexError):
logger.warning(f"Could not determine step number from checkpoint names in {checkpoint_dir}. Cannot resume 'latest'.")
args.resume_from_checkpoint = None # Disable resume
else: logger.info("No 'latest' checkpoint found to resume from."); args.resume_from_checkpoint = None
else: checkpoint_path = args.resume_from_checkpoint
if checkpoint_path and os.path.isdir(checkpoint_path):
logger.info(f"Attempting resume from specific checkpoint: {checkpoint_path}")
try:
accelerator.load_state(checkpoint_path)
# Extract global step from checkpoint directory name
path_stem = Path(checkpoint_path).stem
global_step = int(path_stem.split("-")[-1])
logger.info(f"Loaded state. Resuming from global step {global_step}.")
# Recalculate steps per epoch AFTER prepare (dataloader length might change)
steps_per_epoch_after_prepare = 0
if len(train_dataloader) > 0:
steps_per_epoch_after_prepare = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if steps_per_epoch_after_prepare > 0:
first_epoch = global_step // steps_per_epoch_after_prepare
resume_step = global_step % steps_per_epoch_after_prepare
logger.info(f"Calculated resume point: Epoch {first_epoch}, Step within epoch ~{resume_step}.")
else:
logger.warning("Steps/epoch is 0 after prepare. Cannot accurately calculate resume epoch/step within epoch. Starting from epoch 0.")
first_epoch = 0; resume_step = 0
except FileNotFoundError: logger.error(f"Resume checkpoint directory not found: {checkpoint_path}. Starting fresh."); global_step=0; first_epoch=0; resume_step=0
except (ValueError, IndexError): logger.error(f"Could not parse step number from checkpoint name: {checkpoint_path}. Starting fresh."); global_step=0; first_epoch=0; resume_step=0
except Exception as e: logger.error(f"Failed to load checkpoint state: {e}. Starting fresh.", exc_info=True); global_step=0; first_epoch=0; resume_step=0
elif args.resume_from_checkpoint: logger.warning(f"Resume checkpoint path invalid or not found: '{args.resume_from_checkpoint}'. Starting fresh."); global_step=0; first_epoch=0; resume_step=0
else: # Case where resume_from_checkpoint was 'latest' but none found
logger.info("Starting training from scratch (no checkpoint to resume)."); global_step=0; first_epoch=0; resume_step=0
# --- Training Loop ---
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info(f"***** Running training ({args.dataset_name} - {final_dataset_size} Effective Samples) *****")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (effective) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
logger.info(f" Starting Epoch = {first_epoch}")
logger.info(f" Starting Global Step = {global_step}")
logger.info(f" Resume Step in Epoch (approx) = {resume_step}") # Approx because dataloader shuffling changes order
progress_bar = tqdm(range(global_step, args.max_train_steps), initial=global_step, total=args.max_train_steps, desc="Optimization Steps", disable=not accelerator.is_local_main_process)
# >>> Determine the weight_dtype based on mixed precision AFTER accelerator is ready <<<
# This was likely done before the loop, ensure 'weight_dtype' is defined in this scope
# Add it here for clarity if it wasn't defined before the loop in main()
weight_dtype = torch.float32 # Default
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Make sure VAE and Text Encoder are kept in float32 as per previous fix
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
logger.info(f"--- Starting Epoch {epoch} ---")
for step, batch in enumerate(train_dataloader):
if not batch or "pixel_values" not in batch or batch["pixel_values"].shape[0] == 0:
if accelerator.is_main_process:
if global_step % 100 == 0:
logger.warning(f"Skipping empty/invalid batch at raw step {step} (Epoch {epoch}, Global ~{global_step}). Likely due to download/collation errors.")
continue
# --- Accumulate Gradients ---
with accelerator.accumulate(unet):
try:
# --- >>> MODIFIED FORWARD PASS START <<< ---
# pixel_values usually float32 from dataloader/transforms
pixel_values = batch["pixel_values"].to(accelerator.device)
# 1. VAE Encoding (VAE is float32 on accelerator.device)
with torch.no_grad():
# Explicitly cast VAE input to float32
latents = vae.encode(pixel_values.to(dtype=torch.float32)).latent_dist.sample() * vae.config.scaling_factor
# 'latents' are float32 output from VAE
# 2. Prepare Noise (matches latents dtype -> float32) and Timesteps
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
# 3. Add Noise (scheduler handles dtypes, noisy_latents should be float32)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# 4. Text Encoding (Text Encoder is float32 on accelerator.device)
with torch.no_grad():
input_ids = batch["input_ids"].to(accelerator.device)
# Output 'encoder_hidden_states' is float32
encoder_hidden_states = text_encoder(input_ids)[0]
# --- 5. Cast UNet inputs to mixed precision type ---
# 'weight_dtype' is float16 or bfloat16 if mixed precision is enabled
noisy_latents_input = noisy_latents.to(dtype=weight_dtype)
encoder_hidden_states_input = encoder_hidden_states.to(dtype=weight_dtype)
# --- End Cast ---
# 6. Predict Noise using UNet (UNet runs in mixed precision)
model_pred = unet(
noisy_latents_input,
timesteps,
encoder_hidden_states_input
).sample
# 'model_pred' is likely in weight_dtype (e.g., float16)
# 7. Get Target for Loss (Should be float32)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise # noise is float32
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps) # float32
else:
raise ValueError(f"Unsupported prediction type {noise_scheduler.config.prediction_type}")
# 8. Calculate Loss (Cast BOTH model_pred and target to float32)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# 9. Gather Loss for Logging
avg_loss = accelerator.gather(loss.unsqueeze(0)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# --- >>> MODIFIED FORWARD PASS END <<< ---
# --- Backward Pass ---
accelerator.backward(loss)
# --- Optimizer Step ---
# This happens *outside* the try block but *inside* the sync_gradients check below
# DO NOT PUT OPTIMIZER STEP HERE
except Exception as forward_err:
logger.error(f"Error during training step {step} (Epoch {epoch}, Global ~{global_step}): {forward_err}", exc_info=True)
try:
pv_shape = batch.get('pixel_values').shape if batch and isinstance(batch.get('pixel_values'), torch.Tensor) else 'N/A'
id_shape = batch.get('input_ids').shape if batch and isinstance(batch.get('input_ids'), torch.Tensor) else 'N/A'
logger.error(f" Batch Shapes - Pixels: {pv_shape}, IDs: {id_shape}")
except Exception as log_err:
logger.error(f" (Could not log batch details: {log_err})")
continue # Skip to next batch
# --- Sync Gradients, Step Optimizer, Log, Checkpoint, Validate ---
if accelerator.sync_gradients:
try: # Wrap optimizer step and gradient clipping
if args.max_grad_norm > 0:
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
except Exception as optim_err:
logger.error(f"Error during optimizer step/grad clipping at Global Step {global_step}: {optim_err}", exc_info=True)
# Decide if you want to continue or stop on optimizer errors
continue # Skip to next step for now
# --- Progress Bar and Global Step ---
progress_bar.update(1)
global_step += 1
# --- Log Metrics ---
if accelerator.is_main_process:
logs = {"train_loss": train_loss} # Log the averaged accumulated loss
if hasattr(lr_scheduler, "get_last_lr"):
current_lr = lr_scheduler.get_last_lr()
logs["lr"] = current_lr[0] if isinstance(current_lr, list) else current_lr
else:
logs["lr"] = optimizer.param_groups[0]['lr'] # Get LR from optimizer if scheduler doesn't have method
try: accelerator.log(logs, step=global_step)
except Exception as log_err: logger.warning(f"Logging failed for step {global_step}: {log_err}")
train_loss = 0.0 # Reset accumulated loss for next set of accumulations
# --- Checkpointing ---
if global_step > 0 and global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = Path(args.output_dir) / f"checkpoint-{global_step}"
try:
logger.info(f"Saving checkpoint: {save_path}...")
accelerator.save_state(str(save_path))
unwrapped_unet = accelerator.unwrap_model(unet)
unet_save_path = save_path / "unet_mamba"
unwrapped_unet.save_pretrained(
str(unet_save_path),
state_dict=unwrapped_unet.state_dict(),
safe_serialization=True
)
logger.info(f"Checkpoint saved to {save_path}")
# Delete old checkpoints
if args.checkpoints_total_limit is not None and args.checkpoints_total_limit > 0:
checkpoint_dir = Path(args.output_dir)
ckpts = sorted(
[d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
key=lambda d: int(d.name.split("-")[-1])
)
if len(ckpts) > args.checkpoints_total_limit:
num_to_delete = len(ckpts) - args.checkpoints_total_limit
for old_ckpt in ckpts[:num_to_delete]:
logger.info(f"Deleting old checkpoint: {old_ckpt}")
shutil.rmtree(old_ckpt, ignore_errors=True) # Add ignore_errors
except Exception as ckpt_err: logger.error(f"Checkpoint saving failed for step {global_step}: {ckpt_err}", exc_info=True)
# --- Validation ---
run_validation = False
if args.validation_steps and global_step > 0 and global_step % args.validation_steps == 0:
run_validation = True
elif not args.validation_steps and args.validation_epochs > 0 and (epoch + 1) % args.validation_epochs == 0:
is_last_accum_step = step == len(train_dataloader) - 1
if is_last_accum_step: run_validation = True
if run_validation and accelerator.is_main_process:
logger.info(f"Running validation at Global Step {global_step} (Epoch {epoch})...")
log_validation_images = []
pipeline = None
original_unet_training_mode = unet.training # Store training mode
unet.eval() # Set unet to eval mode for validation
try:
# Models (VAE, Text Encoder) are already on device and float32
unet_val = accelerator.unwrap_model(unet) # Use unwrapped for pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet_val,
vae=vae, # Use the float32 vae
text_encoder=text_encoder, # Use the float32 text_encoder
tokenizer=tokenizer,
scheduler=noise_scheduler,
safety_checker=None,
torch_dtype=torch.float32, # <<< Run pipeline inference in float32 for stability >>>
# or torch_dtype=weight_dtype if you are sure about VAE/TextEncoder fp16 stability
cache_dir=args.cache_dir
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
logger.info(f"Generating {args.num_validation_images} validation images...")
for i in range(args.num_validation_images):
# Autocast might still be useful if internal pipeline ops benefit
with torch.autocast(str(accelerator.device).split(":")[0], dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"), torch.no_grad():
image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
log_validation_images.append(np.array(image))
if log_validation_images:
logger.info(f"Logging {len(log_validation_images)} validation images to trackers...")
try:
images_np = np.stack(log_validation_images)
accelerator.log({"validation_images": images_np}, step=global_step)
logger.info("Validation images logged.")
except Exception as tracker_err: logger.warning(f"Failed to log validation images: {tracker_err}")
else: logger.warning("No validation images were generated.")
except Exception as val_err:
logger.error(f"Validation failed at step {global_step}: {val_err}", exc_info=True)
finally:
# Cleanup pipeline and restore UNet training mode
if pipeline is not None: del pipeline
torch.cuda.empty_cache()
unet.train(original_unet_training_mode) # Restore original mode
logger.info("Validation run finished.")
# --- Update progress bar postfix ---
if accelerator.is_main_process:
try:
loss_val = loss.detach().item()
current_lr_val = lr_scheduler.get_last_lr()[0] if hasattr(lr_scheduler, "get_last_lr") else optimizer.param_groups[0]['lr']
logs_postfix = {"loss": f"{loss_val:.4f}", "lr": f"{current_lr_val:.2e}"}
progress_bar.set_postfix(**logs_postfix)
except NameError:
logs_postfix = {"loss": "N/A", "lr": optimizer.param_groups[0]['lr'] if optimizer.param_groups else 'N/A'}
progress_bar.set_postfix(**logs_postfix)
except Exception as pf_err:
logger.debug(f"Postfix update error: {pf_err}")
progress_bar.set_postfix({"step_status":"error"})
# --- Check for Training Completion ---
if global_step >= args.max_train_steps:
logger.info(f"Reached max_train_steps ({args.max_train_steps}). Stopping training.")
break # Exit step (inner) loop
# --- End of Epoch ---
logger.info(f"--- Finished Epoch {epoch} (Reached Global Step {global_step}) ---")
if global_step >= args.max_train_steps:
break # Exit epoch loop
# --- End of Training ---
logger.info("Training finished. Waiting for all processes...");
accelerator.wait_for_everyone();
if accelerator.is_main_process: progress_bar.close()
# Final Save
if accelerator.is_main_process:
logger.info("Saving final trained U-Net model...");
try:
unet_final = accelerator.unwrap_model(unet)
final_save_path = Path(args.output_dir)
unet_final.save_pretrained(
final_save_path / "unet_mamba_final",
safe_serialization=True,
state_dict=unet_final.state_dict()
)
logger.info(f"Final UNet saved to: {final_save_path / 'unet_mamba_final'}")
tokenizer.save_pretrained(str(final_save_path / "tokenizer_final"))
logger.info(f"Final Tokenizer saved to: {final_save_path / 'tokenizer_final'}")
except Exception as e: logger.error(f"Failed to save final UNet/Tokenizer: {e}", exc_info=True)
# Hub Push Logic
if args.push_to_hub:
logger.info("Attempting to push final model to Hub...");
if repo_id is None: logger.warning("Cannot push to Hub (repo_id not defined or Hub creation failed).")
else:
try:
logger.info(f"Pushing contents of {args.output_dir} to repository {repo_id}...");
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training - Mamba SD URL Text",
ignore_patterns=["step_*", "epoch_*", "checkpoint-*/**", "checkpoint-*/", "*.safetensors.index.json", "logs/**"],
token=args.hub_token
)
logger.info("Push to Hub successful.")
except Exception as e: logger.error(f"Hub upload failed: {e}", exc_info=True)
logger.info("Ending training script...");
accelerator.end_training();
logger.info("Script finished.")
# --- Entry Point ---
# (Keep the entry point section as it was)
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"\n\n !!! --- FATAL SCRIPT ERROR --- !!!")
print(f"Error Type: {type(e).__name__}")
print(f"Error Details: {e}")
print(f"Traceback:")
print(traceback.format_exc())
print(f" !!! --- SCRIPT TERMINATED DUE TO ERROR --- !!!")
exit(1)