|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import shutil |
|
|
import random |
|
|
from pathlib import Path |
|
|
import traceback |
|
|
import io |
|
|
import requests |
|
|
|
|
|
import accelerate |
|
|
import datasets |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
import transformers |
|
|
from accelerate import Accelerator |
|
|
from accelerate.logging import get_logger |
|
|
from accelerate.state import AcceleratorState |
|
|
from accelerate.utils import ProjectConfiguration, set_seed |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import create_repo, upload_folder |
|
|
from packaging import version |
|
|
from torchvision import transforms |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from transformers.utils import ContextManagers |
|
|
from PIL import Image, UnidentifiedImageError |
|
|
|
|
|
import diffusers |
|
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
|
|
from diffusers.optimization import get_scheduler |
|
|
from diffusers.training_utils import EMAModel |
|
|
from diffusers.utils import check_min_version, deprecate, is_wandb_available |
|
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba |
|
|
print("Successfully imported MambaSequentialBlock and replacement function from msd_utils.py.") |
|
|
except ImportError as e: |
|
|
print("="*50); print("ERROR: Failed to import from msd_utils.py!"); print("Ensure 'msd_utils.py' exists and contains necessary definitions."); print(f"Import error: {e}"); print("="*50); exit(1) |
|
|
|
|
|
|
|
|
from diffusers.models.attention import BasicTransformerBlock |
|
|
|
|
|
check_min_version("0.28.0") |
|
|
|
|
|
|
|
|
DEFAULT_IMAGE_COLUMN = "URL" |
|
|
DEFAULT_CAPTION_COLUMN = "TEXT" |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Train Stable Diffusion with Mamba using a URL/Text dataset (e.g., MS_COCO_2017_URL_TEXT).") |
|
|
|
|
|
parser.add_argument("--pretrained_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5", help="Path to pretrained model or model identifier from huggingface.co/models.") |
|
|
|
|
|
parser.add_argument( |
|
|
"--dataset_name", |
|
|
type=str, |
|
|
default="ChristophSchuhmann/MS_COCO_2017_URL_TEXT", |
|
|
help="The HuggingFace dataset identifier for a dataset with URL and TEXT columns." |
|
|
) |
|
|
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data (Not recommended for URL datasets). Overrides --dataset_name.") |
|
|
parser.add_argument( |
|
|
"--image_column", |
|
|
type=str, |
|
|
default=DEFAULT_IMAGE_COLUMN, |
|
|
help="The column of the dataset containing image URLs." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--caption_column", |
|
|
type=str, |
|
|
default=DEFAULT_CAPTION_COLUMN, |
|
|
help="The column of the dataset containing single text captions." |
|
|
) |
|
|
parser.add_argument("--max_train_samples", type=int, default=5000, help="Limit the number of training examples. Loads dataset metadata first, then selects.") |
|
|
|
|
|
parser.add_argument("--validation_prompt", type=str, default="A photo of a busy city street with cars and pedestrians.", help="A prompt that is used during validation to verify that the model is learning.") |
|
|
parser.add_argument("--num_validation_images", type=int, default=4, help="Number of images that should be generated during validation with `validation_prompt`.") |
|
|
parser.add_argument("--validation_epochs", type=int, default=1, help="Run validation every X epochs (ignored if validation_steps is set).") |
|
|
parser.add_argument("--validation_steps", type=int, default=500, help="Run validation every X steps. Overrides validation_epochs.") |
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="sd-mamba-trained-mscoco-urltext-5k", help="The output directory where the model predictions and checkpoints will be written.") |
|
|
parser.add_argument("--cache_dir", type=str, default=None, help="Directory to cache downloaded models and datasets.") |
|
|
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") |
|
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") |
|
|
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") |
|
|
parser.add_argument("--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.") |
|
|
|
|
|
parser.add_argument("--resolution", type=int, default=512, help="The resolution for input images, all images will be resized to this size.") |
|
|
parser.add_argument("--center_crop", action="store_true", default=True, help="Whether to center crop images after downloading.") |
|
|
parser.add_argument("--random_flip", action="store_true", default=True, help="Whether to randomly flip images horizontally.") |
|
|
|
|
|
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") |
|
|
parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") |
|
|
parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. Overrides num_train_epochs.") |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.") |
|
|
parser.add_argument("--gradient_checkpointing", action="store_true", default=True, help="Whether to use gradient checkpointing to save memory.") |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Initial learning rate (after the potential warmup period) to use.") |
|
|
parser.add_argument("--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.") |
|
|
parser.add_argument("--lr_scheduler", type=str, default="cosine", help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]') |
|
|
parser.add_argument("--lr_warmup_steps", type=int, default=100, help="Number of steps for the warmup in the learning rate scheduler.") |
|
|
parser.add_argument("--use_8bit_adam", action="store_true", default=True, help="Whether to use 8-bit AdamW optimizer.") |
|
|
parser.add_argument("--allow_tf32", action="store_true", default=True, help="Whether to allow TF32 on Ampere GPUs. Can speed up training.") |
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=8, help="Number of subprocesses to use for data loading.") |
|
|
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
|
|
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
|
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
|
|
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") |
|
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
|
|
|
|
|
parser.add_argument("--logging_dir", type=str, default="logs", help="Location for TensorBoard logs.") |
|
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision.") |
|
|
parser.add_argument("--report_to", type=str, default="tensorboard", help='The integration to report the results and logs to.') |
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
|
|
|
|
|
parser.add_argument("--checkpointing_steps", type=int, default=500, help="Save a checkpoint of the training state every X updates steps.") |
|
|
parser.add_argument("--checkpoints_total_limit", type=int, default=3, help="Max number of checkpoints to store.") |
|
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Whether to resume training from a previous checkpoint directory or 'latest'.") |
|
|
|
|
|
parser.add_argument("--mamba_d_state", type=int, default=16, help="Mamba ssm state dimension.") |
|
|
parser.add_argument("--mamba_d_conv", type=int, default=4, help="Mamba ssm convolution dimension.") |
|
|
parser.add_argument("--mamba_expand", type=int, default=2, help="Mamba ssm expansion factor.") |
|
|
|
|
|
parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for data preprocessing (defaults to cpu count capped at 16).") |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
|
|
if env_local_rank != -1 and env_local_rank != args.local_rank: |
|
|
print(f"INFO: Overriding local_rank {args.local_rank} with environment variable LOCAL_RANK {env_local_rank}") |
|
|
args.local_rank = env_local_rank |
|
|
|
|
|
|
|
|
if args.dataset_name is None and args.train_data_dir is None: |
|
|
raise ValueError("Need either --dataset_name or --train_data_dir") |
|
|
if args.dataset_name and args.train_data_dir: |
|
|
print("WARNING: Both --dataset_name and --train_data_dir provided. Using --dataset_name as URL dataset is specified.") |
|
|
args.train_data_dir = None |
|
|
|
|
|
|
|
|
if args.preprocessing_num_workers is None: |
|
|
try: |
|
|
|
|
|
args.preprocessing_num_workers = min(len(os.sched_getaffinity(0)), 16) |
|
|
except AttributeError: |
|
|
args.preprocessing_num_workers = min(os.cpu_count(), 16) |
|
|
print(f"INFO: Auto-detected preprocessing_num_workers: {args.preprocessing_num_workers}") |
|
|
|
|
|
|
|
|
if args.max_train_samples is not None and args.max_train_samples <= 0: |
|
|
raise ValueError("--max_train_samples must be a positive integer.") |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataset(args, tokenizer, logger): |
|
|
"""Loads, selects, preprocesses (downloads URLs), and filters the dataset.""" |
|
|
if args.dataset_name is not None: |
|
|
logger.info(f"Loading dataset '{args.dataset_name}' metadata...") |
|
|
try: |
|
|
|
|
|
dataset = load_dataset( |
|
|
args.dataset_name, |
|
|
cache_dir=args.cache_dir, |
|
|
|
|
|
|
|
|
) |
|
|
logger.info("Dataset metadata loaded successfully.") |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset '{args.dataset_name}': {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
|
|
|
split_to_use = "train" |
|
|
if split_to_use not in dataset: |
|
|
available_splits = list(dataset.keys()) |
|
|
if len(available_splits) == 1: |
|
|
split_to_use = available_splits[0] |
|
|
logger.warning(f"'train' split not found. Using the only available split: '{split_to_use}'.") |
|
|
else: |
|
|
raise ValueError(f"'train' split not found in dataset '{args.dataset_name}'. Available splits: {available_splits}. Please check the dataset structure or specify the split.") |
|
|
dataset = dataset[split_to_use] |
|
|
logger.info(f"Using '{split_to_use}' split. Initial size: {len(dataset)}") |
|
|
logger.info(f"Dataset features: {dataset.features}") |
|
|
|
|
|
|
|
|
else: |
|
|
logger.error("Local data directory loading (--train_data_dir) is not the intended use case for this script modification.") |
|
|
raise NotImplementedError("This script is modified for URL datasets via --dataset_name.") |
|
|
|
|
|
|
|
|
column_names = dataset.column_names |
|
|
logger.info(f"Original dataset columns: {column_names}") |
|
|
if args.image_column not in column_names: |
|
|
raise ValueError(f"--image_column '{args.image_column}' not found in dataset '{args.dataset_name}'. Available columns: {column_names}") |
|
|
if args.caption_column not in column_names: |
|
|
raise ValueError(f"--caption_column '{args.caption_column}' not found in dataset '{args.dataset_name}'. Available columns: {column_names}") |
|
|
|
|
|
|
|
|
if args.max_train_samples is not None: |
|
|
num_samples = len(dataset) |
|
|
max_samples_to_select = min(args.max_train_samples, num_samples) |
|
|
if args.max_train_samples > num_samples: |
|
|
logger.warning( |
|
|
f"--max_train_samples ({args.max_train_samples}) is larger than the dataset size ({num_samples}). " |
|
|
f"Using all {num_samples} samples." |
|
|
) |
|
|
logger.info(f"Selecting {max_samples_to_select} samples from the dataset (shuffling first).") |
|
|
|
|
|
if max_samples_to_select < num_samples: |
|
|
dataset = dataset.shuffle(seed=args.seed).select(range(max_samples_to_select)) |
|
|
else: |
|
|
|
|
|
dataset = dataset.select(range(max_samples_to_select)) |
|
|
|
|
|
logger.info(f"Dataset size after selecting samples: {len(dataset)}") |
|
|
if len(dataset) == 0: |
|
|
raise ValueError(f"Selected 0 samples. Check --max_train_samples ({args.max_train_samples}) and dataset availability.") |
|
|
|
|
|
|
|
|
train_transforms = transforms.Compose([ |
|
|
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), |
|
|
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), |
|
|
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]), |
|
|
]) |
|
|
logger.info("Image transforms defined.") |
|
|
|
|
|
|
|
|
def preprocess_train_single(example): |
|
|
image_url = example[args.image_column] |
|
|
caption = example[args.caption_column] |
|
|
|
|
|
|
|
|
processed_image_tensor = None |
|
|
try: |
|
|
|
|
|
if not isinstance(image_url, str) or not image_url.startswith(("http://", "https://")): |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
response = requests.get(image_url, timeout=20, stream=False) |
|
|
response.raise_for_status() |
|
|
img_bytes = response.content |
|
|
if not img_bytes: |
|
|
raise ValueError("Empty image content received") |
|
|
|
|
|
image_pil = Image.open(io.BytesIO(img_bytes)) |
|
|
|
|
|
|
|
|
|
|
|
MAX_PIXELS = 4096 * 4096 |
|
|
if image_pil.width * image_pil.height > MAX_PIXELS: |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
image_pil = image_pil.convert("RGB") |
|
|
|
|
|
|
|
|
processed_image_tensor = train_transforms(image_pil) |
|
|
|
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
|
|
|
return None |
|
|
except requests.exceptions.TooManyRedirects: |
|
|
|
|
|
return None |
|
|
except requests.exceptions.SSLError: |
|
|
|
|
|
return None |
|
|
except requests.exceptions.RequestException as http_err: |
|
|
|
|
|
|
|
|
return None |
|
|
except UnidentifiedImageError: |
|
|
|
|
|
return None |
|
|
except ValueError as val_err: |
|
|
|
|
|
return None |
|
|
except OSError as os_err: |
|
|
|
|
|
return None |
|
|
except Exception as img_err: |
|
|
|
|
|
logger.warning(f"Generic error processing/transforming image from {image_url}: {img_err}. Skipping.") |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if processed_image_tensor is None: |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
try: |
|
|
caption_str = str(caption) if caption is not None else "" |
|
|
if not caption_str: |
|
|
|
|
|
return None |
|
|
|
|
|
inputs = tokenizer( |
|
|
caption_str, |
|
|
max_length=tokenizer.model_max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
input_ids_tensor = inputs.input_ids.squeeze(0) |
|
|
|
|
|
except Exception as tok_err: |
|
|
logger.warning(f"Error tokenizing caption '{str(caption)[:50]}...' for URL {image_url}: {tok_err}. Skipping.") |
|
|
return None |
|
|
|
|
|
|
|
|
return {"pixel_values": processed_image_tensor, "input_ids": input_ids_tensor} |
|
|
|
|
|
|
|
|
num_proc = args.preprocessing_num_workers |
|
|
logger.info(f"Preprocessing dataset (downloading URLs, single item processing) using {num_proc} workers...") |
|
|
|
|
|
|
|
|
|
|
|
columns_to_remove = dataset.column_names |
|
|
processed_dataset = dataset.map( |
|
|
preprocess_train_single, |
|
|
batched=False, |
|
|
num_proc=num_proc, |
|
|
remove_columns=columns_to_remove, |
|
|
load_from_cache_file=True, |
|
|
desc="Downloading images and tokenizing captions", |
|
|
) |
|
|
logger.info(f"Dataset size after map (potential download/processing): {len(processed_dataset)}") |
|
|
|
|
|
|
|
|
original_count = len(processed_dataset) |
|
|
|
|
|
processed_dataset = processed_dataset.filter(lambda example: example is not None, num_proc=1) |
|
|
new_count = len(processed_dataset) |
|
|
if original_count != new_count: |
|
|
logger.warning(f"Filtered out {original_count - new_count} entries due to download/processing errors.") |
|
|
if new_count == 0: |
|
|
raise RuntimeError("Dataset is empty after preprocessing and filtering. Check download errors (network, timeouts, invalid URLs/images), dataset integrity, and --max_train_samples.") |
|
|
logger.info(f"Final dataset size after filtering: {new_count}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
final_columns = processed_dataset.column_names |
|
|
columns_to_set = [col for col in ["pixel_values", "input_ids"] if col in final_columns] |
|
|
if columns_to_set: |
|
|
processed_dataset.set_format(type="torch", columns=columns_to_set) |
|
|
logger.info(f"Successfully set dataset format to 'torch' for columns: {columns_to_set}.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
logger.warning(f"Columns {['pixel_values', 'input_ids']} not found after filtering/mapping, skipping set_format. Available columns: {final_columns}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to set dataset format to torch: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(examples): |
|
|
|
|
|
valid_examples = [e for e in examples if e is not None and "pixel_values" in e and "input_ids" in e] |
|
|
|
|
|
if not valid_examples: |
|
|
|
|
|
|
|
|
return {} |
|
|
|
|
|
try: |
|
|
|
|
|
pixel_values = torch.stack([example["pixel_values"] for example in valid_examples]) |
|
|
input_ids = torch.stack([example["input_ids"] for example in valid_examples]) |
|
|
except Exception as e: |
|
|
logger.error(f"Error during collation (likely size mismatch or invalid data): {e}", exc_info=True) |
|
|
|
|
|
for i, ex in enumerate(valid_examples[:5]): |
|
|
pv_shape = ex["pixel_values"].shape if isinstance(ex.get("pixel_values"), torch.Tensor) else type(ex.get("pixel_values")) |
|
|
id_shape = ex["input_ids"].shape if isinstance(ex.get("input_ids"), torch.Tensor) else type(ex.get("input_ids")) |
|
|
logger.error(f" Example {i}: PV shape/type={pv_shape}, ID shape/type={id_shape}") |
|
|
return {} |
|
|
|
|
|
|
|
|
if pixel_values.shape[0] != input_ids.shape[0]: |
|
|
logger.error(f"Collation error: Mismatched batch sizes after stacking. Images: {pixel_values.shape[0]}, Texts: {input_ids.shape[0]}. Skipping batch.") |
|
|
return {} |
|
|
|
|
|
return {"pixel_values": pixel_values, "input_ids": input_ids} |
|
|
|
|
|
logger.info("Dataset preparation function finished.") |
|
|
|
|
|
return processed_dataset, collate_fn, new_count, original_count |
|
|
|
|
|
def main(): |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
logging_dir = Path(args.output_dir, args.logging_dir) |
|
|
accelerator_project_config = ProjectConfiguration(project_dir=str(args.output_dir), logging_dir=str(logging_dir)) |
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
mixed_precision=args.mixed_precision, |
|
|
log_with=args.report_to, |
|
|
project_config=accelerator_project_config, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s [%(process)d] - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process: |
|
|
datasets.utils.logging.set_verbosity_warning() |
|
|
transformers.utils.logging.set_verbosity_warning() |
|
|
diffusers.utils.logging.set_verbosity_info() |
|
|
else: |
|
|
datasets.utils.logging.set_verbosity_error() |
|
|
transformers.utils.logging.set_verbosity_error() |
|
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
logger.info(f"Accelerator state: {accelerator.state}", main_process_only=False) |
|
|
|
|
|
logging.getLogger("PIL").setLevel(logging.WARNING) |
|
|
logging.getLogger("requests").setLevel(logging.WARNING) |
|
|
logging.getLogger("urllib3").setLevel(logging.WARNING) |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting training script with arguments:") |
|
|
for k, v in sorted(vars(args).items()): |
|
|
logger.info(f" {k}: {v}") |
|
|
logger.info(f"Using dataset: '{args.dataset_name}'") |
|
|
logger.info(f"Using image column: '{args.image_column}', caption column: '{args.caption_column}'") |
|
|
|
|
|
|
|
|
|
|
|
if args.seed is not None: |
|
|
set_seed(args.seed) |
|
|
logger.info(f"Set random seed to {args.seed}") |
|
|
|
|
|
|
|
|
repo_id = None |
|
|
if accelerator.is_main_process: |
|
|
output_dir_path = Path(args.output_dir) |
|
|
if args.output_dir: |
|
|
output_dir_path.mkdir(parents=True, exist_ok=True) |
|
|
logger.info(f"Output directory ensured: {args.output_dir}") |
|
|
if args.push_to_hub: |
|
|
|
|
|
try: |
|
|
repo_id = create_repo( |
|
|
repo_id=args.hub_model_id or output_dir_path.name, exist_ok=True, token=args.hub_token |
|
|
).repo_id |
|
|
logger.info(f"Created/verified Hub repo: {repo_id}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create/verify Hub repo: {e}", exc_info=True) |
|
|
logger.warning("Disabling Hub push due to error.") |
|
|
args.push_to_hub = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", cache_dir=args.cache_dir) |
|
|
logger.info("Loading text encoder...") |
|
|
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", cache_dir=args.cache_dir) |
|
|
logger.info("Loading VAE...") |
|
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", cache_dir=args.cache_dir) |
|
|
logger.info("Loading noise scheduler...") |
|
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Loading base U-Net state dict..."); |
|
|
try: |
|
|
|
|
|
original_unet_state_dict = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", cache_dir=args.cache_dir, low_cpu_mem_usage=False).state_dict() |
|
|
except TypeError: |
|
|
logger.warning("low_cpu_mem_usage=False failed for UNet loading (unexpected), trying with low_cpu_mem_usage=True.") |
|
|
original_unet_state_dict = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", cache_dir=args.cache_dir, low_cpu_mem_usage=True).state_dict() |
|
|
except Exception as load_err: |
|
|
logger.error(f"Failed to load original UNet state_dict: {load_err}", exc_info=True) |
|
|
raise |
|
|
logger.info("Creating new U-Net structure..."); |
|
|
unet_config = UNet2DConditionModel.load_config(args.pretrained_model_name_or_path, subfolder="unet") |
|
|
unet = UNet2DConditionModel.from_config(unet_config) |
|
|
logger.info("Replacing U-Net Self-Attention with Mamba blocks..."); |
|
|
mamba_kwargs = {'d_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand} |
|
|
unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) |
|
|
logger.info("Loading partial pre-trained weights into new structure..."); |
|
|
modified_keys = set(unet.state_dict().keys()) |
|
|
filtered_state_dict = { |
|
|
k: v for k, v in original_unet_state_dict.items() |
|
|
if k in modified_keys and unet.state_dict()[k].shape == v.shape |
|
|
} |
|
|
load_result = unet.load_state_dict(filtered_state_dict, strict=False) |
|
|
logger.info(f"U-Net Load Result - Missing Keys: {len(load_result.missing_keys)}, Unexpected Keys: {len(load_result.unexpected_keys)}") |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
if load_result.missing_keys: logger.debug(f" Example Missing Keys (likely Mamba): {load_result.missing_keys[:5]}...") |
|
|
if load_result.unexpected_keys: logger.debug(f" Example Unexpected Keys (likely Attention): {load_result.unexpected_keys[:5]}...") |
|
|
del original_unet_state_dict, filtered_state_dict |
|
|
|
|
|
|
|
|
|
|
|
vae.requires_grad_(False); text_encoder.requires_grad_(False); unet.requires_grad_(False) |
|
|
logger.info("Froze VAE and Text Encoder.") |
|
|
logger.info("Unfreezing specified Mamba/Norm parameters in U-Net...") |
|
|
unfrozen_params_count = 0; total_params_count = 0; unfrozen_param_names = [] |
|
|
trainable_params = [] |
|
|
for name, param in unet.named_parameters(): |
|
|
total_params_count += param.numel() |
|
|
module_path_parts = name.split('.') |
|
|
should_unfreeze = False |
|
|
|
|
|
|
|
|
current_module = unet |
|
|
is_in_mamba_block = False |
|
|
try: |
|
|
for part in module_path_parts[:-1]: |
|
|
current_module = getattr(current_module, part) |
|
|
if isinstance(current_module, MambaSequentialBlock): |
|
|
is_in_mamba_block = True |
|
|
break |
|
|
if is_in_mamba_block: |
|
|
should_unfreeze = True |
|
|
else: |
|
|
|
|
|
is_norm1 = name.endswith(".norm1.weight") or name.endswith(".norm1.bias") |
|
|
if is_norm1 and len(module_path_parts) > 2: |
|
|
grandparent_module_path = '.'.join(module_path_parts[:-2]) |
|
|
grandparent_module = unet.get_submodule(grandparent_module_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(grandparent_module, BasicTransformerBlock) and hasattr(grandparent_module, 'attn1') and isinstance(grandparent_module.attn1, MambaSequentialBlock): |
|
|
should_unfreeze = True |
|
|
|
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
if should_unfreeze: |
|
|
param.requires_grad_(True) |
|
|
unfrozen_params_count += param.numel() |
|
|
unfrozen_param_names.append(name) |
|
|
trainable_params.append(param) |
|
|
|
|
|
logger.info(f"Unfroze {unfrozen_params_count} / {total_params_count} parameters ({unfrozen_params_count/total_params_count:.2%}) in U-Net.") |
|
|
if unfrozen_params_count > 0 and accelerator.is_main_process: logger.info(f"Example unfrozen parameters: {unfrozen_param_names[:5]}...") |
|
|
elif unfrozen_params_count == 0: logger.error("CRITICAL: No U-Net parameters were unfrozen! Check Mamba replacement and unfreezing logic."); exit(1) |
|
|
|
|
|
|
|
|
if args.gradient_checkpointing: unet.enable_gradient_checkpointing(); logger.info("Enabled gradient checkpointing for U-Net.") |
|
|
if args.allow_tf32: |
|
|
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: |
|
|
logger.info("Allowing TF32 for matmul and cuDNN.") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True |
|
|
else: logger.info("TF32 not enabled (requires Ampere+ GPU or CUDA setup).") |
|
|
if is_xformers_available(): |
|
|
try: unet.enable_xformers_memory_efficient_attention(); logger.info("Enabled xformers memory efficient attention.") |
|
|
except Exception as e: logger.warning(f"Could not enable xformers (may not be relevant if Mamba replaced all): {e}.") |
|
|
|
|
|
|
|
|
logger.info(f"Number of trainable parameters for optimizer: {len(trainable_params)}") |
|
|
if not trainable_params: logger.error("CRITICAL: No trainable parameters found for optimizer!"); exit(1) |
|
|
if args.use_8bit_adam: |
|
|
try: import bitsandbytes as bnb; optimizer_cls = bnb.optim.AdamW8bit; logger.info("Using 8-bit AdamW optimizer.") |
|
|
except ImportError: logger.warning("bitsandbytes not installed. Falling back to standard AdamW."); optimizer_cls = torch.optim.AdamW |
|
|
else: optimizer_cls = torch.optim.AdamW; logger.info("Using standard AdamW optimizer.") |
|
|
|
|
|
|
|
|
if args.scale_lr: |
|
|
|
|
|
|
|
|
effective_total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
args.learning_rate = args.learning_rate * effective_total_batch_size |
|
|
logger.info(f"Scaled learning rate to {args.learning_rate} (original * {effective_total_batch_size})") |
|
|
|
|
|
|
|
|
optimizer = optimizer_cls(trainable_params, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon) |
|
|
|
|
|
|
|
|
logger.info("Preparing dataset and dataloader (will download images during mapping)...") |
|
|
try: train_dataset, collate_fn, final_dataset_size, count_before_filter = prepare_dataset(args, tokenizer, logger) |
|
|
except Exception as e: logger.error(f"Failed during dataset preparation: {e}", exc_info=True); exit(1) |
|
|
final_dataset_size = len(train_dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Successfully prepared dataset. Final size after filtering errors: {final_dataset_size}") |
|
|
if final_dataset_size == 0: logger.error("Training dataset is empty after filtering download/processing errors. Cannot train."); exit(1) |
|
|
|
|
|
|
|
|
if count_before_filter > 0: |
|
|
filter_ratio = (count_before_filter - final_dataset_size) / count_before_filter |
|
|
|
|
|
if filter_ratio > 0.2: |
|
|
logger.warning(f"Filtering ratio: Filtered {count_before_filter - final_dataset_size}/{count_before_filter} ({filter_ratio:.1%}) samples during download/processing. Check network/dataset quality if ratio is high.") |
|
|
elif args.max_train_samples: |
|
|
|
|
|
logger.warning(f"Dataset size before filtering was 0, despite requesting samples. Initial map/download may have failed for all items.") |
|
|
train_dataloader = torch.utils.data.DataLoader( |
|
|
train_dataset, |
|
|
shuffle=True, |
|
|
collate_fn=collate_fn, |
|
|
batch_size=args.train_batch_size, |
|
|
num_workers=args.dataloader_num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=True if args.dataloader_num_workers > 0 else False, |
|
|
) |
|
|
logger.info("DataLoader created.") |
|
|
|
|
|
|
|
|
|
|
|
if len(train_dataloader) == 0 and final_dataset_size > 0: |
|
|
logger.warning(f"DataLoader length is 0 but dataset size is {final_dataset_size}. Check batch size ({args.train_batch_size}). Effective steps per epoch will be 0.") |
|
|
num_update_steps_per_epoch = 0 |
|
|
elif len(train_dataloader) == 0 and final_dataset_size == 0: |
|
|
logger.error("Both dataset size and dataloader length are 0. Cannot train.") |
|
|
exit(1) |
|
|
else: |
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
|
|
|
|
if args.max_train_steps is None: |
|
|
if num_update_steps_per_epoch == 0: logger.error("Cannot calculate max_train_steps (steps per epoch is 0). Please set --max_train_steps explicitly."); exit(1) |
|
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
|
|
logger.info(f"Calculated max_train_steps: {args.max_train_steps} ({args.num_train_epochs} epochs * {num_update_steps_per_epoch} steps/epoch)") |
|
|
else: |
|
|
if num_update_steps_per_epoch > 0: args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch); logger.info(f"Training for {args.max_train_steps} steps (~{args.num_train_epochs} epochs).") |
|
|
else: args.num_train_epochs = 0; logger.warning(f"Training for {args.max_train_steps} steps, but calculated steps per epoch is zero.") |
|
|
|
|
|
|
|
|
|
|
|
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes) |
|
|
lr_scheduler = get_scheduler( |
|
|
args.lr_scheduler, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=args.lr_warmup_steps, |
|
|
num_training_steps=args.max_train_steps |
|
|
) |
|
|
logger.info(f"Initialized LR scheduler: {args.lr_scheduler} ({args.lr_warmup_steps} warmup, {args.max_train_steps} total steps).") |
|
|
|
|
|
|
|
|
logger.info("Preparing models, optimizer, dataloader, and scheduler with Accelerator...") |
|
|
|
|
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) |
|
|
logger.info("Accelerator preparation finished.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.float32 |
|
|
if accelerator.mixed_precision == "fp16": |
|
|
weight_dtype = torch.float16 |
|
|
elif accelerator.mixed_precision == "bf16": |
|
|
weight_dtype = torch.bfloat16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Moving VAE and Text Encoder to device {accelerator.device} (keeping float32)...") |
|
|
vae.to(accelerator.device) |
|
|
text_encoder.to(accelerator.device) |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
tracker_project_name = "mamba-sd-train-url" |
|
|
|
|
|
clean_dataset_name = args.dataset_name.split('/')[-1].replace('-', '_').replace('/','_').replace('.','_') if args.dataset_name else "local_data" |
|
|
effective_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
run_name = f"{clean_dataset_name}_{args.max_train_samples or 'all'}samples_lr{args.learning_rate}_bs{effective_batch_size}_mamba{args.mamba_d_state}-{args.mamba_d_conv}-{args.mamba_expand}" |
|
|
try: |
|
|
accelerator.init_trackers(tracker_project_name, config=vars(args), init_kwargs={"wandb": {"name": run_name}}) |
|
|
logger.info(f"Initialized trackers (Project: {tracker_project_name}, Run: {run_name})") |
|
|
except Exception as e: logger.warning(f"Could not initialize trackers ({args.report_to}): {e}.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global_step = 0; first_epoch = 0; resume_step = 0 |
|
|
if args.resume_from_checkpoint: |
|
|
checkpoint_path = None |
|
|
checkpoint_dir = Path(args.output_dir) |
|
|
if args.resume_from_checkpoint == "latest": |
|
|
|
|
|
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")] |
|
|
if dirs: |
|
|
try: |
|
|
latest_checkpoint = max(dirs, key=lambda d: int(d.name.split('-')[-1])) |
|
|
checkpoint_path = str(latest_checkpoint) |
|
|
logger.info(f"Resuming from latest checkpoint: {checkpoint_path}") |
|
|
except (ValueError, IndexError): |
|
|
logger.warning(f"Could not determine step number from checkpoint names in {checkpoint_dir}. Cannot resume 'latest'.") |
|
|
args.resume_from_checkpoint = None |
|
|
else: logger.info("No 'latest' checkpoint found to resume from."); args.resume_from_checkpoint = None |
|
|
else: checkpoint_path = args.resume_from_checkpoint |
|
|
|
|
|
if checkpoint_path and os.path.isdir(checkpoint_path): |
|
|
logger.info(f"Attempting resume from specific checkpoint: {checkpoint_path}") |
|
|
try: |
|
|
accelerator.load_state(checkpoint_path) |
|
|
|
|
|
path_stem = Path(checkpoint_path).stem |
|
|
global_step = int(path_stem.split("-")[-1]) |
|
|
logger.info(f"Loaded state. Resuming from global step {global_step}.") |
|
|
|
|
|
steps_per_epoch_after_prepare = 0 |
|
|
if len(train_dataloader) > 0: |
|
|
steps_per_epoch_after_prepare = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
|
|
|
|
if steps_per_epoch_after_prepare > 0: |
|
|
first_epoch = global_step // steps_per_epoch_after_prepare |
|
|
resume_step = global_step % steps_per_epoch_after_prepare |
|
|
logger.info(f"Calculated resume point: Epoch {first_epoch}, Step within epoch ~{resume_step}.") |
|
|
else: |
|
|
logger.warning("Steps/epoch is 0 after prepare. Cannot accurately calculate resume epoch/step within epoch. Starting from epoch 0.") |
|
|
first_epoch = 0; resume_step = 0 |
|
|
except FileNotFoundError: logger.error(f"Resume checkpoint directory not found: {checkpoint_path}. Starting fresh."); global_step=0; first_epoch=0; resume_step=0 |
|
|
except (ValueError, IndexError): logger.error(f"Could not parse step number from checkpoint name: {checkpoint_path}. Starting fresh."); global_step=0; first_epoch=0; resume_step=0 |
|
|
except Exception as e: logger.error(f"Failed to load checkpoint state: {e}. Starting fresh.", exc_info=True); global_step=0; first_epoch=0; resume_step=0 |
|
|
elif args.resume_from_checkpoint: logger.warning(f"Resume checkpoint path invalid or not found: '{args.resume_from_checkpoint}'. Starting fresh."); global_step=0; first_epoch=0; resume_step=0 |
|
|
else: |
|
|
logger.info("Starting training from scratch (no checkpoint to resume)."); global_step=0; first_epoch=0; resume_step=0 |
|
|
|
|
|
|
|
|
|
|
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
logger.info(f"***** Running training ({args.dataset_name} - {final_dataset_size} Effective Samples) *****") |
|
|
logger.info(f" Num Epochs = {args.num_train_epochs}") |
|
|
logger.info(f" Batch size per device = {args.train_batch_size}") |
|
|
logger.info(f" Total train batch size (effective) = {total_batch_size}") |
|
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
|
logger.info(f" Total optimization steps = {args.max_train_steps}") |
|
|
logger.info(f" Starting Epoch = {first_epoch}") |
|
|
logger.info(f" Starting Global Step = {global_step}") |
|
|
logger.info(f" Resume Step in Epoch (approx) = {resume_step}") |
|
|
|
|
|
progress_bar = tqdm(range(global_step, args.max_train_steps), initial=global_step, total=args.max_train_steps, desc="Optimization Steps", disable=not accelerator.is_local_main_process) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.float32 |
|
|
if accelerator.mixed_precision == "fp16": |
|
|
weight_dtype = torch.float16 |
|
|
elif accelerator.mixed_precision == "bf16": |
|
|
weight_dtype = torch.bfloat16 |
|
|
|
|
|
|
|
|
for epoch in range(first_epoch, args.num_train_epochs): |
|
|
unet.train() |
|
|
train_loss = 0.0 |
|
|
logger.info(f"--- Starting Epoch {epoch} ---") |
|
|
|
|
|
for step, batch in enumerate(train_dataloader): |
|
|
|
|
|
if not batch or "pixel_values" not in batch or batch["pixel_values"].shape[0] == 0: |
|
|
if accelerator.is_main_process: |
|
|
if global_step % 100 == 0: |
|
|
logger.warning(f"Skipping empty/invalid batch at raw step {step} (Epoch {epoch}, Global ~{global_step}). Likely due to download/collation errors.") |
|
|
continue |
|
|
|
|
|
|
|
|
with accelerator.accumulate(unet): |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
pixel_values = batch["pixel_values"].to(accelerator.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
latents = vae.encode(pixel_values.to(dtype=torch.float32)).latent_dist.sample() * vae.config.scaling_factor |
|
|
|
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
bsz = latents.shape[0] |
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() |
|
|
|
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
input_ids = batch["input_ids"].to(accelerator.device) |
|
|
|
|
|
encoder_hidden_states = text_encoder(input_ids)[0] |
|
|
|
|
|
|
|
|
|
|
|
noisy_latents_input = noisy_latents.to(dtype=weight_dtype) |
|
|
encoder_hidden_states_input = encoder_hidden_states.to(dtype=weight_dtype) |
|
|
|
|
|
|
|
|
|
|
|
model_pred = unet( |
|
|
noisy_latents_input, |
|
|
timesteps, |
|
|
encoder_hidden_states_input |
|
|
).sample |
|
|
|
|
|
|
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon": |
|
|
target = noise |
|
|
elif noise_scheduler.config.prediction_type == "v_prediction": |
|
|
target = noise_scheduler.get_velocity(latents, noise, timesteps) |
|
|
else: |
|
|
raise ValueError(f"Unsupported prediction type {noise_scheduler.config.prediction_type}") |
|
|
|
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
|
|
|
|
|
|
|
|
avg_loss = accelerator.gather(loss.unsqueeze(0)).mean() |
|
|
train_loss += avg_loss.item() / args.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as forward_err: |
|
|
logger.error(f"Error during training step {step} (Epoch {epoch}, Global ~{global_step}): {forward_err}", exc_info=True) |
|
|
try: |
|
|
pv_shape = batch.get('pixel_values').shape if batch and isinstance(batch.get('pixel_values'), torch.Tensor) else 'N/A' |
|
|
id_shape = batch.get('input_ids').shape if batch and isinstance(batch.get('input_ids'), torch.Tensor) else 'N/A' |
|
|
logger.error(f" Batch Shapes - Pixels: {pv_shape}, IDs: {id_shape}") |
|
|
except Exception as log_err: |
|
|
logger.error(f" (Could not log batch details: {log_err})") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
try: |
|
|
if args.max_grad_norm > 0: |
|
|
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) |
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
except Exception as optim_err: |
|
|
logger.error(f"Error during optimizer step/grad clipping at Global Step {global_step}: {optim_err}", exc_info=True) |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
progress_bar.update(1) |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
logs = {"train_loss": train_loss} |
|
|
if hasattr(lr_scheduler, "get_last_lr"): |
|
|
current_lr = lr_scheduler.get_last_lr() |
|
|
logs["lr"] = current_lr[0] if isinstance(current_lr, list) else current_lr |
|
|
else: |
|
|
logs["lr"] = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
try: accelerator.log(logs, step=global_step) |
|
|
except Exception as log_err: logger.warning(f"Logging failed for step {global_step}: {log_err}") |
|
|
train_loss = 0.0 |
|
|
|
|
|
|
|
|
if global_step > 0 and global_step % args.checkpointing_steps == 0: |
|
|
if accelerator.is_main_process: |
|
|
save_path = Path(args.output_dir) / f"checkpoint-{global_step}" |
|
|
try: |
|
|
logger.info(f"Saving checkpoint: {save_path}...") |
|
|
accelerator.save_state(str(save_path)) |
|
|
unwrapped_unet = accelerator.unwrap_model(unet) |
|
|
unet_save_path = save_path / "unet_mamba" |
|
|
unwrapped_unet.save_pretrained( |
|
|
str(unet_save_path), |
|
|
state_dict=unwrapped_unet.state_dict(), |
|
|
safe_serialization=True |
|
|
) |
|
|
logger.info(f"Checkpoint saved to {save_path}") |
|
|
|
|
|
|
|
|
if args.checkpoints_total_limit is not None and args.checkpoints_total_limit > 0: |
|
|
checkpoint_dir = Path(args.output_dir) |
|
|
ckpts = sorted( |
|
|
[d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")], |
|
|
key=lambda d: int(d.name.split("-")[-1]) |
|
|
) |
|
|
if len(ckpts) > args.checkpoints_total_limit: |
|
|
num_to_delete = len(ckpts) - args.checkpoints_total_limit |
|
|
for old_ckpt in ckpts[:num_to_delete]: |
|
|
logger.info(f"Deleting old checkpoint: {old_ckpt}") |
|
|
shutil.rmtree(old_ckpt, ignore_errors=True) |
|
|
except Exception as ckpt_err: logger.error(f"Checkpoint saving failed for step {global_step}: {ckpt_err}", exc_info=True) |
|
|
|
|
|
|
|
|
|
|
|
run_validation = False |
|
|
if args.validation_steps and global_step > 0 and global_step % args.validation_steps == 0: |
|
|
run_validation = True |
|
|
elif not args.validation_steps and args.validation_epochs > 0 and (epoch + 1) % args.validation_epochs == 0: |
|
|
is_last_accum_step = step == len(train_dataloader) - 1 |
|
|
if is_last_accum_step: run_validation = True |
|
|
|
|
|
if run_validation and accelerator.is_main_process: |
|
|
logger.info(f"Running validation at Global Step {global_step} (Epoch {epoch})...") |
|
|
log_validation_images = [] |
|
|
pipeline = None |
|
|
original_unet_training_mode = unet.training |
|
|
unet.eval() |
|
|
try: |
|
|
|
|
|
unet_val = accelerator.unwrap_model(unet) |
|
|
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
|
args.pretrained_model_name_or_path, |
|
|
unet=unet_val, |
|
|
vae=vae, |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
scheduler=noise_scheduler, |
|
|
safety_checker=None, |
|
|
torch_dtype=torch.float32, |
|
|
|
|
|
cache_dir=args.cache_dir |
|
|
) |
|
|
pipeline = pipeline.to(accelerator.device) |
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None |
|
|
|
|
|
logger.info(f"Generating {args.num_validation_images} validation images...") |
|
|
for i in range(args.num_validation_images): |
|
|
|
|
|
with torch.autocast(str(accelerator.device).split(":")[0], dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"), torch.no_grad(): |
|
|
image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] |
|
|
log_validation_images.append(np.array(image)) |
|
|
|
|
|
if log_validation_images: |
|
|
logger.info(f"Logging {len(log_validation_images)} validation images to trackers...") |
|
|
try: |
|
|
images_np = np.stack(log_validation_images) |
|
|
accelerator.log({"validation_images": images_np}, step=global_step) |
|
|
logger.info("Validation images logged.") |
|
|
except Exception as tracker_err: logger.warning(f"Failed to log validation images: {tracker_err}") |
|
|
else: logger.warning("No validation images were generated.") |
|
|
|
|
|
except Exception as val_err: |
|
|
logger.error(f"Validation failed at step {global_step}: {val_err}", exc_info=True) |
|
|
finally: |
|
|
|
|
|
if pipeline is not None: del pipeline |
|
|
torch.cuda.empty_cache() |
|
|
unet.train(original_unet_training_mode) |
|
|
logger.info("Validation run finished.") |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
try: |
|
|
loss_val = loss.detach().item() |
|
|
current_lr_val = lr_scheduler.get_last_lr()[0] if hasattr(lr_scheduler, "get_last_lr") else optimizer.param_groups[0]['lr'] |
|
|
logs_postfix = {"loss": f"{loss_val:.4f}", "lr": f"{current_lr_val:.2e}"} |
|
|
progress_bar.set_postfix(**logs_postfix) |
|
|
except NameError: |
|
|
logs_postfix = {"loss": "N/A", "lr": optimizer.param_groups[0]['lr'] if optimizer.param_groups else 'N/A'} |
|
|
progress_bar.set_postfix(**logs_postfix) |
|
|
except Exception as pf_err: |
|
|
logger.debug(f"Postfix update error: {pf_err}") |
|
|
progress_bar.set_postfix({"step_status":"error"}) |
|
|
|
|
|
|
|
|
if global_step >= args.max_train_steps: |
|
|
logger.info(f"Reached max_train_steps ({args.max_train_steps}). Stopping training.") |
|
|
break |
|
|
|
|
|
|
|
|
logger.info(f"--- Finished Epoch {epoch} (Reached Global Step {global_step}) ---") |
|
|
if global_step >= args.max_train_steps: |
|
|
break |
|
|
|
|
|
|
|
|
logger.info("Training finished. Waiting for all processes..."); |
|
|
accelerator.wait_for_everyone(); |
|
|
if accelerator.is_main_process: progress_bar.close() |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
logger.info("Saving final trained U-Net model..."); |
|
|
try: |
|
|
unet_final = accelerator.unwrap_model(unet) |
|
|
final_save_path = Path(args.output_dir) |
|
|
unet_final.save_pretrained( |
|
|
final_save_path / "unet_mamba_final", |
|
|
safe_serialization=True, |
|
|
state_dict=unet_final.state_dict() |
|
|
) |
|
|
logger.info(f"Final UNet saved to: {final_save_path / 'unet_mamba_final'}") |
|
|
tokenizer.save_pretrained(str(final_save_path / "tokenizer_final")) |
|
|
logger.info(f"Final Tokenizer saved to: {final_save_path / 'tokenizer_final'}") |
|
|
|
|
|
except Exception as e: logger.error(f"Failed to save final UNet/Tokenizer: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
if args.push_to_hub: |
|
|
logger.info("Attempting to push final model to Hub..."); |
|
|
if repo_id is None: logger.warning("Cannot push to Hub (repo_id not defined or Hub creation failed).") |
|
|
else: |
|
|
try: |
|
|
logger.info(f"Pushing contents of {args.output_dir} to repository {repo_id}..."); |
|
|
upload_folder( |
|
|
repo_id=repo_id, |
|
|
folder_path=args.output_dir, |
|
|
commit_message="End of training - Mamba SD URL Text", |
|
|
ignore_patterns=["step_*", "epoch_*", "checkpoint-*/**", "checkpoint-*/", "*.safetensors.index.json", "logs/**"], |
|
|
token=args.hub_token |
|
|
) |
|
|
logger.info("Push to Hub successful.") |
|
|
except Exception as e: logger.error(f"Hub upload failed: {e}", exc_info=True) |
|
|
|
|
|
logger.info("Ending training script..."); |
|
|
accelerator.end_training(); |
|
|
logger.info("Script finished.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
main() |
|
|
except Exception as e: |
|
|
print(f"\n\n !!! --- FATAL SCRIPT ERROR --- !!!") |
|
|
print(f"Error Type: {type(e).__name__}") |
|
|
print(f"Error Details: {e}") |
|
|
print(f"Traceback:") |
|
|
print(traceback.format_exc()) |
|
|
print(f" !!! --- SCRIPT TERMINATED DUE TO ERROR --- !!!") |
|
|
exit(1) |