diff --git a/Rectified_Noise/GVP-Disp/README.md b/Rectified_Noise/GVP-Disp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..116a845bb2f32f1d32b074e4335c9cf71eaacf60 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/README.md @@ -0,0 +1,92 @@ +# [AAAI 2026] Rectified Noise: A Generative Model Using Positive-incentive Noise + +![Visualization of the $\pi$-noise by $\Delta$RN.](assests/visual.png) + +
+ +HuggingFace + +## Introduction +This is a [Pytorch](https://pytorch.org) implementation of **Rectified Noise**, a generative model using positive-incentive noise to enhance model's sampling. + +![Overview of Laytrol](assests/pipeline.png) + +## Setup + +We provide an `environment.yml` file that can be used to create a Conda environment. + +```bash +conda env create -f environment.yml +conda activate RN +``` + +## Usage + +### Training +1. We provide a training script for RN in `train_rectified_noise.py` + + Run: + +```bash +torchrun --nnodes=1 --nproc_per_node=4 train_rectified_noise.py \ +--data-path /path/to/data \ +--num-classes 3 \ +--path-type Linear \ +--prediction velocity \ +--ckpt /path/to/pretrained_model \ +--model SiT-B/2 +--learn-mu True \ +--depth 1 \ +``` + +You can find relevant checkpoint files from the previous Hugging Face link. + +2. Parameters: + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--data-path ` | str | `-` | Path to the dataset. | +| `--num-classes` | int | `-` | Number of classes. | +| `--path-type` | str | `Linear` | Directory to save the generated images. | +| `--prediction` | str | `velocity` | Output type of network. | +| `--ckpt` | str | `-` | Path to pretrained model checkpoint. | +| `--model` | str | `SiT-B/2` | Model type, any option from the model list. | +| `--learn-mu` | bool | `True` | Whether to learn the mu parameter. | +| `--depth` | int | `1` | Depth parameter for the SiTF2 model(Extra SiT Block). | + +**Sampling** + +1. Using the trained RN model to enhance the pre-trained model + +```bash +torchrun --nnodes=1 --nproc_per_node=4 train_rectified_noise.py \ +--path-type Linear \ +--prediction velocity \ +--ckpt /path/to/pretrained_model \ +--sitf2-ckpt /path/to/pretrained_RN \ +--model SiT-B/2 +--learn-mu True \ +--depth 1 \ +``` + +## Ackownledgement +This repo benefits from [SiT](https://github.com/willisma/SiT). Thanks for their excellent works. + +## Contact +If you have any question about this project, please contact mguzhenyu@outlook.com. + +## Citation + +If you find the code useful for your research, please consider citing our work: + +``` +@misc{gu2025rectifiednoisegenerativemodel, + title={Rectified Noise: A Generative Model Using Positive-incentive Noise}, + author={Zhenyu Gu and Yanchen Xu and Sida Huang and Yubin Guo and Hongyuan Zhang}, + year={2025}, + eprint={2511.07911}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2511.07911}, +} +``` diff --git a/Rectified_Noise/GVP-Disp/W_False.log b/Rectified_Noise/GVP-Disp/W_False.log new file mode 100644 index 0000000000000000000000000000000000000000..b07397afd2eed57cb9358c342e01b1754891cbc6 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/W_False.log @@ -0,0 +1,5 @@ +[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds. +Starting rank=0, seed=0, world_size=1. +Saving .png samples at GVP_samples/depth-mu-2-threshold-1.0-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04 +Total number of images that will be sampled: 3008 + 0%| | 0/47 [00:00 eval_threshold_0.0.log 2>&1 & + +# Evaluate threshold 0.15 on GPU 1 +CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.15-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.15.log 2>&1 & + +# Evaluate threshold 0.25 on GPU 2 +CUDA_VISIBLE_DEVICES=2 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.25-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.25.log 2>&1 & + +# Evaluate threshold 0.5 on GPU 3 +CUDA_VISIBLE_DEVICES=3 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.5-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.5.log 2>&1 & + +# Evaluate threshold 0.75 on GPU 4 +CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.75-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.75.log 2>&1 & + +# Evaluate threshold 1.0 on GPU 5 +CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-1.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_1.0.log 2>&1 & + +# Wait for all background jobs to complete +echo "All evaluation tasks started. Waiting for completion..." +wait + +echo "All evaluation tasks completed!" +echo "" +echo "Results saved in:" +echo " - eval_threshold_0.0.log" +echo " - eval_threshold_0.15.log" +echo " - eval_threshold_0.25.log" +echo " - eval_threshold_0.5.log" +echo " - eval_threshold_0.75.log" +echo " - eval_threshold_1.0.log" diff --git a/Rectified_Noise/GVP-Disp/evaluator.py b/Rectified_Noise/GVP-Disp/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..62bdd4b20c4da0db41fe0328d3b2ee6040935d23 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/evaluator.py @@ -0,0 +1,689 @@ +import argparse +import io +import os +import random +import warnings +import zipfile +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool +from typing import Iterable, Optional, Tuple, Union + +import numpy as np +import requests +import tensorflow.compat.v1 as tf +from scipy import linalg +from tqdm.auto import tqdm +from datetime import timedelta +import torch + + + +INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" +INCEPTION_V3_PATH = "classify_image_graph_def.pb" + +FID_POOL_NAME = "pool_3:0" +FID_SPATIAL_NAME = "mixed_6/conv:0" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ref_batch", default='/gemini/space/zhaozy/zhy/dataset/VIRTUAL_imagenet256_labeled.npz',help="path to reference batch npz file") + parser.add_argument("--sample_batch", default='/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/last_samples_depth_2/depth-mu-28-0050000-2000000-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz', help="path to sample batch npz file") + args = parser.parse_args() + + config = tf.ConfigProto( + allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph + ) + config.gpu_options.allow_growth = True + evaluator = Evaluator(tf.Session(config=config)) + + print("warming up TensorFlow...") + # This will cause TF to print a bunch of verbose stuff now rather + # than after the next print(), to help prevent confusion. + evaluator.warmup() + + print("computing reference batch activations...") + ref_acts = evaluator.read_activations(args.ref_batch) + print("computing/reading reference batch statistics...") + ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) + + print("computing sample batch activations...") + sample_acts = evaluator.read_activations(args.sample_batch) + print("computing/reading sample batch statistics...") + sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) + + print("Computing evaluations...") + #print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) + print("FID:", sample_stats.frechet_distance(ref_stats)) + #print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) + #prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) + #print("Precision:", prec) + #print("Recall:", recall) + + +class InvalidFIDException(Exception): + pass + + +class FIDStatistics: + def __init__(self, mu: np.ndarray, sigma: np.ndarray): + self.mu = mu + self.sigma = sigma + + def frechet_distance(self, other, eps=1e-6): + """ + Compute the Frechet distance between two sets of statistics. + """ + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 + mu1, sigma1 = self.mu, self.sigma + mu2, sigma2 = other.mu, other.sigma + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" + assert ( + sigma1.shape == sigma2.shape + ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; adding %s to diagonal of cov estimates" + % eps + ) + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + #虚部报错部分 + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1): + m = np.max(np.abs(covmean.imag)) + print(f"Real component: {covmean.real}") + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +class Evaluator: + def __init__( + self, + session, + batch_size=64, + softmax_batch_size=512, + ): + self.sess = session + self.batch_size = batch_size + self.softmax_batch_size = softmax_batch_size + self.manifold_estimator = ManifoldEstimator(session) + with self.sess.graph.as_default(): + self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) + self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) + self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) + self.softmax = _create_softmax_graph(self.softmax_input) + + def warmup(self): + self.compute_activations(np.zeros([1, 8, 64, 64, 3])) + + def read_activations(self, npz_path: Union[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + if isinstance(npz_path, str): + # If npz_path is a string, treat it as a file path and read the .npz file + with open_npz_array(npz_path, "arr_0") as reader: + return self.compute_activations(reader.read_batches(self.batch_size)) + elif isinstance(npz_path, np.ndarray): + # If npz_path is a numpy array, split it into batches manually + print("--------line 140-----------") + batches = np.array_split(npz_path, range(self.batch_size, npz_path.shape[0], self.batch_size)) + print("--------line 143-----------") + return self.compute_activations(batches) + else: + raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)") + + + def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute image features for downstream evals. + + :param batches: a iterator over NHWC numpy arrays in [0, 255]. + :return: a tuple of numpy arrays of shape [N x X], where X is a feature + dimension. The tuple is (pool_3, spatial). + """ + preds = [] + spatial_preds = [] + for batch in tqdm(batches): + # print("--------line 164-----------") + + # # 识别当前进程信息 + # if 'RANK' in os.environ: + # rank = int(os.environ['RANK']) + # local_rank = int(os.environ.get('LOCAL_RANK', rank % torch.cuda.device_count())) + # print(f"Distributed training - Global Rank: {rank}, Local Rank: {local_rank}") + # print(f"Current GPU device: {torch.cuda.current_device()}" if torch.cuda.is_available() else "No CUDA") + # else: + # print("Single process mode") + + # print(f"Process PID: {os.getpid()}") + + batch = batch.astype(np.float32) + pred, spatial_pred = self.sess.run( + [self.pool_features, self.spatial_features], {self.image_input: batch} + ) + # print("--------line 169-----------") + preds.append(pred.reshape([pred.shape[0], -1])) + spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) + return ( + np.concatenate(preds, axis=0), + np.concatenate(spatial_preds, axis=0), + ) + + def read_statistics( + self, npz_path: Union[str, np.ndarray], activations: Tuple[np.ndarray, np.ndarray] + ) -> Tuple[FIDStatistics, FIDStatistics]: + if isinstance(npz_path, str): + obj = np.load(npz_path) + if "mu" in list(obj.keys()): + return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( + obj["mu_s"], obj["sigma_s"] + ) + elif isinstance(npz_path, np.ndarray): + obj = npz_path + else: + raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)") + return tuple(self.compute_statistics(x) for x in activations) + + def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return FIDStatistics(mu, sigma) + + def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: + softmax_out = [] + for i in range(0, len(activations), self.softmax_batch_size): + acts = activations[i : i + self.softmax_batch_size] + softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) + preds = np.concatenate(softmax_out, axis=0) + # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 + scores = [] + for i in range(0, len(preds), split_size): + part = preds[i : i + split_size] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)) + + def compute_prec_recall( + self, activations_ref: np.ndarray, activations_sample: np.ndarray + ) -> Tuple[float, float]: + radii_1 = self.manifold_estimator.manifold_radii(activations_ref) + radii_2 = self.manifold_estimator.manifold_radii(activations_sample) + pr = self.manifold_estimator.evaluate_pr( + activations_ref, radii_1, activations_sample, radii_2 + ) + return (float(pr[0][0]), float(pr[1][0])) + + +class ManifoldEstimator: + """ + A helper for comparing manifolds of feature vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 + """ + + def __init__( + self, + session, + row_batch_size=10000, + col_batch_size=10000, + nhood_sizes=(3,), + clamp_to_percentile=None, + eps=1e-5, + ): + """ + Estimate the manifold of given feature vectors. + + :param session: the TensorFlow session. + :param row_batch_size: row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + :param col_batch_size: column batch size to compute pairwise distances. + :param nhood_sizes: number of neighbors used to estimate the manifold. + :param clamp_to_percentile: prune hyperspheres that have radius larger than + the given percentile. + :param eps: small number for numerical stability. + """ + self.distance_block = DistanceBlock(session) + self.row_batch_size = row_batch_size + self.col_batch_size = col_batch_size + self.nhood_sizes = nhood_sizes + self.num_nhoods = len(nhood_sizes) + self.clamp_to_percentile = clamp_to_percentile + self.eps = eps + + def warmup(self): + feats, radii = ( + np.zeros([1, 2048], dtype=np.float32), + np.zeros([1, 1], dtype=np.float32), + ) + self.evaluate_pr(feats, radii, feats, radii) + + def manifold_radii(self, features: np.ndarray) -> np.ndarray: + num_images = len(features) + + # Estimate manifold of features by calculating distances to k-NN of each sample. + radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) + distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) + seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) + + for begin1 in range(0, num_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(row_batch, col_batch) + + # Find the k-nearest neighbor from the current batch. + radii[begin1:end1, :] = np.concatenate( + [ + x[:, self.nhood_sizes] + for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) + ], + axis=0, + ) + + if self.clamp_to_percentile is not None: + max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) + radii[radii > max_distances] = 0 + return radii + + def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): + """ + Evaluate if new feature vectors are at the manifold. + """ + num_eval_images = eval_features.shape[0] + num_ref_images = radii.shape[0] + distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) + batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) + max_realism_score = np.zeros([num_eval_images], dtype=np.float32) + nearest_indices = np.zeros([num_eval_images], dtype=np.int32) + + for begin1 in range(0, num_eval_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_eval_images) + feature_batch = eval_features[begin1:end1] + + for begin2 in range(0, num_ref_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_ref_images) + ref_batch = features[begin2:end2] + + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) + + # From the minibatch of new feature vectors, determine if they are in the estimated manifold. + # If a feature vector is inside a hypersphere of some reference sample, then + # the new sample lies at the estimated manifold. + # The radii of the hyperspheres are determined from distances of neighborhood size k. + samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii + batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) + + max_realism_score[begin1:end1] = np.max( + radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 + ) + nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) + + return { + "fraction": float(np.mean(batch_predictions)), + "batch_predictions": batch_predictions, + "max_realisim_score": max_realism_score, + "nearest_indices": nearest_indices, + } + + def evaluate_pr( + self, + features_1: np.ndarray, + radii_1: np.ndarray, + features_2: np.ndarray, + radii_2: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Evaluate precision and recall efficiently. + + :param features_1: [N1 x D] feature vectors for reference batch. + :param radii_1: [N1 x K1] radii for reference vectors. + :param features_2: [N2 x D] feature vectors for the other batch. + :param radii_2: [N x K2] radii for other vectors. + :return: a tuple of arrays for (precision, recall): + - precision: an np.ndarray of length K1 + - recall: an np.ndarray of length K2 + """ + features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) + features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) + for begin_1 in range(0, len(features_1), self.row_batch_size): + end_1 = begin_1 + self.row_batch_size + batch_1 = features_1[begin_1:end_1] + for begin_2 in range(0, len(features_2), self.col_batch_size): + end_2 = begin_2 + self.col_batch_size + batch_2 = features_2[begin_2:end_2] + batch_1_in, batch_2_in = self.distance_block.less_thans( + batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] + ) + features_1_status[begin_1:end_1] |= batch_1_in + features_2_status[begin_2:end_2] |= batch_2_in + return ( + np.mean(features_2_status.astype(np.float64), axis=0), + np.mean(features_1_status.astype(np.float64), axis=0), + ) + + +class DistanceBlock: + """ + Calculate pairwise distances between vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 + """ + + def __init__(self, session): + self.session = session + + # Initialize TF graph to calculate pairwise distances. + with session.graph.as_default(): + self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) + self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) + distance_block_16 = _batch_pairwise_distances( + tf.cast(self._features_batch1, tf.float16), + tf.cast(self._features_batch2, tf.float16), + ) + self.distance_block = tf.cond( + tf.reduce_all(tf.math.is_finite(distance_block_16)), + lambda: tf.cast(distance_block_16, tf.float32), + lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), + ) + + # Extra logic for less thans. + self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) + self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) + dist32 = tf.cast(self.distance_block, tf.float32)[..., None] + self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) + self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) + + def pairwise_distances(self, U, V): + """ + Evaluate pairwise distances between two batches of feature vectors. + """ + return self.session.run( + self.distance_block, + feed_dict={self._features_batch1: U, self._features_batch2: V}, + ) + + def less_thans(self, batch_1, radii_1, batch_2, radii_2): + return self.session.run( + [self._batch_1_in, self._batch_2_in], + feed_dict={ + self._features_batch1: batch_1, + self._features_batch2: batch_2, + self._radii1: radii_1, + self._radii2: radii_2, + }, + ) + + +def _batch_pairwise_distances(U, V): + """ + Compute pairwise distances between two batches of feature vectors. + """ + with tf.variable_scope("pairwise_dist_block"): + # Squared norms of each row in U and V. + norm_u = tf.reduce_sum(tf.square(U), 1) + norm_v = tf.reduce_sum(tf.square(V), 1) + + # norm_u as a column and norm_v as a row vectors. + norm_u = tf.reshape(norm_u, [-1, 1]) + norm_v = tf.reshape(norm_v, [1, -1]) + + # Pairwise squared Euclidean distances. + D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) + + return D + + +class NpzArrayReader(ABC): + @abstractmethod + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + pass + + @abstractmethod + def remaining(self) -> int: + pass + + def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: + def gen_fn(): + while True: + batch = self.read_batch(batch_size) + if batch is None: + break + yield batch + + rem = self.remaining() + num_batches = rem // batch_size + int(rem % batch_size != 0) + return BatchIterator(gen_fn, num_batches) + + +class BatchIterator: + def __init__(self, gen_fn, length): + self.gen_fn = gen_fn + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen_fn() + + +class StreamingNpzArrayReader(NpzArrayReader): + def __init__(self, arr_f, shape, dtype): + self.arr_f = arr_f + self.shape = shape + self.dtype = dtype + self.idx = 0 + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.shape[0]: + return None + + bs = min(batch_size, self.shape[0] - self.idx) + self.idx += bs + + if self.dtype.itemsize == 0: + return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) + + read_count = bs * np.prod(self.shape[1:]) + read_size = int(read_count * self.dtype.itemsize) + data = _read_bytes(self.arr_f, read_size, "array data") + return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) + + def remaining(self) -> int: + return max(0, self.shape[0] - self.idx) + + +class MemoryNpzArrayReader(NpzArrayReader): + def __init__(self, arr): + self.arr = arr + self.idx = 0 + + @classmethod + def load(cls, path: str, arr_name: str): + with open(path, "rb") as f: + arr = np.load(f)[arr_name] + return cls(arr) + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.arr.shape[0]: + return None + + res = self.arr[self.idx : self.idx + batch_size] + self.idx += batch_size + return res + + def remaining(self) -> int: + return max(0, self.arr.shape[0] - self.idx) + + +@contextmanager +def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: + with _open_npy_file(path, arr_name) as arr_f: + version = np.lib.format.read_magic(arr_f) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(arr_f) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(arr_f) + else: + yield MemoryNpzArrayReader.load(path, arr_name) + return + shape, fortran, dtype = header + if fortran or dtype.hasobject: + yield MemoryNpzArrayReader.load(path, arr_name) + else: + yield StreamingNpzArrayReader(arr_f, shape, dtype) + + +def _read_bytes(fp, size, error_template="ran out of data"): + """ + Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 + + Read from file-like object until size bytes are read. + Raises ValueError if not EOF is encountered before size bytes are read. + Non-blocking objects only supported if they derive from io objects. + Required as e.g. ZipExtFile in python 2.6 can return less data than + requested. + """ + data = bytes() + while True: + # io files (default in python3) return None or raise on + # would-block, python2 file will truncate, probably nothing can be + # done about that. note that regular files can't be non-blocking + try: + r = fp.read(size - len(data)) + data += r + if len(r) == 0 or len(data) == size: + break + except io.BlockingIOError: + pass + if len(data) != size: + msg = "EOF: reading %s, expected %d bytes got %d" + raise ValueError(msg % (error_template, size, len(data))) + else: + return data + + +@contextmanager +def _open_npy_file(path: str, arr_name: str): + with open(path, "rb") as f: + with zipfile.ZipFile(f, "r") as zip_f: + if f"{arr_name}.npy" not in zip_f.namelist(): + raise ValueError(f"missing {arr_name} in npz file") + with zip_f.open(f"{arr_name}.npy", "r") as arr_f: + yield arr_f + + +def _download_inception_model(): + if os.path.exists(INCEPTION_V3_PATH): + return + print("downloading InceptionV3 model...") + with requests.get(INCEPTION_V3_URL, stream=True) as r: + r.raise_for_status() + tmp_path = INCEPTION_V3_PATH + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in tqdm(r.iter_content(chunk_size=8192)): + f.write(chunk) + os.rename(tmp_path, INCEPTION_V3_PATH) + + +def _create_feature_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + pool3, spatial = tf.import_graph_def( + graph_def, + input_map={f"ExpandDims:0": input_batch}, + return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], + name=prefix, + ) + _update_shapes(pool3) + spatial = spatial[..., :7] + return pool3, spatial + + +def _create_softmax_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + (matmul,) = tf.import_graph_def( + graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix + ) + w = matmul.inputs[1] + logits = tf.matmul(input_batch, w) + return tf.nn.softmax(logits) + + +def _update_shapes(pool3): + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 + ops = pool3.graph.get_operations() + for op in ops: + for o in op.outputs: + shape = o.get_shape() + if shape._dims is not None: # pylint: disable=protected-access + # shape = [s.value for s in shape] TF 1.x + shape = [s for s in shape] # TF 2.x + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__["_shape_val"] = tf.TensorShape(new_shape) + return pool3 + + +def _numpy_partition(arr, kth, **kwargs): + num_workers = min(cpu_count(), len(arr)) + chunk_size = len(arr) // num_workers + extra = len(arr) % num_workers + + start_idx = 0 + batches = [] + for i in range(num_workers): + size = chunk_size + (1 if i < extra else 0) + batches.append(arr[start_idx : start_idx + size]) + start_idx += size + + with ThreadPool(num_workers) as pool: + return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-000-SiT-XL-2-GVP-velocity-None/log.txt b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-000-SiT-XL-2-GVP-velocity-None/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..2bc560625dda92d8a4d85f0e69f02b5c4bac4c17 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-000-SiT-XL-2-GVP-velocity-None/log.txt @@ -0,0 +1,11 @@ +[2026-02-03 06:38:01] Experiment directory created at results_256_gvp_disp/depth-mu-2-000-SiT-XL-2-GVP-velocity-None +[2026-02-03 06:38:35] Combined_model Parameters: 729,629,632 +[2026-02-03 06:38:35] Total trainable parameters: 53,910,176 +[2026-02-03 06:38:38] Dataset contains 1,281,167 images (/gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/) +[2026-02-03 06:38:38] Training for 100000 epochs... +[2026-02-03 06:38:38] Beginning epoch 0... +[2026-02-03 06:39:30] (step=0000100) Train Loss: -1.8935, Train Steps/Sec: 1.91 +[2026-02-03 06:40:20] (step=0000200) Train Loss: -2.2925, Train Steps/Sec: 2.04 +[2026-02-03 06:41:10] (step=0000300) Train Loss: -2.2953, Train Steps/Sec: 1.99 +[2026-02-03 06:42:00] (step=0000400) Train Loss: -2.2904, Train Steps/Sec: 1.99 +[2026-02-03 06:42:50] (step=0000500) Train Loss: -2.2938, Train Steps/Sec: 2.00 diff --git a/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-001-SiT-XL-2-GVP-velocity-None/log.txt b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-001-SiT-XL-2-GVP-velocity-None/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..3cea30af66e6073c8e73c7eecfabf44f8a4e6112 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-001-SiT-XL-2-GVP-velocity-None/log.txt @@ -0,0 +1 @@ +[2026-02-03 06:44:16] Experiment directory created at results_256_gvp_disp/depth-mu-2-001-SiT-XL-2-GVP-velocity-None diff --git a/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/log.txt b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..c257ed1eda9c5429b0b3d6d40be24430bfde4a18 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/log.txt @@ -0,0 +1,500 @@ +[2026-02-03 06:45:00] Experiment directory created at results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None +[2026-02-03 06:45:32] Combined_model Parameters: 729,629,632 +[2026-02-03 06:45:32] Total trainable parameters: 53,910,176 +[2026-02-03 06:45:34] Dataset contains 1,281,167 images (/gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/) +[2026-02-03 06:45:34] Training for 100000 epochs... +[2026-02-03 06:45:34] Beginning epoch 0... +[2026-02-03 06:47:01] (step=0000100) Train Loss: -3.1750, Train Steps/Sec: 1.15 +[2026-02-03 06:48:24] (step=0000200) Train Loss: -3.6610, Train Steps/Sec: 1.20 +[2026-02-03 06:49:47] (step=0000300) Train Loss: -3.6752, Train Steps/Sec: 1.20 +[2026-02-03 06:51:10] (step=0000400) Train Loss: -3.6767, Train Steps/Sec: 1.20 +[2026-02-03 06:52:33] (step=0000500) Train Loss: -3.6782, Train Steps/Sec: 1.20 +[2026-02-03 06:53:56] (step=0000600) Train Loss: -3.6781, Train Steps/Sec: 1.20 +[2026-02-03 06:55:21] (step=0000700) Train Loss: -3.6788, Train Steps/Sec: 1.18 +[2026-02-03 06:57:50] (step=0000800) Train Loss: -3.6797, Train Steps/Sec: 0.67 +[2026-02-03 07:00:57] (step=0000900) Train Loss: -3.6833, Train Steps/Sec: 0.54 +[2026-02-03 07:04:02] (step=0001000) Train Loss: -3.6793, Train Steps/Sec: 0.54 +[2026-02-03 07:07:08] (step=0001100) Train Loss: -3.6790, Train Steps/Sec: 0.54 +[2026-02-03 07:10:14] (step=0001200) Train Loss: -3.6799, Train Steps/Sec: 0.54 +[2026-02-03 07:13:20] (step=0001300) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-03 07:19:05] (step=0001400) Train Loss: -3.6833, Train Steps/Sec: 0.29 +[2026-02-03 07:22:12] (step=0001500) Train Loss: -3.6796, Train Steps/Sec: 0.53 +[2026-02-03 07:25:19] (step=0001600) Train Loss: -3.6813, Train Steps/Sec: 0.54 +[2026-02-03 07:28:25] (step=0001700) Train Loss: -3.6843, Train Steps/Sec: 0.54 +[2026-02-03 07:31:32] (step=0001800) Train Loss: -3.6813, Train Steps/Sec: 0.53 +[2026-02-03 07:34:38] (step=0001900) Train Loss: -3.6828, Train Steps/Sec: 0.54 +[2026-02-03 07:37:45] (step=0002000) Train Loss: -3.6826, Train Steps/Sec: 0.54 +[2026-02-03 07:40:51] (step=0002100) Train Loss: -3.6799, Train Steps/Sec: 0.54 +[2026-02-03 07:43:58] (step=0002200) Train Loss: -3.6784, Train Steps/Sec: 0.53 +[2026-02-03 07:47:06] (step=0002300) Train Loss: -3.6824, Train Steps/Sec: 0.53 +[2026-02-03 07:50:12] (step=0002400) Train Loss: -3.6787, Train Steps/Sec: 0.54 +[2026-02-03 07:53:19] (step=0002500) Train Loss: -3.6771, Train Steps/Sec: 0.54 +[2026-02-03 07:53:23] Beginning epoch 1... +[2026-02-03 07:56:29] (step=0002600) Train Loss: -3.6847, Train Steps/Sec: 0.53 +[2026-02-03 07:59:35] (step=0002700) Train Loss: -3.6829, Train Steps/Sec: 0.54 +[2026-02-03 08:02:42] (step=0002800) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-03 08:05:49] (step=0002900) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-03 08:08:55] (step=0003000) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-03 08:12:01] (step=0003100) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 08:15:09] (step=0003200) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-03 08:18:16] (step=0003300) Train Loss: -3.6800, Train Steps/Sec: 0.53 +[2026-02-03 08:21:20] (step=0003400) Train Loss: -3.6797, Train Steps/Sec: 0.54 +[2026-02-03 08:24:27] (step=0003500) Train Loss: -3.6802, Train Steps/Sec: 0.54 +[2026-02-03 08:27:34] (step=0003600) Train Loss: -3.6834, Train Steps/Sec: 0.53 +[2026-02-03 08:30:40] (step=0003700) Train Loss: -3.6810, Train Steps/Sec: 0.54 +[2026-02-03 08:33:48] (step=0003800) Train Loss: -3.6822, Train Steps/Sec: 0.53 +[2026-02-03 08:36:55] (step=0003900) Train Loss: -3.6817, Train Steps/Sec: 0.53 +[2026-02-03 08:40:01] (step=0004000) Train Loss: -3.6794, Train Steps/Sec: 0.54 +[2026-02-03 08:43:08] (step=0004100) Train Loss: -3.6801, Train Steps/Sec: 0.54 +[2026-02-03 08:46:15] (step=0004200) Train Loss: -3.6850, Train Steps/Sec: 0.54 +[2026-02-03 08:49:21] (step=0004300) Train Loss: -3.6801, Train Steps/Sec: 0.54 +[2026-02-03 08:52:28] (step=0004400) Train Loss: -3.6816, Train Steps/Sec: 0.54 +[2026-02-03 08:55:35] (step=0004500) Train Loss: -3.6820, Train Steps/Sec: 0.53 +[2026-02-03 08:58:42] (step=0004600) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-03 09:01:49] (step=0004700) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-03 09:04:56] (step=0004800) Train Loss: -3.6797, Train Steps/Sec: 0.53 +[2026-02-03 09:08:03] (step=0004900) Train Loss: -3.6800, Train Steps/Sec: 0.53 +[2026-02-03 09:11:10] (step=0005000) Train Loss: -3.6831, Train Steps/Sec: 0.54 +[2026-02-03 09:11:18] Beginning epoch 2... +[2026-02-03 09:14:20] (step=0005100) Train Loss: -3.6803, Train Steps/Sec: 0.52 +[2026-02-03 09:17:27] (step=0005200) Train Loss: -3.6804, Train Steps/Sec: 0.53 +[2026-02-03 09:20:34] (step=0005300) Train Loss: -3.6804, Train Steps/Sec: 0.54 +[2026-02-03 09:23:40] (step=0005400) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-03 09:26:47] (step=0005500) Train Loss: -3.6819, Train Steps/Sec: 0.53 +[2026-02-03 09:29:54] (step=0005600) Train Loss: -3.6834, Train Steps/Sec: 0.54 +[2026-02-03 09:33:01] (step=0005700) Train Loss: -3.6805, Train Steps/Sec: 0.53 +[2026-02-03 09:36:08] (step=0005800) Train Loss: -3.6827, Train Steps/Sec: 0.53 +[2026-02-03 09:39:15] (step=0005900) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 09:42:20] (step=0006000) Train Loss: -3.6807, Train Steps/Sec: 0.54 +[2026-02-03 09:45:27] (step=0006100) Train Loss: -3.6814, Train Steps/Sec: 0.53 +[2026-02-03 09:48:34] (step=0006200) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-03 09:51:40] (step=0006300) Train Loss: -3.6799, Train Steps/Sec: 0.54 +[2026-02-03 09:54:46] (step=0006400) Train Loss: -3.6797, Train Steps/Sec: 0.54 +[2026-02-03 09:57:54] (step=0006500) Train Loss: -3.6820, Train Steps/Sec: 0.53 +[2026-02-03 10:01:01] (step=0006600) Train Loss: -3.6789, Train Steps/Sec: 0.53 +[2026-02-03 10:04:08] (step=0006700) Train Loss: -3.6804, Train Steps/Sec: 0.53 +[2026-02-03 10:07:15] (step=0006800) Train Loss: -3.6803, Train Steps/Sec: 0.53 +[2026-02-03 10:10:22] (step=0006900) Train Loss: -3.6787, Train Steps/Sec: 0.54 +[2026-02-03 10:13:29] (step=0007000) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-03 10:16:35] (step=0007100) Train Loss: -3.6813, Train Steps/Sec: 0.54 +[2026-02-03 10:19:42] (step=0007200) Train Loss: -3.6820, Train Steps/Sec: 0.54 +[2026-02-03 10:22:49] (step=0007300) Train Loss: -3.6810, Train Steps/Sec: 0.53 +[2026-02-03 10:25:56] (step=0007400) Train Loss: -3.6828, Train Steps/Sec: 0.53 +[2026-02-03 10:29:04] (step=0007500) Train Loss: -3.6821, Train Steps/Sec: 0.53 +[2026-02-03 10:29:16] Beginning epoch 3... +[2026-02-03 10:32:13] (step=0007600) Train Loss: -3.6794, Train Steps/Sec: 0.53 +[2026-02-03 10:35:20] (step=0007700) Train Loss: -3.6809, Train Steps/Sec: 0.53 +[2026-02-03 10:38:27] (step=0007800) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-03 10:41:34] (step=0007900) Train Loss: -3.6813, Train Steps/Sec: 0.54 +[2026-02-03 10:44:41] (step=0008000) Train Loss: -3.6852, Train Steps/Sec: 0.53 +[2026-02-03 10:47:47] (step=0008100) Train Loss: -3.6820, Train Steps/Sec: 0.54 +[2026-02-03 10:50:54] (step=0008200) Train Loss: -3.6798, Train Steps/Sec: 0.54 +[2026-02-03 10:54:01] (step=0008300) Train Loss: -3.6772, Train Steps/Sec: 0.54 +[2026-02-03 10:57:07] (step=0008400) Train Loss: -3.6800, Train Steps/Sec: 0.54 +[2026-02-03 11:00:13] (step=0008500) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-03 11:03:19] (step=0008600) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-03 11:06:23] (step=0008700) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-03 11:09:29] (step=0008800) Train Loss: -3.6762, Train Steps/Sec: 0.54 +[2026-02-03 11:12:36] (step=0008900) Train Loss: -3.6838, Train Steps/Sec: 0.54 +[2026-02-03 11:15:43] (step=0009000) Train Loss: -3.6826, Train Steps/Sec: 0.53 +[2026-02-03 11:18:50] (step=0009100) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-03 11:21:57] (step=0009200) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-03 11:25:04] (step=0009300) Train Loss: -3.6819, Train Steps/Sec: 0.54 +[2026-02-03 11:28:11] (step=0009400) Train Loss: -3.6785, Train Steps/Sec: 0.53 +[2026-02-03 11:31:17] (step=0009500) Train Loss: -3.6769, Train Steps/Sec: 0.54 +[2026-02-03 11:34:24] (step=0009600) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-03 11:37:31] (step=0009700) Train Loss: -3.6856, Train Steps/Sec: 0.54 +[2026-02-03 11:40:38] (step=0009800) Train Loss: -3.6803, Train Steps/Sec: 0.53 +[2026-02-03 11:43:45] (step=0009900) Train Loss: -3.6805, Train Steps/Sec: 0.54 +[2026-02-03 11:46:51] (step=0010000) Train Loss: -3.6819, Train Steps/Sec: 0.54 +[2026-02-03 11:47:07] Beginning epoch 4... +[2026-02-03 11:50:01] (step=0010100) Train Loss: -3.6850, Train Steps/Sec: 0.53 +[2026-02-03 11:53:08] (step=0010200) Train Loss: -3.6816, Train Steps/Sec: 0.53 +[2026-02-03 11:56:15] (step=0010300) Train Loss: -3.6836, Train Steps/Sec: 0.53 +[2026-02-03 11:59:22] (step=0010400) Train Loss: -3.6789, Train Steps/Sec: 0.53 +[2026-02-03 12:02:29] (step=0010500) Train Loss: -3.6793, Train Steps/Sec: 0.54 +[2026-02-03 12:05:36] (step=0010600) Train Loss: -3.6834, Train Steps/Sec: 0.54 +[2026-02-03 12:08:42] (step=0010700) Train Loss: -3.6842, Train Steps/Sec: 0.54 +[2026-02-03 12:11:49] (step=0010800) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-03 12:14:56] (step=0010900) Train Loss: -3.6813, Train Steps/Sec: 0.54 +[2026-02-03 12:18:03] (step=0011000) Train Loss: -3.6843, Train Steps/Sec: 0.53 +[2026-02-03 12:21:09] (step=0011100) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 12:24:15] (step=0011200) Train Loss: -3.6787, Train Steps/Sec: 0.54 +[2026-02-03 12:27:20] (step=0011300) Train Loss: -3.6828, Train Steps/Sec: 0.54 +[2026-02-03 12:30:27] (step=0011400) Train Loss: -3.6830, Train Steps/Sec: 0.53 +[2026-02-03 12:33:34] (step=0011500) Train Loss: -3.6784, Train Steps/Sec: 0.53 +[2026-02-03 12:36:41] (step=0011600) Train Loss: -3.6831, Train Steps/Sec: 0.53 +[2026-02-03 12:39:48] (step=0011700) Train Loss: -3.6834, Train Steps/Sec: 0.53 +[2026-02-03 12:42:55] (step=0011800) Train Loss: -3.6808, Train Steps/Sec: 0.53 +[2026-02-03 12:46:02] (step=0011900) Train Loss: -3.6810, Train Steps/Sec: 0.54 +[2026-02-03 12:49:09] (step=0012000) Train Loss: -3.6821, Train Steps/Sec: 0.53 +[2026-02-03 12:52:16] (step=0012100) Train Loss: -3.6827, Train Steps/Sec: 0.53 +[2026-02-03 12:55:23] (step=0012200) Train Loss: -3.6827, Train Steps/Sec: 0.54 +[2026-02-03 12:58:30] (step=0012300) Train Loss: -3.6808, Train Steps/Sec: 0.54 +[2026-02-03 13:01:37] (step=0012400) Train Loss: -3.6818, Train Steps/Sec: 0.53 +[2026-02-03 13:04:44] (step=0012500) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-03 13:05:03] Beginning epoch 5... +[2026-02-03 13:07:54] (step=0012600) Train Loss: -3.6814, Train Steps/Sec: 0.52 +[2026-02-03 13:11:01] (step=0012700) Train Loss: -3.6842, Train Steps/Sec: 0.53 +[2026-02-03 13:14:08] (step=0012800) Train Loss: -3.6816, Train Steps/Sec: 0.54 +[2026-02-03 13:17:15] (step=0012900) Train Loss: -3.6790, Train Steps/Sec: 0.53 +[2026-02-03 13:20:22] (step=0013000) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-03 13:23:29] (step=0013100) Train Loss: -3.6792, Train Steps/Sec: 0.53 +[2026-02-03 13:26:36] (step=0013200) Train Loss: -3.6836, Train Steps/Sec: 0.53 +[2026-02-03 13:29:43] (step=0013300) Train Loss: -3.6845, Train Steps/Sec: 0.54 +[2026-02-03 13:32:50] (step=0013400) Train Loss: -3.6822, Train Steps/Sec: 0.53 +[2026-02-03 13:35:57] (step=0013500) Train Loss: -3.6798, Train Steps/Sec: 0.53 +[2026-02-03 13:39:04] (step=0013600) Train Loss: -3.6828, Train Steps/Sec: 0.54 +[2026-02-03 13:42:11] (step=0013700) Train Loss: -3.6799, Train Steps/Sec: 0.54 +[2026-02-03 13:45:18] (step=0013800) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-03 13:48:22] (step=0013900) Train Loss: -3.6831, Train Steps/Sec: 0.54 +[2026-02-03 13:51:29] (step=0014000) Train Loss: -3.6808, Train Steps/Sec: 0.54 +[2026-02-03 13:54:36] (step=0014100) Train Loss: -3.6823, Train Steps/Sec: 0.53 +[2026-02-03 13:57:43] (step=0014200) Train Loss: -3.6795, Train Steps/Sec: 0.54 +[2026-02-03 14:00:50] (step=0014300) Train Loss: -3.6795, Train Steps/Sec: 0.53 +[2026-02-03 14:03:57] (step=0014400) Train Loss: -3.6838, Train Steps/Sec: 0.54 +[2026-02-03 14:07:04] (step=0014500) Train Loss: -3.6832, Train Steps/Sec: 0.53 +[2026-02-03 14:10:11] (step=0014600) Train Loss: -3.6832, Train Steps/Sec: 0.53 +[2026-02-03 14:13:18] (step=0014700) Train Loss: -3.6784, Train Steps/Sec: 0.54 +[2026-02-03 14:16:24] (step=0014800) Train Loss: -3.6824, Train Steps/Sec: 0.54 +[2026-02-03 14:19:31] (step=0014900) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-03 14:22:38] (step=0015000) Train Loss: -3.6822, Train Steps/Sec: 0.53 +[2026-02-03 14:23:01] Beginning epoch 6... +[2026-02-03 14:25:48] (step=0015100) Train Loss: -3.6831, Train Steps/Sec: 0.53 +[2026-02-03 14:28:55] (step=0015200) Train Loss: -3.6786, Train Steps/Sec: 0.53 +[2026-02-03 14:32:02] (step=0015300) Train Loss: -3.6826, Train Steps/Sec: 0.54 +[2026-02-03 14:35:08] (step=0015400) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-03 14:38:15] (step=0015500) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-03 14:41:21] (step=0015600) Train Loss: -3.6796, Train Steps/Sec: 0.54 +[2026-02-03 14:44:28] (step=0015700) Train Loss: -3.6839, Train Steps/Sec: 0.54 +[2026-02-03 14:47:36] (step=0015800) Train Loss: -3.6846, Train Steps/Sec: 0.53 +[2026-02-03 14:50:43] (step=0015900) Train Loss: -3.6828, Train Steps/Sec: 0.53 +[2026-02-03 14:53:50] (step=0016000) Train Loss: -3.6828, Train Steps/Sec: 0.54 +[2026-02-03 14:56:57] (step=0016100) Train Loss: -3.6789, Train Steps/Sec: 0.53 +[2026-02-03 15:00:04] (step=0016200) Train Loss: -3.6810, Train Steps/Sec: 0.53 +[2026-02-03 15:03:11] (step=0016300) Train Loss: -3.6799, Train Steps/Sec: 0.53 +[2026-02-03 15:06:19] (step=0016400) Train Loss: -3.6806, Train Steps/Sec: 0.53 +[2026-02-03 15:09:24] (step=0016500) Train Loss: -3.6828, Train Steps/Sec: 0.54 +[2026-02-03 15:12:31] (step=0016600) Train Loss: -3.6781, Train Steps/Sec: 0.54 +[2026-02-03 15:15:37] (step=0016700) Train Loss: -3.6830, Train Steps/Sec: 0.54 +[2026-02-03 15:18:44] (step=0016800) Train Loss: -3.6756, Train Steps/Sec: 0.54 +[2026-02-03 15:21:51] (step=0016900) Train Loss: -3.6798, Train Steps/Sec: 0.54 +[2026-02-03 15:24:58] (step=0017000) Train Loss: -3.6813, Train Steps/Sec: 0.53 +[2026-02-03 15:28:04] (step=0017100) Train Loss: -3.6807, Train Steps/Sec: 0.54 +[2026-02-03 15:31:11] (step=0017200) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-03 15:34:18] (step=0017300) Train Loss: -3.6800, Train Steps/Sec: 0.54 +[2026-02-03 15:37:25] (step=0017400) Train Loss: -3.6836, Train Steps/Sec: 0.53 +[2026-02-03 15:40:32] (step=0017500) Train Loss: -3.6807, Train Steps/Sec: 0.53 +[2026-02-03 15:40:59] Beginning epoch 7... +[2026-02-03 15:43:42] (step=0017600) Train Loss: -3.6829, Train Steps/Sec: 0.53 +[2026-02-03 15:46:49] (step=0017700) Train Loss: -3.6790, Train Steps/Sec: 0.53 +[2026-02-03 15:49:56] (step=0017800) Train Loss: -3.6850, Train Steps/Sec: 0.53 +[2026-02-03 15:53:04] (step=0017900) Train Loss: -3.6803, Train Steps/Sec: 0.53 +[2026-02-03 15:56:11] (step=0018000) Train Loss: -3.6835, Train Steps/Sec: 0.53 +[2026-02-03 15:59:18] (step=0018100) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-03 16:02:25] (step=0018200) Train Loss: -3.6788, Train Steps/Sec: 0.53 +[2026-02-03 16:05:31] (step=0018300) Train Loss: -3.6786, Train Steps/Sec: 0.54 +[2026-02-03 16:08:39] (step=0018400) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-03 16:11:46] (step=0018500) Train Loss: -3.6809, Train Steps/Sec: 0.53 +[2026-02-03 16:14:52] (step=0018600) Train Loss: -3.6803, Train Steps/Sec: 0.54 +[2026-02-03 16:17:59] (step=0018700) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-03 16:21:06] (step=0018800) Train Loss: -3.6819, Train Steps/Sec: 0.53 +[2026-02-03 16:24:12] (step=0018900) Train Loss: -3.6834, Train Steps/Sec: 0.54 +[2026-02-03 16:27:19] (step=0019000) Train Loss: -3.6824, Train Steps/Sec: 0.54 +[2026-02-03 16:30:24] (step=0019100) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-03 16:33:31] (step=0019200) Train Loss: -3.6826, Train Steps/Sec: 0.53 +[2026-02-03 16:36:38] (step=0019300) Train Loss: -3.6774, Train Steps/Sec: 0.53 +[2026-02-03 16:39:45] (step=0019400) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-03 16:42:51] (step=0019500) Train Loss: -3.6837, Train Steps/Sec: 0.54 +[2026-02-03 16:45:59] (step=0019600) Train Loss: -3.6828, Train Steps/Sec: 0.53 +[2026-02-03 16:49:06] (step=0019700) Train Loss: -3.6803, Train Steps/Sec: 0.53 +[2026-02-03 16:52:13] (step=0019800) Train Loss: -3.6828, Train Steps/Sec: 0.53 +[2026-02-03 16:55:20] (step=0019900) Train Loss: -3.6832, Train Steps/Sec: 0.53 +[2026-02-03 16:58:27] (step=0020000) Train Loss: -3.6837, Train Steps/Sec: 0.54 +[2026-02-03 16:58:57] Beginning epoch 8... +[2026-02-03 17:01:37] (step=0020100) Train Loss: -3.6820, Train Steps/Sec: 0.52 +[2026-02-03 17:04:45] (step=0020200) Train Loss: -3.6798, Train Steps/Sec: 0.53 +[2026-02-03 17:07:52] (step=0020300) Train Loss: -3.6807, Train Steps/Sec: 0.53 +[2026-02-03 17:10:59] (step=0020400) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-03 17:14:05] (step=0020500) Train Loss: -3.6794, Train Steps/Sec: 0.54 +[2026-02-03 17:17:13] (step=0020600) Train Loss: -3.6833, Train Steps/Sec: 0.53 +[2026-02-03 17:20:20] (step=0020700) Train Loss: -3.6802, Train Steps/Sec: 0.53 +[2026-02-03 17:23:27] (step=0020800) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-03 17:26:34] (step=0020900) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-03 17:29:41] (step=0021000) Train Loss: -3.6795, Train Steps/Sec: 0.53 +[2026-02-03 17:32:48] (step=0021100) Train Loss: -3.6794, Train Steps/Sec: 0.53 +[2026-02-03 17:35:55] (step=0021200) Train Loss: 3.9167, Train Steps/Sec: 0.53 +[2026-02-03 17:39:02] (step=0021300) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 17:42:09] (step=0021400) Train Loss: -3.6805, Train Steps/Sec: 0.53 +[2026-02-03 17:45:16] (step=0021500) Train Loss: -3.6808, Train Steps/Sec: 0.54 +[2026-02-03 17:48:23] (step=0021600) Train Loss: -3.6812, Train Steps/Sec: 0.54 +[2026-02-03 17:51:28] (step=0021700) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-03 17:54:34] (step=0021800) Train Loss: -3.6846, Train Steps/Sec: 0.54 +[2026-02-03 17:57:41] (step=0021900) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-03 18:00:48] (step=0022000) Train Loss: -3.6807, Train Steps/Sec: 0.54 +[2026-02-03 18:03:55] (step=0022100) Train Loss: -3.6799, Train Steps/Sec: 0.53 +[2026-02-03 18:07:02] (step=0022200) Train Loss: -3.6788, Train Steps/Sec: 0.53 +[2026-02-03 18:10:09] (step=0022300) Train Loss: -3.6821, Train Steps/Sec: 0.53 +[2026-02-03 18:13:16] (step=0022400) Train Loss: -3.6808, Train Steps/Sec: 0.53 +[2026-02-03 18:16:24] (step=0022500) Train Loss: -3.6836, Train Steps/Sec: 0.53 +[2026-02-03 18:16:58] Beginning epoch 9... +[2026-02-03 18:19:34] (step=0022600) Train Loss: -3.6835, Train Steps/Sec: 0.53 +[2026-02-03 18:22:40] (step=0022700) Train Loss: -3.6848, Train Steps/Sec: 0.54 +[2026-02-03 18:25:47] (step=0022800) Train Loss: -3.6778, Train Steps/Sec: 0.54 +[2026-02-03 18:28:53] (step=0022900) Train Loss: -3.6829, Train Steps/Sec: 0.54 +[2026-02-03 18:32:00] (step=0023000) Train Loss: -3.6807, Train Steps/Sec: 0.54 +[2026-02-03 18:35:07] (step=0023100) Train Loss: -3.6846, Train Steps/Sec: 0.53 +[2026-02-03 18:38:14] (step=0023200) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-03 18:41:21] (step=0023300) Train Loss: -3.6807, Train Steps/Sec: 0.53 +[2026-02-03 18:44:28] (step=0023400) Train Loss: -3.6812, Train Steps/Sec: 0.54 +[2026-02-03 18:47:35] (step=0023500) Train Loss: -3.6811, Train Steps/Sec: 0.53 +[2026-02-03 18:50:42] (step=0023600) Train Loss: -3.6800, Train Steps/Sec: 0.53 +[2026-02-03 18:53:49] (step=0023700) Train Loss: -3.6848, Train Steps/Sec: 0.53 +[2026-02-03 18:56:56] (step=0023800) Train Loss: -3.6824, Train Steps/Sec: 0.54 +[2026-02-03 19:00:03] (step=0023900) Train Loss: -3.6820, Train Steps/Sec: 0.54 +[2026-02-03 19:03:09] (step=0024000) Train Loss: -3.6848, Train Steps/Sec: 0.54 +[2026-02-03 19:06:16] (step=0024100) Train Loss: -3.6791, Train Steps/Sec: 0.54 +[2026-02-03 19:09:22] (step=0024200) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-03 19:12:30] (step=0024300) Train Loss: -3.6800, Train Steps/Sec: 0.53 +[2026-02-03 19:15:35] (step=0024400) Train Loss: -3.6792, Train Steps/Sec: 0.54 +[2026-02-03 19:18:42] (step=0024500) Train Loss: -3.6807, Train Steps/Sec: 0.53 +[2026-02-03 19:21:49] (step=0024600) Train Loss: -3.6796, Train Steps/Sec: 0.53 +[2026-02-03 19:24:56] (step=0024700) Train Loss: -3.6814, Train Steps/Sec: 0.53 +[2026-02-03 19:28:03] (step=0024800) Train Loss: -3.6832, Train Steps/Sec: 0.54 +[2026-02-03 19:31:10] (step=0024900) Train Loss: -3.6832, Train Steps/Sec: 0.54 +[2026-02-03 19:34:18] (step=0025000) Train Loss: -3.6782, Train Steps/Sec: 0.53 +[2026-02-03 19:34:18] Saved checkpoint to results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/checkpoints/0025000.pt +[2026-02-03 19:34:56] Beginning epoch 10... +[2026-02-03 19:37:29] (step=0025100) Train Loss: -3.6836, Train Steps/Sec: 0.52 +[2026-02-03 19:40:21] Generating EMA samples... +[2026-02-03 19:40:36] (step=0025200) Train Loss: -3.6796, Train Steps/Sec: 0.53 +[2026-02-03 19:43:43] (step=0025300) Train Loss: -3.6818, Train Steps/Sec: 0.53 +[2026-02-03 19:46:50] (step=0025400) Train Loss: -3.6789, Train Steps/Sec: 0.54 +[2026-02-03 19:49:58] (step=0025500) Train Loss: -3.6817, Train Steps/Sec: 0.53 +[2026-02-03 19:53:05] (step=0025600) Train Loss: -3.6804, Train Steps/Sec: 0.53 +[2026-02-03 19:56:11] (step=0025700) Train Loss: -3.6800, Train Steps/Sec: 0.54 +[2026-02-03 19:59:19] (step=0025800) Train Loss: -3.6832, Train Steps/Sec: 0.53 +[2026-02-03 20:02:25] (step=0025900) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-03 20:05:32] (step=0026000) Train Loss: -3.6812, Train Steps/Sec: 0.54 +[2026-02-03 20:08:39] (step=0026100) Train Loss: -3.6827, Train Steps/Sec: 0.54 +[2026-02-03 20:11:47] (step=0026200) Train Loss: -3.6793, Train Steps/Sec: 0.53 +[2026-02-03 20:14:54] (step=0026300) Train Loss: -3.6817, Train Steps/Sec: 0.53 +[2026-02-03 20:18:01] (step=0026400) Train Loss: -3.6813, Train Steps/Sec: 0.54 +[2026-02-03 20:21:07] (step=0026500) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-03 20:24:14] (step=0026600) Train Loss: -3.6842, Train Steps/Sec: 0.54 +[2026-02-03 20:27:20] (step=0026700) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-03 20:30:27] (step=0026800) Train Loss: -3.6849, Train Steps/Sec: 0.53 +[2026-02-03 20:33:34] (step=0026900) Train Loss: -3.6802, Train Steps/Sec: 0.53 +[2026-02-03 20:36:39] (step=0027000) Train Loss: -3.6792, Train Steps/Sec: 0.54 +[2026-02-03 20:39:46] (step=0027100) Train Loss: -3.6843, Train Steps/Sec: 0.54 +[2026-02-03 20:42:52] (step=0027200) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 20:45:59] (step=0027300) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-03 20:49:06] (step=0027400) Train Loss: -3.6775, Train Steps/Sec: 0.54 +[2026-02-03 20:52:12] (step=0027500) Train Loss: -3.6800, Train Steps/Sec: 0.54 +[2026-02-03 20:52:54] Beginning epoch 11... +[2026-02-03 20:55:23] (step=0027600) Train Loss: -3.6853, Train Steps/Sec: 0.53 +[2026-02-03 20:58:29] (step=0027700) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-03 21:01:37] (step=0027800) Train Loss: -3.6811, Train Steps/Sec: 0.53 +[2026-02-03 21:04:43] (step=0027900) Train Loss: -3.6810, Train Steps/Sec: 0.54 +[2026-02-03 21:07:50] (step=0028000) Train Loss: -3.6827, Train Steps/Sec: 0.53 +[2026-02-03 21:10:57] (step=0028100) Train Loss: -3.6839, Train Steps/Sec: 0.53 +[2026-02-03 21:14:04] (step=0028200) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-03 21:17:11] (step=0028300) Train Loss: -3.6830, Train Steps/Sec: 0.53 +[2026-02-03 21:20:18] (step=0028400) Train Loss: -3.6797, Train Steps/Sec: 0.53 +[2026-02-03 21:23:25] (step=0028500) Train Loss: -3.6797, Train Steps/Sec: 0.53 +[2026-02-03 21:26:32] (step=0028600) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 21:29:39] (step=0028700) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-03 21:32:45] (step=0028800) Train Loss: -3.6812, Train Steps/Sec: 0.54 +[2026-02-03 21:35:53] (step=0028900) Train Loss: -3.6858, Train Steps/Sec: 0.53 +[2026-02-03 21:38:59] (step=0029000) Train Loss: -3.6842, Train Steps/Sec: 0.54 +[2026-02-03 21:42:06] (step=0029100) Train Loss: -3.6836, Train Steps/Sec: 0.54 +[2026-02-03 21:45:14] (step=0029200) Train Loss: -3.6813, Train Steps/Sec: 0.53 +[2026-02-03 21:48:20] (step=0029300) Train Loss: -3.6783, Train Steps/Sec: 0.54 +[2026-02-03 21:51:27] (step=0029400) Train Loss: -3.6829, Train Steps/Sec: 0.53 +[2026-02-03 21:54:34] (step=0029500) Train Loss: -3.6812, Train Steps/Sec: 0.54 +[2026-02-03 21:57:39] (step=0029600) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-03 22:00:46] (step=0029700) Train Loss: -3.6828, Train Steps/Sec: 0.53 +[2026-02-03 22:03:53] (step=0029800) Train Loss: -3.6826, Train Steps/Sec: 0.54 +[2026-02-03 22:06:59] (step=0029900) Train Loss: -3.6814, Train Steps/Sec: 0.54 +[2026-02-03 22:10:06] (step=0030000) Train Loss: -3.6837, Train Steps/Sec: 0.54 +[2026-02-03 22:10:51] Beginning epoch 12... +[2026-02-03 22:13:16] (step=0030100) Train Loss: -3.6822, Train Steps/Sec: 0.53 +[2026-02-03 22:16:22] (step=0030200) Train Loss: -3.6787, Train Steps/Sec: 0.54 +[2026-02-03 22:19:29] (step=0030300) Train Loss: -3.6815, Train Steps/Sec: 0.53 +[2026-02-03 22:22:37] (step=0030400) Train Loss: -3.6806, Train Steps/Sec: 0.53 +[2026-02-03 22:25:44] (step=0030500) Train Loss: -3.6825, Train Steps/Sec: 0.53 +[2026-02-03 22:28:51] (step=0030600) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-03 22:31:58] (step=0030700) Train Loss: -3.6838, Train Steps/Sec: 0.54 +[2026-02-03 22:35:05] (step=0030800) Train Loss: -3.6822, Train Steps/Sec: 0.53 +[2026-02-03 22:38:11] (step=0030900) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-03 22:41:18] (step=0031000) Train Loss: -3.6815, Train Steps/Sec: 0.54 +[2026-02-03 22:44:25] (step=0031100) Train Loss: -3.6796, Train Steps/Sec: 0.53 +[2026-02-03 22:47:32] (step=0031200) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-03 22:50:39] (step=0031300) Train Loss: -3.6806, Train Steps/Sec: 0.53 +[2026-02-03 22:53:46] (step=0031400) Train Loss: -3.6822, Train Steps/Sec: 0.53 +[2026-02-03 22:56:53] (step=0031500) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 23:00:00] (step=0031600) Train Loss: -3.6803, Train Steps/Sec: 0.53 +[2026-02-03 23:03:07] (step=0031700) Train Loss: -3.6843, Train Steps/Sec: 0.53 +[2026-02-03 23:06:14] (step=0031800) Train Loss: -3.6832, Train Steps/Sec: 0.53 +[2026-02-03 23:09:21] (step=0031900) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-03 23:12:28] (step=0032000) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-03 23:15:34] (step=0032100) Train Loss: -3.6786, Train Steps/Sec: 0.54 +[2026-02-03 23:18:39] (step=0032200) Train Loss: -3.6814, Train Steps/Sec: 0.54 +[2026-02-03 23:21:46] (step=0032300) Train Loss: -3.6839, Train Steps/Sec: 0.54 +[2026-02-03 23:24:52] (step=0032400) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-03 23:27:59] (step=0032500) Train Loss: -3.6809, Train Steps/Sec: 0.53 +[2026-02-03 23:28:48] Beginning epoch 13... +[2026-02-03 23:31:09] (step=0032600) Train Loss: -3.6846, Train Steps/Sec: 0.53 +[2026-02-03 23:34:16] (step=0032700) Train Loss: -3.6841, Train Steps/Sec: 0.53 +[2026-02-03 23:37:24] (step=0032800) Train Loss: -3.6813, Train Steps/Sec: 0.53 +[2026-02-03 23:40:31] (step=0032900) Train Loss: -3.6792, Train Steps/Sec: 0.53 +[2026-02-03 23:43:38] (step=0033000) Train Loss: -3.6782, Train Steps/Sec: 0.53 +[2026-02-03 23:46:45] (step=0033100) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-03 23:49:52] (step=0033200) Train Loss: -3.6819, Train Steps/Sec: 0.53 +[2026-02-03 23:52:59] (step=0033300) Train Loss: -3.6793, Train Steps/Sec: 0.54 +[2026-02-03 23:56:06] (step=0033400) Train Loss: -3.6810, Train Steps/Sec: 0.54 +[2026-02-03 23:59:13] (step=0033500) Train Loss: -3.6816, Train Steps/Sec: 0.53 +[2026-02-04 00:02:20] (step=0033600) Train Loss: -3.6831, Train Steps/Sec: 0.54 +[2026-02-04 00:05:26] (step=0033700) Train Loss: -3.6831, Train Steps/Sec: 0.54 +[2026-02-04 00:08:33] (step=0033800) Train Loss: -3.6826, Train Steps/Sec: 0.54 +[2026-02-04 00:11:40] (step=0033900) Train Loss: -3.6804, Train Steps/Sec: 0.54 +[2026-02-04 00:14:46] (step=0034000) Train Loss: -3.6789, Train Steps/Sec: 0.54 +[2026-02-04 00:17:54] (step=0034100) Train Loss: -3.6814, Train Steps/Sec: 0.53 +[2026-02-04 00:21:00] (step=0034200) Train Loss: -3.6805, Train Steps/Sec: 0.54 +[2026-02-04 00:24:07] (step=0034300) Train Loss: -3.6837, Train Steps/Sec: 0.53 +[2026-02-04 00:27:14] (step=0034400) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-04 00:30:20] (step=0034500) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-04 00:33:27] (step=0034600) Train Loss: -3.6821, Train Steps/Sec: 0.54 +[2026-02-04 00:36:34] (step=0034700) Train Loss: -3.6799, Train Steps/Sec: 0.54 +[2026-02-04 00:39:38] (step=0034800) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-04 00:42:45] (step=0034900) Train Loss: -3.6820, Train Steps/Sec: 0.54 +[2026-02-04 00:45:52] (step=0035000) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-04 00:46:45] Beginning epoch 14... +[2026-02-04 00:49:01] (step=0035100) Train Loss: -3.6794, Train Steps/Sec: 0.53 +[2026-02-04 00:52:08] (step=0035200) Train Loss: -3.6804, Train Steps/Sec: 0.54 +[2026-02-04 00:55:15] (step=0035300) Train Loss: -3.6825, Train Steps/Sec: 0.53 +[2026-02-04 00:58:22] (step=0035400) Train Loss: -3.6817, Train Steps/Sec: 0.53 +[2026-02-04 01:01:29] (step=0035500) Train Loss: -3.6840, Train Steps/Sec: 0.54 +[2026-02-04 01:04:35] (step=0035600) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-04 01:07:42] (step=0035700) Train Loss: -3.6796, Train Steps/Sec: 0.53 +[2026-02-04 01:10:50] (step=0035800) Train Loss: -3.6834, Train Steps/Sec: 0.53 +[2026-02-04 01:13:56] (step=0035900) Train Loss: -3.6763, Train Steps/Sec: 0.54 +[2026-02-04 01:17:03] (step=0036000) Train Loss: -3.6837, Train Steps/Sec: 0.53 +[2026-02-04 01:20:10] (step=0036100) Train Loss: -3.6806, Train Steps/Sec: 0.53 +[2026-02-04 01:23:18] (step=0036200) Train Loss: -3.6821, Train Steps/Sec: 0.53 +[2026-02-04 01:26:24] (step=0036300) Train Loss: -3.6772, Train Steps/Sec: 0.54 +[2026-02-04 01:29:31] (step=0036400) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-04 01:32:38] (step=0036500) Train Loss: -3.6816, Train Steps/Sec: 0.54 +[2026-02-04 01:35:45] (step=0036600) Train Loss: -3.6792, Train Steps/Sec: 0.53 +[2026-02-04 01:38:51] (step=0036700) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-04 01:41:59] (step=0036800) Train Loss: -3.6835, Train Steps/Sec: 0.53 +[2026-02-04 01:45:05] (step=0036900) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-04 01:48:12] (step=0037000) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-04 01:51:18] (step=0037100) Train Loss: -3.6775, Train Steps/Sec: 0.54 +[2026-02-04 01:54:25] (step=0037200) Train Loss: -3.6796, Train Steps/Sec: 0.54 +[2026-02-04 01:57:31] (step=0037300) Train Loss: -3.6806, Train Steps/Sec: 0.54 +[2026-02-04 02:00:38] (step=0037400) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-04 02:03:43] (step=0037500) Train Loss: -3.6808, Train Steps/Sec: 0.54 +[2026-02-04 02:04:39] Beginning epoch 15... +[2026-02-04 02:06:52] (step=0037600) Train Loss: -3.6847, Train Steps/Sec: 0.53 +[2026-02-04 02:10:00] (step=0037700) Train Loss: -3.6837, Train Steps/Sec: 0.53 +[2026-02-04 02:13:06] (step=0037800) Train Loss: -3.6796, Train Steps/Sec: 0.54 +[2026-02-04 02:16:13] (step=0037900) Train Loss: -3.6804, Train Steps/Sec: 0.54 +[2026-02-04 02:19:20] (step=0038000) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-04 02:22:26] (step=0038100) Train Loss: -3.6803, Train Steps/Sec: 0.54 +[2026-02-04 02:25:33] (step=0038200) Train Loss: -3.6813, Train Steps/Sec: 0.54 +[2026-02-04 02:28:40] (step=0038300) Train Loss: -3.6798, Train Steps/Sec: 0.53 +[2026-02-04 02:31:47] (step=0038400) Train Loss: -3.6797, Train Steps/Sec: 0.53 +[2026-02-04 02:34:54] (step=0038500) Train Loss: -3.6817, Train Steps/Sec: 0.54 +[2026-02-04 02:38:01] (step=0038600) Train Loss: -3.6818, Train Steps/Sec: 0.54 +[2026-02-04 02:41:08] (step=0038700) Train Loss: -3.6824, Train Steps/Sec: 0.54 +[2026-02-04 02:44:14] (step=0038800) Train Loss: -3.6800, Train Steps/Sec: 0.54 +[2026-02-04 02:47:22] (step=0038900) Train Loss: -3.6812, Train Steps/Sec: 0.53 +[2026-02-04 02:50:28] (step=0039000) Train Loss: -3.6826, Train Steps/Sec: 0.54 +[2026-02-04 02:53:35] (step=0039100) Train Loss: -3.6807, Train Steps/Sec: 0.53 +[2026-02-04 02:56:42] (step=0039200) Train Loss: -3.6831, Train Steps/Sec: 0.54 +[2026-02-04 02:59:48] (step=0039300) Train Loss: -3.6822, Train Steps/Sec: 0.54 +[2026-02-04 03:02:55] (step=0039400) Train Loss: -3.6803, Train Steps/Sec: 0.54 +[2026-02-04 03:06:01] (step=0039500) Train Loss: -3.6815, Train Steps/Sec: 0.54 +[2026-02-04 03:09:08] (step=0039600) Train Loss: -3.6830, Train Steps/Sec: 0.53 +[2026-02-04 03:12:15] (step=0039700) Train Loss: -3.6771, Train Steps/Sec: 0.54 +[2026-02-04 03:15:21] (step=0039800) Train Loss: -3.6791, Train Steps/Sec: 0.54 +[2026-02-04 03:18:28] (step=0039900) Train Loss: -3.6797, Train Steps/Sec: 0.54 +[2026-02-04 03:21:34] (step=0040000) Train Loss: -3.6815, Train Steps/Sec: 0.54 +[2026-02-04 03:22:33] Beginning epoch 16... +[2026-02-04 03:24:43] (step=0040100) Train Loss: -3.6799, Train Steps/Sec: 0.53 +[2026-02-04 03:27:50] (step=0040200) Train Loss: -3.6823, Train Steps/Sec: 0.53 +[2026-02-04 03:30:57] (step=0040300) Train Loss: -3.6805, Train Steps/Sec: 0.53 +[2026-02-04 03:34:04] (step=0040400) Train Loss: -3.6829, Train Steps/Sec: 0.54 +[2026-02-04 03:37:11] (step=0040500) Train Loss: -3.6786, Train Steps/Sec: 0.53 +[2026-02-04 03:40:18] (step=0040600) Train Loss: -3.6811, Train Steps/Sec: 0.54 +[2026-02-04 03:43:24] (step=0040700) Train Loss: -3.6804, Train Steps/Sec: 0.54 +[2026-02-04 03:46:32] (step=0040800) Train Loss: -3.6860, Train Steps/Sec: 0.53 +[2026-02-04 03:49:38] (step=0040900) Train Loss: -3.6804, Train Steps/Sec: 0.54 +[2026-02-04 03:52:44] (step=0041000) Train Loss: -3.6803, Train Steps/Sec: 0.54 +[2026-02-04 03:55:52] (step=0041100) Train Loss: -3.6803, Train Steps/Sec: 0.53 +[2026-02-04 03:58:59] (step=0041200) Train Loss: -3.6801, Train Steps/Sec: 0.53 +[2026-02-04 04:02:06] (step=0041300) Train Loss: -3.6794, Train Steps/Sec: 0.53 +[2026-02-04 04:05:14] (step=0041400) Train Loss: -3.6816, Train Steps/Sec: 0.53 +[2026-02-04 04:08:20] (step=0041500) Train Loss: -3.6858, Train Steps/Sec: 0.54 +[2026-02-04 04:11:27] (step=0041600) Train Loss: -3.6811, Train Steps/Sec: 0.53 +[2026-02-04 04:14:34] (step=0041700) Train Loss: -3.6859, Train Steps/Sec: 0.53 +[2026-02-04 04:17:41] (step=0041800) Train Loss: -3.6823, Train Steps/Sec: 0.54 +[2026-02-04 04:20:47] (step=0041900) Train Loss: -3.6838, Train Steps/Sec: 0.54 +[2026-02-04 04:23:54] (step=0042000) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-04 04:27:00] (step=0042100) Train Loss: -3.6781, Train Steps/Sec: 0.54 +[2026-02-04 04:30:07] (step=0042200) Train Loss: -3.6826, Train Steps/Sec: 0.54 +[2026-02-04 04:33:13] (step=0042300) Train Loss: -3.6835, Train Steps/Sec: 0.54 +[2026-02-04 04:36:20] (step=0042400) Train Loss: -3.6816, Train Steps/Sec: 0.54 +[2026-02-04 04:39:27] (step=0042500) Train Loss: -3.6802, Train Steps/Sec: 0.53 +[2026-02-04 04:40:31] Beginning epoch 17... +[2026-02-04 04:42:37] (step=0042600) Train Loss: -3.6831, Train Steps/Sec: 0.53 +[2026-02-04 04:45:42] (step=0042700) Train Loss: -3.6778, Train Steps/Sec: 0.54 +[2026-02-04 04:48:48] (step=0042800) Train Loss: -3.6846, Train Steps/Sec: 0.54 +[2026-02-04 04:51:55] (step=0042900) Train Loss: -3.6827, Train Steps/Sec: 0.53 +[2026-02-04 04:55:02] (step=0043000) Train Loss: -3.6820, Train Steps/Sec: 0.54 +[2026-02-04 04:58:08] (step=0043100) Train Loss: -3.6803, Train Steps/Sec: 0.54 +[2026-02-04 05:01:15] (step=0043200) Train Loss: -3.6808, Train Steps/Sec: 0.54 +[2026-02-04 05:04:22] (step=0043300) Train Loss: -3.6838, Train Steps/Sec: 0.53 +[2026-02-04 05:07:29] (step=0043400) Train Loss: -3.6809, Train Steps/Sec: 0.54 +[2026-02-04 05:10:36] (step=0043500) Train Loss: -3.6757, Train Steps/Sec: 0.53 +[2026-02-04 05:13:43] (step=0043600) Train Loss: -3.6808, Train Steps/Sec: 0.54 +[2026-02-04 05:16:50] (step=0043700) Train Loss: -3.6807, Train Steps/Sec: 0.54 +[2026-02-04 05:19:56] (step=0043800) Train Loss: -3.6825, Train Steps/Sec: 0.54 +[2026-02-04 05:23:03] (step=0043900) Train Loss: -3.6811, Train Steps/Sec: 0.53 +[2026-02-04 05:26:10] (step=0044000) Train Loss: -3.6819, Train Steps/Sec: 0.54 +[2026-02-04 05:29:17] (step=0044100) Train Loss: -3.6801, Train Steps/Sec: 0.54 +[2026-02-04 05:32:24] (step=0044200) Train Loss: -3.6785, Train Steps/Sec: 0.54 +[2026-02-04 05:35:31] (step=0044300) Train Loss: -3.6841, Train Steps/Sec: 0.53 +[2026-02-04 05:38:38] (step=0044400) Train Loss: -3.6841, Train Steps/Sec: 0.53 +[2026-02-04 05:41:01] (step=0044500) Train Loss: -3.6791, Train Steps/Sec: 0.70 +[2026-02-04 05:42:24] (step=0044600) Train Loss: -3.6843, Train Steps/Sec: 1.20 +[2026-02-04 05:43:47] (step=0044700) Train Loss: -3.6815, Train Steps/Sec: 1.21 +[2026-02-04 05:45:10] (step=0044800) Train Loss: -3.6785, Train Steps/Sec: 1.21 +[2026-02-04 05:46:33] (step=0044900) Train Loss: -3.6820, Train Steps/Sec: 1.21 +[2026-02-04 05:47:56] (step=0045000) Train Loss: -3.6847, Train Steps/Sec: 1.20 +[2026-02-04 05:48:26] Beginning epoch 18... +[2026-02-04 05:49:22] (step=0045100) Train Loss: -3.6816, Train Steps/Sec: 1.16 +[2026-02-04 05:50:45] (step=0045200) Train Loss: -3.6834, Train Steps/Sec: 1.20 +[2026-02-04 05:52:08] (step=0045300) Train Loss: -3.6787, Train Steps/Sec: 1.21 +[2026-02-04 05:53:31] (step=0045400) Train Loss: -3.6844, Train Steps/Sec: 1.20 +[2026-02-04 05:54:54] (step=0045500) Train Loss: -3.6823, Train Steps/Sec: 1.20 +[2026-02-04 05:56:17] (step=0045600) Train Loss: -3.6806, Train Steps/Sec: 1.20 +[2026-02-04 05:57:40] (step=0045700) Train Loss: -3.6797, Train Steps/Sec: 1.21 +[2026-02-04 05:59:03] (step=0045800) Train Loss: -3.6819, Train Steps/Sec: 1.20 +[2026-02-04 06:00:26] (step=0045900) Train Loss: -3.6807, Train Steps/Sec: 1.20 +[2026-02-04 06:01:49] (step=0046000) Train Loss: -3.6814, Train Steps/Sec: 1.21 +[2026-02-04 06:03:12] (step=0046100) Train Loss: -3.6827, Train Steps/Sec: 1.21 +[2026-02-04 06:04:35] (step=0046200) Train Loss: -3.6824, Train Steps/Sec: 1.20 +[2026-02-04 06:05:58] (step=0046300) Train Loss: -3.6825, Train Steps/Sec: 1.20 +[2026-02-04 06:07:21] (step=0046400) Train Loss: -3.6826, Train Steps/Sec: 1.20 +[2026-02-04 06:08:44] (step=0046500) Train Loss: -3.6778, Train Steps/Sec: 1.20 +[2026-02-04 06:10:07] (step=0046600) Train Loss: -3.6820, Train Steps/Sec: 1.20 +[2026-02-04 06:11:30] (step=0046700) Train Loss: -3.6830, Train Steps/Sec: 1.21 +[2026-02-04 06:12:53] (step=0046800) Train Loss: -3.6808, Train Steps/Sec: 1.20 +[2026-02-04 06:14:16] (step=0046900) Train Loss: -3.6812, Train Steps/Sec: 1.20 +[2026-02-04 06:15:39] (step=0047000) Train Loss: -3.6836, Train Steps/Sec: 1.20 +[2026-02-04 06:17:02] (step=0047100) Train Loss: -3.6806, Train Steps/Sec: 1.20 +[2026-02-04 06:18:25] (step=0047200) Train Loss: -3.6813, Train Steps/Sec: 1.20 +[2026-02-04 06:19:48] (step=0047300) Train Loss: -3.6828, Train Steps/Sec: 1.20 +[2026-02-04 06:21:11] (step=0047400) Train Loss: -3.6842, Train Steps/Sec: 1.21 diff --git a/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-003-SiT-XL-2-GVP-velocity-None/log.txt b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-003-SiT-XL-2-GVP-velocity-None/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..02ca309f74660d5681a636aa5c7ecc5c45b9cd3b --- /dev/null +++ b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-003-SiT-XL-2-GVP-velocity-None/log.txt @@ -0,0 +1,6 @@ +[2026-02-03 06:53:41] Experiment directory created at results_256_gvp_disp/depth-mu-2-003-SiT-XL-2-GVP-velocity-None +[2026-02-03 06:54:17] Combined_model Parameters: 729,629,632 +[2026-02-03 06:54:17] Total trainable parameters: 53,910,176 +[2026-02-03 06:54:19] Dataset contains 1,281,167 images (/gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/) +[2026-02-03 06:54:19] Training for 100000 epochs... +[2026-02-03 06:54:19] Beginning epoch 0... diff --git a/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None/log.txt b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..2766ed572bbdc4b097714b9ff950854ea366b006 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None/log.txt @@ -0,0 +1,863 @@ +[2026-02-03 06:55:12] Experiment directory created at results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None +[2026-02-03 06:55:47] Combined_model Parameters: 729,629,632 +[2026-02-03 06:55:47] Total trainable parameters: 53,910,176 +[2026-02-03 06:55:50] Dataset contains 1,281,167 images (/gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/) +[2026-02-03 06:55:50] Training for 100000 epochs... +[2026-02-03 06:55:50] Beginning epoch 0... +[2026-02-03 06:57:30] (step=0000100) Train Loss: -2.4789, Train Steps/Sec: 1.00 +[2026-02-03 06:59:08] (step=0000200) Train Loss: -2.9649, Train Steps/Sec: 1.02 +[2026-02-03 07:00:47] (step=0000300) Train Loss: -2.9777, Train Steps/Sec: 1.01 +[2026-02-03 07:02:27] (step=0000400) Train Loss: -2.9828, Train Steps/Sec: 1.00 +[2026-02-03 07:04:08] (step=0000500) Train Loss: -2.9877, Train Steps/Sec: 0.99 +[2026-02-03 07:05:49] (step=0000600) Train Loss: -2.9875, Train Steps/Sec: 0.99 +[2026-02-03 07:07:28] (step=0000700) Train Loss: -2.9882, Train Steps/Sec: 1.01 +[2026-02-03 07:09:08] (step=0000800) Train Loss: -2.9861, Train Steps/Sec: 1.00 +[2026-02-03 07:10:49] (step=0000900) Train Loss: -2.9862, Train Steps/Sec: 0.99 +[2026-02-03 07:12:30] (step=0001000) Train Loss: -2.9886, Train Steps/Sec: 0.99 +[2026-02-03 07:14:12] (step=0001100) Train Loss: -2.9849, Train Steps/Sec: 0.98 +[2026-02-03 07:18:10] (step=0001200) Train Loss: -2.9885, Train Steps/Sec: 0.42 +[2026-02-03 07:20:07] (step=0001300) Train Loss: -2.9864, Train Steps/Sec: 0.85 +[2026-02-03 07:21:45] (step=0001400) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 07:23:22] (step=0001500) Train Loss: -2.9863, Train Steps/Sec: 1.03 +[2026-02-03 07:25:00] (step=0001600) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 07:26:37] (step=0001700) Train Loss: -2.9930, Train Steps/Sec: 1.03 +[2026-02-03 07:28:14] (step=0001800) Train Loss: -2.9892, Train Steps/Sec: 1.03 +[2026-02-03 07:29:52] (step=0001900) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 07:31:30] (step=0002000) Train Loss: -2.9857, Train Steps/Sec: 1.03 +[2026-02-03 07:33:07] (step=0002100) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-03 07:34:45] (step=0002200) Train Loss: -2.9829, Train Steps/Sec: 1.03 +[2026-02-03 07:36:23] (step=0002300) Train Loss: -2.9862, Train Steps/Sec: 1.02 +[2026-02-03 07:38:00] (step=0002400) Train Loss: -2.9895, Train Steps/Sec: 1.03 +[2026-02-03 07:39:37] (step=0002500) Train Loss: -2.9878, Train Steps/Sec: 1.03 +[2026-02-03 07:41:15] (step=0002600) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-03 07:42:53] (step=0002700) Train Loss: -2.9906, Train Steps/Sec: 1.02 +[2026-02-03 07:44:31] (step=0002800) Train Loss: -2.9915, Train Steps/Sec: 1.02 +[2026-02-03 07:46:09] (step=0002900) Train Loss: -2.9871, Train Steps/Sec: 1.02 +[2026-02-03 07:47:47] (step=0003000) Train Loss: -2.9850, Train Steps/Sec: 1.03 +[2026-02-03 07:49:25] (step=0003100) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 07:51:03] (step=0003200) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 07:52:41] (step=0003300) Train Loss: -2.9943, Train Steps/Sec: 1.02 +[2026-02-03 07:54:15] (step=0003400) Train Loss: -2.9891, Train Steps/Sec: 1.06 +[2026-02-03 07:55:53] (step=0003500) Train Loss: -2.9845, Train Steps/Sec: 1.03 +[2026-02-03 07:57:30] (step=0003600) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-03 07:59:08] (step=0003700) Train Loss: -2.9916, Train Steps/Sec: 1.03 +[2026-02-03 08:00:46] (step=0003800) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 08:02:23] (step=0003900) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 08:04:01] (step=0004000) Train Loss: -2.9929, Train Steps/Sec: 1.02 +[2026-02-03 08:05:39] (step=0004100) Train Loss: -2.9882, Train Steps/Sec: 1.02 +[2026-02-03 08:07:17] (step=0004200) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 08:08:54] (step=0004300) Train Loss: -2.9849, Train Steps/Sec: 1.02 +[2026-02-03 08:10:32] (step=0004400) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-03 08:12:10] (step=0004500) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-03 08:13:47] (step=0004600) Train Loss: -2.9874, Train Steps/Sec: 1.03 +[2026-02-03 08:15:26] (step=0004700) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-03 08:17:04] (step=0004800) Train Loss: -2.9844, Train Steps/Sec: 1.02 +[2026-02-03 08:18:41] (step=0004900) Train Loss: -2.9825, Train Steps/Sec: 1.02 +[2026-02-03 08:20:19] (step=0005000) Train Loss: -2.9846, Train Steps/Sec: 1.02 +[2026-02-03 08:20:24] Beginning epoch 1... +[2026-02-03 08:21:59] (step=0005100) Train Loss: -2.9935, Train Steps/Sec: 1.00 +[2026-02-03 08:23:37] (step=0005200) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-03 08:25:15] (step=0005300) Train Loss: -2.9927, Train Steps/Sec: 1.02 +[2026-02-03 08:26:53] (step=0005400) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-03 08:28:31] (step=0005500) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 08:30:09] (step=0005600) Train Loss: -2.9912, Train Steps/Sec: 1.02 +[2026-02-03 08:31:46] (step=0005700) Train Loss: -2.9920, Train Steps/Sec: 1.03 +[2026-02-03 08:33:24] (step=0005800) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 08:35:02] (step=0005900) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-03 08:36:39] (step=0006000) Train Loss: -2.9900, Train Steps/Sec: 1.03 +[2026-02-03 08:38:17] (step=0006100) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 08:39:55] (step=0006200) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-03 08:41:33] (step=0006300) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 08:43:11] (step=0006400) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 08:44:49] (step=0006500) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 08:46:27] (step=0006600) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-03 08:48:05] (step=0006700) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 08:49:43] (step=0006800) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-03 08:51:20] (step=0006900) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 08:52:58] (step=0007000) Train Loss: -2.9848, Train Steps/Sec: 1.02 +[2026-02-03 08:54:36] (step=0007100) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 08:56:13] (step=0007200) Train Loss: -2.9936, Train Steps/Sec: 1.02 +[2026-02-03 08:57:51] (step=0007300) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 08:59:29] (step=0007400) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 09:01:07] (step=0007500) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 09:02:45] (step=0007600) Train Loss: -2.9875, Train Steps/Sec: 1.03 +[2026-02-03 09:04:22] (step=0007700) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-03 09:06:01] (step=0007800) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 09:07:39] (step=0007900) Train Loss: -2.9846, Train Steps/Sec: 1.02 +[2026-02-03 09:09:16] (step=0008000) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-03 09:10:54] (step=0008100) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-03 09:12:28] (step=0008200) Train Loss: -2.9826, Train Steps/Sec: 1.07 +[2026-02-03 09:14:06] (step=0008300) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 09:15:44] (step=0008400) Train Loss: -2.9933, Train Steps/Sec: 1.03 +[2026-02-03 09:17:21] (step=0008500) Train Loss: -2.9858, Train Steps/Sec: 1.02 +[2026-02-03 09:18:59] (step=0008600) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 09:20:37] (step=0008700) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 09:22:15] (step=0008800) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 09:23:53] (step=0008900) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-03 09:25:31] (step=0009000) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 09:27:09] (step=0009100) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 09:28:47] (step=0009200) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 09:30:24] (step=0009300) Train Loss: -2.9888, Train Steps/Sec: 1.03 +[2026-02-03 09:32:02] (step=0009400) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 09:33:40] (step=0009500) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 09:35:18] (step=0009600) Train Loss: -2.9820, Train Steps/Sec: 1.02 +[2026-02-03 09:36:56] (step=0009700) Train Loss: -2.9845, Train Steps/Sec: 1.02 +[2026-02-03 09:38:34] (step=0009800) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-03 09:40:12] (step=0009900) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-03 09:41:50] (step=0010000) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 09:41:58] Beginning epoch 2... +[2026-02-03 09:43:30] (step=0010100) Train Loss: -2.9883, Train Steps/Sec: 1.00 +[2026-02-03 09:45:08] (step=0010200) Train Loss: -2.9871, Train Steps/Sec: 1.02 +[2026-02-03 09:46:45] (step=0010300) Train Loss: -2.9880, Train Steps/Sec: 1.03 +[2026-02-03 09:48:22] (step=0010400) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 09:50:00] (step=0010500) Train Loss: -2.9857, Train Steps/Sec: 1.02 +[2026-02-03 09:51:38] (step=0010600) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 09:53:16] (step=0010700) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-03 09:54:55] (step=0010800) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 09:56:33] (step=0010900) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-03 09:58:11] (step=0011000) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 09:59:49] (step=0011100) Train Loss: -2.9910, Train Steps/Sec: 1.02 +[2026-02-03 10:01:27] (step=0011200) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 10:03:06] (step=0011300) Train Loss: -2.9864, Train Steps/Sec: 1.01 +[2026-02-03 10:04:43] (step=0011400) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 10:06:21] (step=0011500) Train Loss: -2.9933, Train Steps/Sec: 1.02 +[2026-02-03 10:08:00] (step=0011600) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 10:09:37] (step=0011700) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 10:11:16] (step=0011800) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 10:12:54] (step=0011900) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 10:14:32] (step=0012000) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 10:16:10] (step=0012100) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 10:17:48] (step=0012200) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-03 10:19:26] (step=0012300) Train Loss: -2.9856, Train Steps/Sec: 1.02 +[2026-02-03 10:21:04] (step=0012400) Train Loss: -2.9944, Train Steps/Sec: 1.02 +[2026-02-03 10:22:41] (step=0012500) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-03 10:24:20] (step=0012600) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 10:25:57] (step=0012700) Train Loss: -2.9851, Train Steps/Sec: 1.02 +[2026-02-03 10:27:35] (step=0012800) Train Loss: -2.9892, Train Steps/Sec: 1.02 +[2026-02-03 10:29:13] (step=0012900) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 10:30:47] (step=0013000) Train Loss: -2.9892, Train Steps/Sec: 1.06 +[2026-02-03 10:32:25] (step=0013100) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-03 10:34:02] (step=0013200) Train Loss: -2.9860, Train Steps/Sec: 1.03 +[2026-02-03 10:35:40] (step=0013300) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 10:37:18] (step=0013400) Train Loss: -2.9860, Train Steps/Sec: 1.02 +[2026-02-03 10:38:56] (step=0013500) Train Loss: -2.9910, Train Steps/Sec: 1.03 +[2026-02-03 10:40:33] (step=0013600) Train Loss: -2.9834, Train Steps/Sec: 1.02 +[2026-02-03 10:42:11] (step=0013700) Train Loss: -2.9847, Train Steps/Sec: 1.02 +[2026-02-03 10:43:49] (step=0013800) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 10:45:27] (step=0013900) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-03 10:47:05] (step=0014000) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 10:48:42] (step=0014100) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 10:50:20] (step=0014200) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 10:51:58] (step=0014300) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 10:53:35] (step=0014400) Train Loss: -2.9889, Train Steps/Sec: 1.03 +[2026-02-03 10:55:13] (step=0014500) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-03 10:56:51] (step=0014600) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 10:58:28] (step=0014700) Train Loss: -2.9864, Train Steps/Sec: 1.03 +[2026-02-03 11:00:06] (step=0014800) Train Loss: -2.9927, Train Steps/Sec: 1.02 +[2026-02-03 11:01:43] (step=0014900) Train Loss: -2.9881, Train Steps/Sec: 1.03 +[2026-02-03 11:03:20] (step=0015000) Train Loss: -2.9892, Train Steps/Sec: 1.03 +[2026-02-03 11:03:33] Beginning epoch 3... +[2026-02-03 11:05:01] (step=0015100) Train Loss: -2.9843, Train Steps/Sec: 1.00 +[2026-02-03 11:06:39] (step=0015200) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 11:08:16] (step=0015300) Train Loss: -2.9872, Train Steps/Sec: 1.03 +[2026-02-03 11:09:54] (step=0015400) Train Loss: -2.9896, Train Steps/Sec: 1.03 +[2026-02-03 11:11:32] (step=0015500) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 11:13:10] (step=0015600) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-03 11:14:48] (step=0015700) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 11:16:25] (step=0015800) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 11:18:04] (step=0015900) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 11:19:42] (step=0016000) Train Loss: -2.9963, Train Steps/Sec: 1.02 +[2026-02-03 11:21:19] (step=0016100) Train Loss: -2.9911, Train Steps/Sec: 1.02 +[2026-02-03 11:22:57] (step=0016200) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-03 11:24:35] (step=0016300) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 11:26:13] (step=0016400) Train Loss: -2.9870, Train Steps/Sec: 1.02 +[2026-02-03 11:27:51] (step=0016500) Train Loss: -2.9856, Train Steps/Sec: 1.02 +[2026-02-03 11:29:29] (step=0016600) Train Loss: -2.9835, Train Steps/Sec: 1.02 +[2026-02-03 11:31:07] (step=0016700) Train Loss: -2.9855, Train Steps/Sec: 1.02 +[2026-02-03 11:32:44] (step=0016800) Train Loss: -2.9885, Train Steps/Sec: 1.03 +[2026-02-03 11:34:21] (step=0016900) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 11:35:59] (step=0017000) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 11:37:38] (step=0017100) Train Loss: -2.9856, Train Steps/Sec: 1.02 +[2026-02-03 11:39:16] (step=0017200) Train Loss: -2.9916, Train Steps/Sec: 1.02 +[2026-02-03 11:40:54] (step=0017300) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 11:42:31] (step=0017400) Train Loss: -2.9858, Train Steps/Sec: 1.02 +[2026-02-03 11:44:10] (step=0017500) Train Loss: -2.9834, Train Steps/Sec: 1.02 +[2026-02-03 11:45:48] (step=0017600) Train Loss: -2.9826, Train Steps/Sec: 1.02 +[2026-02-03 11:47:22] (step=0017700) Train Loss: -2.9870, Train Steps/Sec: 1.06 +[2026-02-03 11:49:00] (step=0017800) Train Loss: -2.9945, Train Steps/Sec: 1.02 +[2026-02-03 11:50:38] (step=0017900) Train Loss: -2.9841, Train Steps/Sec: 1.02 +[2026-02-03 11:52:16] (step=0018000) Train Loss: -2.9945, Train Steps/Sec: 1.02 +[2026-02-03 11:53:53] (step=0018100) Train Loss: -2.9870, Train Steps/Sec: 1.02 +[2026-02-03 11:55:32] (step=0018200) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 11:57:10] (step=0018300) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 11:58:48] (step=0018400) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-03 12:00:26] (step=0018500) Train Loss: -2.9906, Train Steps/Sec: 1.02 +[2026-02-03 12:02:03] (step=0018600) Train Loss: -2.9869, Train Steps/Sec: 1.03 +[2026-02-03 12:03:41] (step=0018700) Train Loss: -2.9855, Train Steps/Sec: 1.03 +[2026-02-03 12:05:18] (step=0018800) Train Loss: -2.9842, Train Steps/Sec: 1.03 +[2026-02-03 12:06:56] (step=0018900) Train Loss: -2.9824, Train Steps/Sec: 1.03 +[2026-02-03 12:08:33] (step=0019000) Train Loss: -2.9857, Train Steps/Sec: 1.03 +[2026-02-03 12:10:11] (step=0019100) Train Loss: -2.9898, Train Steps/Sec: 1.02 +[2026-02-03 12:11:48] (step=0019200) Train Loss: -2.9880, Train Steps/Sec: 1.03 +[2026-02-03 12:13:26] (step=0019300) Train Loss: -2.9942, Train Steps/Sec: 1.03 +[2026-02-03 12:15:03] (step=0019400) Train Loss: -2.9905, Train Steps/Sec: 1.02 +[2026-02-03 12:16:41] (step=0019500) Train Loss: -2.9895, Train Steps/Sec: 1.02 +[2026-02-03 12:18:19] (step=0019600) Train Loss: -2.9856, Train Steps/Sec: 1.02 +[2026-02-03 12:19:57] (step=0019700) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 12:21:35] (step=0019800) Train Loss: -2.9845, Train Steps/Sec: 1.02 +[2026-02-03 12:23:12] (step=0019900) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 12:24:50] (step=0020000) Train Loss: -2.9927, Train Steps/Sec: 1.02 +[2026-02-03 12:25:06] Beginning epoch 4... +[2026-02-03 12:26:30] (step=0020100) Train Loss: -2.9944, Train Steps/Sec: 1.00 +[2026-02-03 12:28:07] (step=0020200) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-03 12:29:45] (step=0020300) Train Loss: -2.9873, Train Steps/Sec: 1.03 +[2026-02-03 12:31:23] (step=0020400) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 12:33:00] (step=0020500) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 12:34:38] (step=0020600) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 12:36:16] (step=0020700) Train Loss: -2.9836, Train Steps/Sec: 1.02 +[2026-02-03 12:37:54] (step=0020800) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 12:39:32] (step=0020900) Train Loss: -2.9839, Train Steps/Sec: 1.02 +[2026-02-03 12:41:10] (step=0021000) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 12:42:48] (step=0021100) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-03 12:44:25] (step=0021200) Train Loss: -2.9904, Train Steps/Sec: 1.03 +[2026-02-03 12:46:03] (step=0021300) Train Loss: -2.9917, Train Steps/Sec: 1.02 +[2026-02-03 12:47:41] (step=0021400) Train Loss: -2.9911, Train Steps/Sec: 1.02 +[2026-02-03 12:49:19] (step=0021500) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 12:50:57] (step=0021600) Train Loss: -2.9900, Train Steps/Sec: 1.02 +[2026-02-03 12:52:35] (step=0021700) Train Loss: -2.9857, Train Steps/Sec: 1.02 +[2026-02-03 12:54:13] (step=0021800) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 12:55:50] (step=0021900) Train Loss: -2.9898, Train Steps/Sec: 1.03 +[2026-02-03 12:57:28] (step=0022000) Train Loss: -2.9929, Train Steps/Sec: 1.02 +[2026-02-03 12:59:06] (step=0022100) Train Loss: -2.9851, Train Steps/Sec: 1.03 +[2026-02-03 13:00:44] (step=0022200) Train Loss: -2.9931, Train Steps/Sec: 1.02 +[2026-02-03 13:02:22] (step=0022300) Train Loss: -2.9841, Train Steps/Sec: 1.02 +[2026-02-03 13:03:59] (step=0022400) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 13:05:34] (step=0022500) Train Loss: -2.9891, Train Steps/Sec: 1.05 +[2026-02-03 13:07:12] (step=0022600) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-03 13:08:50] (step=0022700) Train Loss: -2.9920, Train Steps/Sec: 1.02 +[2026-02-03 13:10:28] (step=0022800) Train Loss: -2.9864, Train Steps/Sec: 1.03 +[2026-02-03 13:12:05] (step=0022900) Train Loss: -2.9827, Train Steps/Sec: 1.03 +[2026-02-03 13:13:43] (step=0023000) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 13:15:21] (step=0023100) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-03 13:16:59] (step=0023200) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 13:18:37] (step=0023300) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-03 13:20:15] (step=0023400) Train Loss: 78736.1641, Train Steps/Sec: 1.02 +[2026-02-03 13:21:53] (step=0023500) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 13:23:31] (step=0023600) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 13:25:09] (step=0023700) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 13:26:47] (step=0023800) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-03 13:28:25] (step=0023900) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 13:30:03] (step=0024000) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 13:31:41] (step=0024100) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 13:33:19] (step=0024200) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 13:34:56] (step=0024300) Train Loss: -2.9898, Train Steps/Sec: 1.02 +[2026-02-03 13:36:34] (step=0024400) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-03 13:38:12] (step=0024500) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 13:39:49] (step=0024600) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 13:41:27] (step=0024700) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 13:43:05] (step=0024800) Train Loss: -2.9900, Train Steps/Sec: 1.02 +[2026-02-03 13:44:43] (step=0024900) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 13:46:21] (step=0025000) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 13:46:22] Saved checkpoint to results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None/checkpoints/0025000.pt +[2026-02-03 13:46:42] Beginning epoch 5... +[2026-02-03 13:48:02] (step=0025100) Train Loss: -2.9903, Train Steps/Sec: 0.99 +[2026-02-03 13:49:31] Generating EMA samples... +[2026-02-03 13:49:39] (step=0025200) Train Loss: -2.9859, Train Steps/Sec: 1.03 +[2026-02-03 13:51:17] (step=0025300) Train Loss: -2.9951, Train Steps/Sec: 1.02 +[2026-02-03 13:52:55] (step=0025400) Train Loss: -2.9863, Train Steps/Sec: 1.02 +[2026-02-03 13:54:33] (step=0025500) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 13:56:10] (step=0025600) Train Loss: -2.9899, Train Steps/Sec: 1.03 +[2026-02-03 13:57:48] (step=0025700) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 13:59:26] (step=0025800) Train Loss: -2.9830, Train Steps/Sec: 1.02 +[2026-02-03 14:01:03] (step=0025900) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 14:02:40] (step=0026000) Train Loss: -2.9889, Train Steps/Sec: 1.03 +[2026-02-03 14:04:18] (step=0026100) Train Loss: -2.9845, Train Steps/Sec: 1.02 +[2026-02-03 14:05:56] (step=0026200) Train Loss: -2.9882, Train Steps/Sec: 1.02 +[2026-02-03 14:07:34] (step=0026300) Train Loss: -2.9906, Train Steps/Sec: 1.02 +[2026-02-03 14:09:12] (step=0026400) Train Loss: -2.9909, Train Steps/Sec: 1.02 +[2026-02-03 14:10:51] (step=0026500) Train Loss: -2.9932, Train Steps/Sec: 1.02 +[2026-02-03 14:12:28] (step=0026600) Train Loss: -2.9909, Train Steps/Sec: 1.02 +[2026-02-03 14:14:06] (step=0026700) Train Loss: -2.9931, Train Steps/Sec: 1.02 +[2026-02-03 14:15:43] (step=0026800) Train Loss: -2.9852, Train Steps/Sec: 1.03 +[2026-02-03 14:17:21] (step=0026900) Train Loss: -2.9819, Train Steps/Sec: 1.02 +[2026-02-03 14:18:59] (step=0027000) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 14:20:37] (step=0027100) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 14:22:15] (step=0027200) Train Loss: -2.9917, Train Steps/Sec: 1.02 +[2026-02-03 14:23:50] (step=0027300) Train Loss: -2.9884, Train Steps/Sec: 1.06 +[2026-02-03 14:25:27] (step=0027400) Train Loss: -2.9851, Train Steps/Sec: 1.02 +[2026-02-03 14:27:05] (step=0027500) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 14:28:43] (step=0027600) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 14:30:21] (step=0027700) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 14:31:59] (step=0027800) Train Loss: -2.9925, Train Steps/Sec: 1.02 +[2026-02-03 14:33:36] (step=0027900) Train Loss: -2.9893, Train Steps/Sec: 1.03 +[2026-02-03 14:35:13] (step=0028000) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-03 14:36:51] (step=0028100) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 14:38:29] (step=0028200) Train Loss: -2.9917, Train Steps/Sec: 1.02 +[2026-02-03 14:40:07] (step=0028300) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-03 14:41:46] (step=0028400) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 14:43:23] (step=0028500) Train Loss: -2.9831, Train Steps/Sec: 1.02 +[2026-02-03 14:45:02] (step=0028600) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 14:46:39] (step=0028700) Train Loss: -2.9925, Train Steps/Sec: 1.02 +[2026-02-03 14:48:17] (step=0028800) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 14:49:55] (step=0028900) Train Loss: -2.9915, Train Steps/Sec: 1.02 +[2026-02-03 14:51:33] (step=0029000) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 14:53:10] (step=0029100) Train Loss: -2.9889, Train Steps/Sec: 1.03 +[2026-02-03 14:54:48] (step=0029200) Train Loss: -2.9917, Train Steps/Sec: 1.03 +[2026-02-03 14:56:26] (step=0029300) Train Loss: -2.9851, Train Steps/Sec: 1.02 +[2026-02-03 14:58:04] (step=0029400) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-03 14:59:42] (step=0029500) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 15:01:20] (step=0029600) Train Loss: -2.9896, Train Steps/Sec: 1.03 +[2026-02-03 15:02:57] (step=0029700) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-03 15:04:35] (step=0029800) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 15:06:13] (step=0029900) Train Loss: -2.9887, Train Steps/Sec: 1.03 +[2026-02-03 15:07:51] (step=0030000) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 15:08:15] Beginning epoch 6... +[2026-02-03 15:09:31] (step=0030100) Train Loss: -2.9889, Train Steps/Sec: 1.00 +[2026-02-03 15:11:09] (step=0030200) Train Loss: -2.9905, Train Steps/Sec: 1.02 +[2026-02-03 15:12:47] (step=0030300) Train Loss: -2.9833, Train Steps/Sec: 1.02 +[2026-02-03 15:14:24] (step=0030400) Train Loss: -2.9880, Train Steps/Sec: 1.03 +[2026-02-03 15:16:02] (step=0030500) Train Loss: -2.9881, Train Steps/Sec: 1.03 +[2026-02-03 15:17:39] (step=0030600) Train Loss: -2.9924, Train Steps/Sec: 1.02 +[2026-02-03 15:19:17] (step=0030700) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 15:20:56] (step=0030800) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 15:22:33] (step=0030900) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 15:24:11] (step=0031000) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 15:25:49] (step=0031100) Train Loss: -2.9836, Train Steps/Sec: 1.02 +[2026-02-03 15:27:27] (step=0031200) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 15:29:05] (step=0031300) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-03 15:30:42] (step=0031400) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 15:32:21] (step=0031500) Train Loss: -2.9930, Train Steps/Sec: 1.02 +[2026-02-03 15:33:59] (step=0031600) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 15:35:36] (step=0031700) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 15:37:14] (step=0031800) Train Loss: -2.9910, Train Steps/Sec: 1.02 +[2026-02-03 15:38:52] (step=0031900) Train Loss: -2.9891, Train Steps/Sec: 1.03 +[2026-02-03 15:40:29] (step=0032000) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-03 15:42:04] (step=0032100) Train Loss: -2.9853, Train Steps/Sec: 1.06 +[2026-02-03 15:43:42] (step=0032200) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 15:45:20] (step=0032300) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 15:46:58] (step=0032400) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 15:48:36] (step=0032500) Train Loss: -2.9862, Train Steps/Sec: 1.02 +[2026-02-03 15:50:14] (step=0032600) Train Loss: -2.9863, Train Steps/Sec: 1.02 +[2026-02-03 15:51:52] (step=0032700) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 15:53:30] (step=0032800) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 15:55:08] (step=0032900) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 15:56:46] (step=0033000) Train Loss: -2.9921, Train Steps/Sec: 1.02 +[2026-02-03 15:58:24] (step=0033100) Train Loss: -2.9846, Train Steps/Sec: 1.02 +[2026-02-03 16:00:02] (step=0033200) Train Loss: -2.9850, Train Steps/Sec: 1.02 +[2026-02-03 16:01:39] (step=0033300) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 16:03:17] (step=0033400) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 16:04:55] (step=0033500) Train Loss: -2.9814, Train Steps/Sec: 1.02 +[2026-02-03 16:06:33] (step=0033600) Train Loss: -2.9837, Train Steps/Sec: 1.02 +[2026-02-03 16:08:11] (step=0033700) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 16:09:49] (step=0033800) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 16:11:27] (step=0033900) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 16:13:05] (step=0034000) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 16:14:43] (step=0034100) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-03 16:16:20] (step=0034200) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 16:17:58] (step=0034300) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 16:19:36] (step=0034400) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 16:21:14] (step=0034500) Train Loss: -2.9865, Train Steps/Sec: 1.03 +[2026-02-03 16:22:51] (step=0034600) Train Loss: -2.9868, Train Steps/Sec: 1.03 +[2026-02-03 16:24:29] (step=0034700) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-03 16:26:07] (step=0034800) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-03 16:27:45] (step=0034900) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-03 16:29:22] (step=0035000) Train Loss: -2.9874, Train Steps/Sec: 1.03 +[2026-02-03 16:29:50] Beginning epoch 7... +[2026-02-03 16:31:03] (step=0035100) Train Loss: -2.9881, Train Steps/Sec: 1.00 +[2026-02-03 16:32:41] (step=0035200) Train Loss: -2.9922, Train Steps/Sec: 1.02 +[2026-02-03 16:34:19] (step=0035300) Train Loss: -2.9847, Train Steps/Sec: 1.02 +[2026-02-03 16:35:56] (step=0035400) Train Loss: -2.9855, Train Steps/Sec: 1.03 +[2026-02-03 16:37:35] (step=0035500) Train Loss: -2.9950, Train Steps/Sec: 1.02 +[2026-02-03 16:39:12] (step=0035600) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 16:40:50] (step=0035700) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 16:42:28] (step=0035800) Train Loss: -2.9853, Train Steps/Sec: 1.02 +[2026-02-03 16:44:06] (step=0035900) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-03 16:45:44] (step=0036000) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 16:47:21] (step=0036100) Train Loss: -2.9906, Train Steps/Sec: 1.02 +[2026-02-03 16:48:59] (step=0036200) Train Loss: -2.9868, Train Steps/Sec: 1.03 +[2026-02-03 16:50:37] (step=0036300) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 16:52:14] (step=0036400) Train Loss: -2.9840, Train Steps/Sec: 1.03 +[2026-02-03 16:53:52] (step=0036500) Train Loss: -2.9848, Train Steps/Sec: 1.02 +[2026-02-03 16:55:30] (step=0036600) Train Loss: -2.9852, Train Steps/Sec: 1.02 +[2026-02-03 16:57:08] (step=0036700) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 16:58:46] (step=0036800) Train Loss: -2.9842, Train Steps/Sec: 1.02 +[2026-02-03 17:00:21] (step=0036900) Train Loss: -2.9858, Train Steps/Sec: 1.05 +[2026-02-03 17:01:58] (step=0037000) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 17:03:36] (step=0037100) Train Loss: -2.9892, Train Steps/Sec: 1.02 +[2026-02-03 17:05:14] (step=0037200) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-03 17:06:52] (step=0037300) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-03 17:08:30] (step=0037400) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 17:10:08] (step=0037500) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-03 17:11:46] (step=0037600) Train Loss: -2.9898, Train Steps/Sec: 1.02 +[2026-02-03 17:13:24] (step=0037700) Train Loss: -2.9885, Train Steps/Sec: 1.03 +[2026-02-03 17:15:02] (step=0037800) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-03 17:16:40] (step=0037900) Train Loss: -2.9863, Train Steps/Sec: 1.02 +[2026-02-03 17:18:18] (step=0038000) Train Loss: -2.9926, Train Steps/Sec: 1.02 +[2026-02-03 17:19:56] (step=0038100) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 17:21:34] (step=0038200) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 17:23:12] (step=0038300) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 17:24:50] (step=0038400) Train Loss: -2.9900, Train Steps/Sec: 1.02 +[2026-02-03 17:26:27] (step=0038500) Train Loss: -2.9831, Train Steps/Sec: 1.02 +[2026-02-03 17:28:05] (step=0038600) Train Loss: -2.9854, Train Steps/Sec: 1.03 +[2026-02-03 17:29:43] (step=0038700) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 17:31:21] (step=0038800) Train Loss: -2.9871, Train Steps/Sec: 1.02 +[2026-02-03 17:32:59] (step=0038900) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 17:34:37] (step=0039000) Train Loss: -2.9915, Train Steps/Sec: 1.03 +[2026-02-03 17:36:15] (step=0039100) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-03 17:37:53] (step=0039200) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-03 17:39:30] (step=0039300) Train Loss: -2.9894, Train Steps/Sec: 1.03 +[2026-02-03 17:41:08] (step=0039400) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 17:42:46] (step=0039500) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 17:44:24] (step=0039600) Train Loss: -2.9895, Train Steps/Sec: 1.02 +[2026-02-03 17:46:02] (step=0039700) Train Loss: -2.9927, Train Steps/Sec: 1.02 +[2026-02-03 17:47:39] (step=0039800) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-03 17:49:17] (step=0039900) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 17:50:55] (step=0040000) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-03 17:51:27] Beginning epoch 8... +[2026-02-03 17:52:35] (step=0040100) Train Loss: -2.9853, Train Steps/Sec: 1.00 +[2026-02-03 17:54:13] (step=0040200) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 17:55:51] (step=0040300) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 17:57:29] (step=0040400) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 17:59:06] (step=0040500) Train Loss: -2.9873, Train Steps/Sec: 1.03 +[2026-02-03 18:00:44] (step=0040600) Train Loss: -2.9879, Train Steps/Sec: 1.03 +[2026-02-03 18:02:22] (step=0040700) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 18:04:00] (step=0040800) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 18:05:38] (step=0040900) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 18:07:16] (step=0041000) Train Loss: -2.9832, Train Steps/Sec: 1.02 +[2026-02-03 18:08:54] (step=0041100) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 18:10:31] (step=0041200) Train Loss: -2.9928, Train Steps/Sec: 1.02 +[2026-02-03 18:12:09] (step=0041300) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-03 18:13:47] (step=0041400) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 18:15:25] (step=0041500) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 18:17:00] (step=0041600) Train Loss: -2.9869, Train Steps/Sec: 1.06 +[2026-02-03 18:18:37] (step=0041700) Train Loss: -2.9919, Train Steps/Sec: 1.03 +[2026-02-03 18:20:15] (step=0041800) Train Loss: -2.9858, Train Steps/Sec: 1.02 +[2026-02-03 18:21:53] (step=0041900) Train Loss: -2.9856, Train Steps/Sec: 1.02 +[2026-02-03 18:23:30] (step=0042000) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-03 18:25:08] (step=0042100) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 18:26:46] (step=0042200) Train Loss: -2.9826, Train Steps/Sec: 1.03 +[2026-02-03 18:28:23] (step=0042300) Train Loss: -2.9954, Train Steps/Sec: 1.03 +[2026-02-03 18:30:00] (step=0042400) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 18:31:38] (step=0042500) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-03 18:33:16] (step=0042600) Train Loss: -2.9924, Train Steps/Sec: 1.02 +[2026-02-03 18:34:54] (step=0042700) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-03 18:36:32] (step=0042800) Train Loss: -2.9871, Train Steps/Sec: 1.02 +[2026-02-03 18:38:10] (step=0042900) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-03 18:39:48] (step=0043000) Train Loss: -2.9860, Train Steps/Sec: 1.02 +[2026-02-03 18:41:25] (step=0043100) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 18:43:03] (step=0043200) Train Loss: -2.9875, Train Steps/Sec: 1.03 +[2026-02-03 18:44:41] (step=0043300) Train Loss: -2.9882, Train Steps/Sec: 1.02 +[2026-02-03 18:46:19] (step=0043400) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 18:47:57] (step=0043500) Train Loss: -2.9948, Train Steps/Sec: 1.02 +[2026-02-03 18:49:34] (step=0043600) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 18:51:12] (step=0043700) Train Loss: -2.9846, Train Steps/Sec: 1.02 +[2026-02-03 18:52:51] (step=0043800) Train Loss: -2.9913, Train Steps/Sec: 1.01 +[2026-02-03 18:54:29] (step=0043900) Train Loss: -2.9863, Train Steps/Sec: 1.02 +[2026-02-03 18:56:07] (step=0044000) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 18:57:45] (step=0044100) Train Loss: -2.9862, Train Steps/Sec: 1.02 +[2026-02-03 18:59:23] (step=0044200) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 19:01:00] (step=0044300) Train Loss: -2.9850, Train Steps/Sec: 1.03 +[2026-02-03 19:02:38] (step=0044400) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-03 19:04:16] (step=0044500) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-03 19:05:53] (step=0044600) Train Loss: -2.9881, Train Steps/Sec: 1.03 +[2026-02-03 19:07:31] (step=0044700) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 19:09:09] (step=0044800) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 19:10:47] (step=0044900) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 19:12:25] (step=0045000) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-03 19:13:00] Beginning epoch 9... +[2026-02-03 19:14:05] (step=0045100) Train Loss: -2.9923, Train Steps/Sec: 1.00 +[2026-02-03 19:15:43] (step=0045200) Train Loss: -2.9882, Train Steps/Sec: 1.02 +[2026-02-03 19:17:21] (step=0045300) Train Loss: -2.9932, Train Steps/Sec: 1.02 +[2026-02-03 19:18:59] (step=0045400) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-03 19:20:37] (step=0045500) Train Loss: -2.9825, Train Steps/Sec: 1.02 +[2026-02-03 19:22:15] (step=0045600) Train Loss: -2.9882, Train Steps/Sec: 1.02 +[2026-02-03 19:23:54] (step=0045700) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 19:25:31] (step=0045800) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-03 19:27:09] (step=0045900) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-03 19:28:47] (step=0046000) Train Loss: -2.9868, Train Steps/Sec: 1.03 +[2026-02-03 19:30:25] (step=0046100) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 19:32:03] (step=0046200) Train Loss: -2.9925, Train Steps/Sec: 1.02 +[2026-02-03 19:33:40] (step=0046300) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 19:35:15] (step=0046400) Train Loss: -2.9863, Train Steps/Sec: 1.06 +[2026-02-03 19:36:53] (step=0046500) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 19:38:30] (step=0046600) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 19:40:08] (step=0046700) Train Loss: -2.9853, Train Steps/Sec: 1.02 +[2026-02-03 19:41:47] (step=0046800) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 19:43:25] (step=0046900) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 19:45:03] (step=0047000) Train Loss: -2.9863, Train Steps/Sec: 1.02 +[2026-02-03 19:46:40] (step=0047100) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 19:48:18] (step=0047200) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 19:49:56] (step=0047300) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 19:51:34] (step=0047400) Train Loss: -2.9927, Train Steps/Sec: 1.02 +[2026-02-03 19:53:12] (step=0047500) Train Loss: -2.9912, Train Steps/Sec: 1.02 +[2026-02-03 19:54:50] (step=0047600) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 19:56:28] (step=0047700) Train Loss: -2.9892, Train Steps/Sec: 1.02 +[2026-02-03 19:58:06] (step=0047800) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 19:59:44] (step=0047900) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 20:01:22] (step=0048000) Train Loss: -2.9915, Train Steps/Sec: 1.02 +[2026-02-03 20:03:00] (step=0048100) Train Loss: -2.9858, Train Steps/Sec: 1.02 +[2026-02-03 20:04:38] (step=0048200) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-03 20:06:16] (step=0048300) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 20:07:54] (step=0048400) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-03 20:09:32] (step=0048500) Train Loss: -2.9878, Train Steps/Sec: 1.02 +[2026-02-03 20:11:10] (step=0048600) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 20:12:48] (step=0048700) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 20:14:26] (step=0048800) Train Loss: -2.9853, Train Steps/Sec: 1.02 +[2026-02-03 20:16:04] (step=0048900) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 20:17:41] (step=0049000) Train Loss: -2.9853, Train Steps/Sec: 1.02 +[2026-02-03 20:19:19] (step=0049100) Train Loss: -2.9849, Train Steps/Sec: 1.02 +[2026-02-03 20:20:57] (step=0049200) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-03 20:22:35] (step=0049300) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-03 20:24:12] (step=0049400) Train Loss: -2.9888, Train Steps/Sec: 1.03 +[2026-02-03 20:25:51] (step=0049500) Train Loss: -2.9911, Train Steps/Sec: 1.02 +[2026-02-03 20:27:29] (step=0049600) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 20:29:07] (step=0049700) Train Loss: -2.9921, Train Steps/Sec: 1.02 +[2026-02-03 20:30:45] (step=0049800) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 20:32:23] (step=0049900) Train Loss: -2.9805, Train Steps/Sec: 1.02 +[2026-02-03 20:34:01] (step=0050000) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 20:34:02] Saved checkpoint to results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None/checkpoints/0050000.pt +[2026-02-03 20:34:41] Beginning epoch 10... +[2026-02-03 20:35:42] (step=0050100) Train Loss: -2.9896, Train Steps/Sec: 0.99 +[2026-02-03 20:37:20] (step=0050200) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 20:38:58] (step=0050300) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 20:40:20] Generating EMA samples... +[2026-02-03 20:40:35] (step=0050400) Train Loss: -2.9846, Train Steps/Sec: 1.03 +[2026-02-03 20:42:13] (step=0050500) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-03 20:43:51] (step=0050600) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 20:45:29] (step=0050700) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 20:47:07] (step=0050800) Train Loss: -2.9852, Train Steps/Sec: 1.02 +[2026-02-03 20:48:44] (step=0050900) Train Loss: -2.9878, Train Steps/Sec: 1.03 +[2026-02-03 20:50:21] (step=0051000) Train Loss: -2.9897, Train Steps/Sec: 1.03 +[2026-02-03 20:52:00] (step=0051100) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 20:53:34] (step=0051200) Train Loss: -2.9882, Train Steps/Sec: 1.06 +[2026-02-03 20:55:12] (step=0051300) Train Loss: -2.9856, Train Steps/Sec: 1.02 +[2026-02-03 20:56:50] (step=0051400) Train Loss: -2.9870, Train Steps/Sec: 1.02 +[2026-02-03 20:58:28] (step=0051500) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-03 21:00:06] (step=0051600) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-03 21:01:44] (step=0051700) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 21:03:22] (step=0051800) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 21:04:59] (step=0051900) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 21:06:38] (step=0052000) Train Loss: -2.9898, Train Steps/Sec: 1.02 +[2026-02-03 21:08:16] (step=0052100) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-03 21:09:54] (step=0052200) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 21:11:31] (step=0052300) Train Loss: -2.9868, Train Steps/Sec: 1.03 +[2026-02-03 21:13:09] (step=0052400) Train Loss: -2.9857, Train Steps/Sec: 1.02 +[2026-02-03 21:14:47] (step=0052500) Train Loss: -2.9898, Train Steps/Sec: 1.03 +[2026-02-03 21:16:25] (step=0052600) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 21:18:03] (step=0052700) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-03 21:19:40] (step=0052800) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 21:21:18] (step=0052900) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 21:22:56] (step=0053000) Train Loss: -2.9871, Train Steps/Sec: 1.02 +[2026-02-03 21:24:34] (step=0053100) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-03 21:26:12] (step=0053200) Train Loss: -2.9921, Train Steps/Sec: 1.02 +[2026-02-03 21:27:49] (step=0053300) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 21:29:27] (step=0053400) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-03 21:31:05] (step=0053500) Train Loss: -2.9909, Train Steps/Sec: 1.02 +[2026-02-03 21:32:43] (step=0053600) Train Loss: -2.9928, Train Steps/Sec: 1.02 +[2026-02-03 21:34:21] (step=0053700) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-03 21:35:59] (step=0053800) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 21:37:37] (step=0053900) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-03 21:39:15] (step=0054000) Train Loss: -2.9844, Train Steps/Sec: 1.02 +[2026-02-03 21:40:53] (step=0054100) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 21:42:31] (step=0054200) Train Loss: -2.9911, Train Steps/Sec: 1.02 +[2026-02-03 21:44:09] (step=0054300) Train Loss: -2.9915, Train Steps/Sec: 1.02 +[2026-02-03 21:45:47] (step=0054400) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-03 21:47:24] (step=0054500) Train Loss: -2.9854, Train Steps/Sec: 1.03 +[2026-02-03 21:49:02] (step=0054600) Train Loss: -2.9923, Train Steps/Sec: 1.02 +[2026-02-03 21:50:39] (step=0054700) Train Loss: -2.9864, Train Steps/Sec: 1.03 +[2026-02-03 21:52:17] (step=0054800) Train Loss: -2.9826, Train Steps/Sec: 1.02 +[2026-02-03 21:53:55] (step=0054900) Train Loss: -2.9858, Train Steps/Sec: 1.02 +[2026-02-03 21:55:33] (step=0055000) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-03 21:56:16] Beginning epoch 11... +[2026-02-03 21:57:13] (step=0055100) Train Loss: -2.9926, Train Steps/Sec: 1.00 +[2026-02-03 21:58:50] (step=0055200) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-03 22:00:28] (step=0055300) Train Loss: -2.9910, Train Steps/Sec: 1.02 +[2026-02-03 22:02:06] (step=0055400) Train Loss: -2.9851, Train Steps/Sec: 1.02 +[2026-02-03 22:03:44] (step=0055500) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-03 22:05:22] (step=0055600) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 22:07:00] (step=0055700) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-03 22:08:37] (step=0055800) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 22:10:15] (step=0055900) Train Loss: -2.9909, Train Steps/Sec: 1.02 +[2026-02-03 22:11:50] (step=0056000) Train Loss: -2.9884, Train Steps/Sec: 1.06 +[2026-02-03 22:13:28] (step=0056100) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-03 22:15:05] (step=0056200) Train Loss: -2.9904, Train Steps/Sec: 1.03 +[2026-02-03 22:16:43] (step=0056300) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 22:18:21] (step=0056400) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 22:19:59] (step=0056500) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 22:21:37] (step=0056600) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 22:23:15] (step=0056700) Train Loss: -2.9888, Train Steps/Sec: 1.02 +[2026-02-03 22:24:53] (step=0056800) Train Loss: -2.9846, Train Steps/Sec: 1.02 +[2026-02-03 22:26:32] (step=0056900) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 22:28:10] (step=0057000) Train Loss: -2.9846, Train Steps/Sec: 1.02 +[2026-02-03 22:29:48] (step=0057100) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-03 22:31:26] (step=0057200) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-03 22:33:04] (step=0057300) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-03 22:34:41] (step=0057400) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-03 22:36:19] (step=0057500) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 22:37:57] (step=0057600) Train Loss: -2.9872, Train Steps/Sec: 1.03 +[2026-02-03 22:39:34] (step=0057700) Train Loss: -2.9902, Train Steps/Sec: 1.03 +[2026-02-03 22:41:12] (step=0057800) Train Loss: -2.9949, Train Steps/Sec: 1.02 +[2026-02-03 22:42:50] (step=0057900) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-03 22:44:28] (step=0058000) Train Loss: -2.9903, Train Steps/Sec: 1.02 +[2026-02-03 22:46:06] (step=0058100) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-03 22:47:44] (step=0058200) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-03 22:49:22] (step=0058300) Train Loss: -2.9900, Train Steps/Sec: 1.02 +[2026-02-03 22:50:59] (step=0058400) Train Loss: -2.9865, Train Steps/Sec: 1.03 +[2026-02-03 22:52:37] (step=0058500) Train Loss: -2.9851, Train Steps/Sec: 1.02 +[2026-02-03 22:54:15] (step=0058600) Train Loss: -2.9861, Train Steps/Sec: 1.01 +[2026-02-03 22:55:53] (step=0058700) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-03 22:57:31] (step=0058800) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-03 22:59:09] (step=0058900) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 23:00:47] (step=0059000) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 23:02:25] (step=0059100) Train Loss: -2.9920, Train Steps/Sec: 1.02 +[2026-02-03 23:04:03] (step=0059200) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-03 23:05:41] (step=0059300) Train Loss: -2.9895, Train Steps/Sec: 1.02 +[2026-02-03 23:07:19] (step=0059400) Train Loss: -2.9911, Train Steps/Sec: 1.02 +[2026-02-03 23:08:57] (step=0059500) Train Loss: -2.9857, Train Steps/Sec: 1.02 +[2026-02-03 23:10:34] (step=0059600) Train Loss: -2.9925, Train Steps/Sec: 1.03 +[2026-02-03 23:12:12] (step=0059700) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-03 23:13:50] (step=0059800) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-03 23:15:28] (step=0059900) Train Loss: -2.9914, Train Steps/Sec: 1.02 +[2026-02-03 23:17:06] (step=0060000) Train Loss: -2.9892, Train Steps/Sec: 1.02 +[2026-02-03 23:17:53] Beginning epoch 12... +[2026-02-03 23:18:45] (step=0060100) Train Loss: -2.9931, Train Steps/Sec: 1.00 +[2026-02-03 23:20:23] (step=0060200) Train Loss: -2.9852, Train Steps/Sec: 1.02 +[2026-02-03 23:22:01] (step=0060300) Train Loss: -2.9839, Train Steps/Sec: 1.02 +[2026-02-03 23:23:39] (step=0060400) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-03 23:25:17] (step=0060500) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-03 23:26:55] (step=0060600) Train Loss: -2.9869, Train Steps/Sec: 1.03 +[2026-02-03 23:28:33] (step=0060700) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-03 23:30:07] (step=0060800) Train Loss: -2.9867, Train Steps/Sec: 1.06 +[2026-02-03 23:31:45] (step=0060900) Train Loss: -2.9912, Train Steps/Sec: 1.02 +[2026-02-03 23:33:23] (step=0061000) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-03 23:35:01] (step=0061100) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-03 23:36:39] (step=0061200) Train Loss: -2.9844, Train Steps/Sec: 1.02 +[2026-02-03 23:38:17] (step=0061300) Train Loss: -2.9937, Train Steps/Sec: 1.02 +[2026-02-03 23:39:55] (step=0061400) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-03 23:41:33] (step=0061500) Train Loss: -2.9898, Train Steps/Sec: 1.02 +[2026-02-03 23:43:10] (step=0061600) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 23:44:48] (step=0061700) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-03 23:46:25] (step=0061800) Train Loss: -2.9888, Train Steps/Sec: 1.03 +[2026-02-03 23:48:03] (step=0061900) Train Loss: -2.9867, Train Steps/Sec: 1.03 +[2026-02-03 23:49:41] (step=0062000) Train Loss: -2.9901, Train Steps/Sec: 1.02 +[2026-02-03 23:51:19] (step=0062100) Train Loss: -2.9850, Train Steps/Sec: 1.02 +[2026-02-03 23:52:56] (step=0062200) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-03 23:54:34] (step=0062300) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-03 23:56:12] (step=0062400) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-03 23:57:50] (step=0062500) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-03 23:59:28] (step=0062600) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-04 00:01:06] (step=0062700) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-04 00:02:44] (step=0062800) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-04 00:04:21] (step=0062900) Train Loss: -2.9891, Train Steps/Sec: 1.03 +[2026-02-04 00:05:58] (step=0063000) Train Loss: -2.9885, Train Steps/Sec: 1.03 +[2026-02-04 00:07:36] (step=0063100) Train Loss: -2.9878, Train Steps/Sec: 1.02 +[2026-02-04 00:09:14] (step=0063200) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-04 00:10:52] (step=0063300) Train Loss: -2.9942, Train Steps/Sec: 1.02 +[2026-02-04 00:12:30] (step=0063400) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-04 00:14:07] (step=0063500) Train Loss: -2.9898, Train Steps/Sec: 1.03 +[2026-02-04 00:15:46] (step=0063600) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-04 00:17:24] (step=0063700) Train Loss: -2.9892, Train Steps/Sec: 1.02 +[2026-02-04 00:19:02] (step=0063800) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-04 00:20:39] (step=0063900) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-04 00:22:18] (step=0064000) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 00:23:55] (step=0064100) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-04 00:25:33] (step=0064200) Train Loss: -2.9838, Train Steps/Sec: 1.02 +[2026-02-04 00:27:11] (step=0064300) Train Loss: -2.9857, Train Steps/Sec: 1.02 +[2026-02-04 00:28:49] (step=0064400) Train Loss: -2.9905, Train Steps/Sec: 1.03 +[2026-02-04 00:30:26] (step=0064500) Train Loss: -2.9910, Train Steps/Sec: 1.02 +[2026-02-04 00:32:05] (step=0064600) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-04 00:33:42] (step=0064700) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-04 00:35:20] (step=0064800) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-04 00:36:58] (step=0064900) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-04 00:38:36] (step=0065000) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-04 00:39:28] Beginning epoch 13... +[2026-02-04 00:40:16] (step=0065100) Train Loss: -2.9899, Train Steps/Sec: 1.00 +[2026-02-04 00:41:54] (step=0065200) Train Loss: -2.9946, Train Steps/Sec: 1.02 +[2026-02-04 00:43:32] (step=0065300) Train Loss: -2.9928, Train Steps/Sec: 1.02 +[2026-02-04 00:45:10] (step=0065400) Train Loss: -2.9897, Train Steps/Sec: 1.02 +[2026-02-04 00:46:46] (step=0065500) Train Loss: -2.9877, Train Steps/Sec: 1.05 +[2026-02-04 00:48:22] (step=0065600) Train Loss: -2.9892, Train Steps/Sec: 1.03 +[2026-02-04 00:50:00] (step=0065700) Train Loss: -2.9847, Train Steps/Sec: 1.02 +[2026-02-04 00:51:38] (step=0065800) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-04 00:53:16] (step=0065900) Train Loss: -2.9838, Train Steps/Sec: 1.03 +[2026-02-04 00:54:54] (step=0066000) Train Loss: -2.9848, Train Steps/Sec: 1.02 +[2026-02-04 00:56:31] (step=0066100) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-04 00:58:08] (step=0066200) Train Loss: -2.9903, Train Steps/Sec: 1.03 +[2026-02-04 00:59:46] (step=0066300) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 01:01:24] (step=0066400) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-04 01:03:02] (step=0066500) Train Loss: -2.9850, Train Steps/Sec: 1.03 +[2026-02-04 01:04:40] (step=0066600) Train Loss: -2.9870, Train Steps/Sec: 1.02 +[2026-02-04 01:06:18] (step=0066700) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-04 01:07:56] (step=0066800) Train Loss: -2.9895, Train Steps/Sec: 1.02 +[2026-02-04 01:09:34] (step=0066900) Train Loss: -2.9862, Train Steps/Sec: 1.02 +[2026-02-04 01:11:11] (step=0067000) Train Loss: -2.9913, Train Steps/Sec: 1.03 +[2026-02-04 01:12:48] (step=0067100) Train Loss: -2.9877, Train Steps/Sec: 1.03 +[2026-02-04 01:14:26] (step=0067200) Train Loss: -2.9923, Train Steps/Sec: 1.03 +[2026-02-04 01:16:04] (step=0067300) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-04 01:17:42] (step=0067400) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-04 01:19:20] (step=0067500) Train Loss: -2.9905, Train Steps/Sec: 1.02 +[2026-02-04 01:20:58] (step=0067600) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-04 01:22:36] (step=0067700) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-04 01:24:14] (step=0067800) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-04 01:25:52] (step=0067900) Train Loss: -2.9875, Train Steps/Sec: 1.03 +[2026-02-04 01:27:29] (step=0068000) Train Loss: -2.9834, Train Steps/Sec: 1.02 +[2026-02-04 01:29:07] (step=0068100) Train Loss: -2.9885, Train Steps/Sec: 1.02 +[2026-02-04 01:30:45] (step=0068200) Train Loss: -2.9882, Train Steps/Sec: 1.02 +[2026-02-04 01:32:22] (step=0068300) Train Loss: -2.9922, Train Steps/Sec: 1.03 +[2026-02-04 01:34:01] (step=0068400) Train Loss: -2.9823, Train Steps/Sec: 1.02 +[2026-02-04 01:35:38] (step=0068500) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-04 01:37:15] (step=0068600) Train Loss: -2.9938, Train Steps/Sec: 1.03 +[2026-02-04 01:38:53] (step=0068700) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-04 01:40:31] (step=0068800) Train Loss: -2.9893, Train Steps/Sec: 1.02 +[2026-02-04 01:42:09] (step=0068900) Train Loss: -2.9892, Train Steps/Sec: 1.02 +[2026-02-04 01:43:47] (step=0069000) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-04 01:45:25] (step=0069100) Train Loss: -2.9871, Train Steps/Sec: 1.02 +[2026-02-04 01:47:02] (step=0069200) Train Loss: -2.9910, Train Steps/Sec: 1.03 +[2026-02-04 01:48:40] (step=0069300) Train Loss: -2.9894, Train Steps/Sec: 1.03 +[2026-02-04 01:50:17] (step=0069400) Train Loss: -2.9837, Train Steps/Sec: 1.02 +[2026-02-04 01:51:55] (step=0069500) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-04 01:53:33] (step=0069600) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 01:55:11] (step=0069700) Train Loss: -2.9852, Train Steps/Sec: 1.03 +[2026-02-04 01:56:49] (step=0069800) Train Loss: -2.9926, Train Steps/Sec: 1.02 +[2026-02-04 01:58:27] (step=0069900) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-04 02:00:05] (step=0070000) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-04 02:01:01] Beginning epoch 14... +[2026-02-04 02:01:45] (step=0070100) Train Loss: -2.9858, Train Steps/Sec: 1.00 +[2026-02-04 02:03:23] (step=0070200) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-04 02:04:57] (step=0070300) Train Loss: -2.9891, Train Steps/Sec: 1.06 +[2026-02-04 02:06:35] (step=0070400) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-04 02:08:13] (step=0070500) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-04 02:09:51] (step=0070600) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-04 02:11:29] (step=0070700) Train Loss: -2.9853, Train Steps/Sec: 1.02 +[2026-02-04 02:13:06] (step=0070800) Train Loss: -2.9915, Train Steps/Sec: 1.02 +[2026-02-04 02:14:44] (step=0070900) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-04 02:16:22] (step=0071000) Train Loss: -2.9910, Train Steps/Sec: 1.02 +[2026-02-04 02:18:00] (step=0071100) Train Loss: -2.9909, Train Steps/Sec: 1.02 +[2026-02-04 02:19:37] (step=0071200) Train Loss: -2.9857, Train Steps/Sec: 1.03 +[2026-02-04 02:21:15] (step=0071300) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-04 02:22:52] (step=0071400) Train Loss: -2.9858, Train Steps/Sec: 1.03 +[2026-02-04 02:24:30] (step=0071500) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-04 02:26:08] (step=0071600) Train Loss: -2.9936, Train Steps/Sec: 1.02 +[2026-02-04 02:27:46] (step=0071700) Train Loss: -2.9813, Train Steps/Sec: 1.02 +[2026-02-04 02:29:24] (step=0071800) Train Loss: -2.9841, Train Steps/Sec: 1.02 +[2026-02-04 02:31:01] (step=0071900) Train Loss: -2.9900, Train Steps/Sec: 1.03 +[2026-02-04 02:32:39] (step=0072000) Train Loss: -2.9901, Train Steps/Sec: 1.03 +[2026-02-04 02:34:16] (step=0072100) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-04 02:35:54] (step=0072200) Train Loss: -2.9852, Train Steps/Sec: 1.03 +[2026-02-04 02:37:32] (step=0072300) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-04 02:39:10] (step=0072400) Train Loss: -2.9919, Train Steps/Sec: 1.02 +[2026-02-04 02:40:48] (step=0072500) Train Loss: -2.9843, Train Steps/Sec: 1.02 +[2026-02-04 02:42:26] (step=0072600) Train Loss: -2.9850, Train Steps/Sec: 1.02 +[2026-02-04 02:44:04] (step=0072700) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-04 02:45:42] (step=0072800) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-04 02:47:19] (step=0072900) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-04 02:48:57] (step=0073000) Train Loss: -2.9901, Train Steps/Sec: 1.03 +[2026-02-04 02:50:35] (step=0073100) Train Loss: -2.9859, Train Steps/Sec: 1.02 +[2026-02-04 02:52:13] (step=0073200) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-04 02:53:50] (step=0073300) Train Loss: -2.9875, Train Steps/Sec: 1.03 +[2026-02-04 02:55:28] (step=0073400) Train Loss: -2.9896, Train Steps/Sec: 1.02 +[2026-02-04 02:57:05] (step=0073500) Train Loss: -2.9940, Train Steps/Sec: 1.03 +[2026-02-04 02:58:43] (step=0073600) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-04 03:00:21] (step=0073700) Train Loss: -2.9883, Train Steps/Sec: 1.02 +[2026-02-04 03:01:58] (step=0073800) Train Loss: -2.9895, Train Steps/Sec: 1.03 +[2026-02-04 03:03:36] (step=0073900) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-04 03:05:14] (step=0074000) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-04 03:06:52] (step=0074100) Train Loss: -2.9830, Train Steps/Sec: 1.02 +[2026-02-04 03:08:30] (step=0074200) Train Loss: -2.9861, Train Steps/Sec: 1.02 +[2026-02-04 03:10:08] (step=0074300) Train Loss: -2.9873, Train Steps/Sec: 1.02 +[2026-02-04 03:11:45] (step=0074400) Train Loss: -2.9860, Train Steps/Sec: 1.03 +[2026-02-04 03:13:22] (step=0074500) Train Loss: -2.9887, Train Steps/Sec: 1.03 +[2026-02-04 03:15:00] (step=0074600) Train Loss: -2.9857, Train Steps/Sec: 1.03 +[2026-02-04 03:16:38] (step=0074700) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-04 03:18:15] (step=0074800) Train Loss: -2.9874, Train Steps/Sec: 1.02 +[2026-02-04 03:19:53] (step=0074900) Train Loss: -2.9849, Train Steps/Sec: 1.02 +[2026-02-04 03:21:32] (step=0075000) Train Loss: -2.9902, Train Steps/Sec: 1.02 +[2026-02-04 03:21:33] Saved checkpoint to results_256_gvp_disp/depth-mu-2-004-SiT-XL-2-GVP-velocity-None/checkpoints/0075000.pt +[2026-02-04 03:22:32] Beginning epoch 15... +[2026-02-04 03:23:10] (step=0075100) Train Loss: -2.9908, Train Steps/Sec: 1.02 +[2026-02-04 03:24:48] (step=0075200) Train Loss: -2.9917, Train Steps/Sec: 1.02 +[2026-02-04 03:26:26] (step=0075300) Train Loss: -2.9913, Train Steps/Sec: 1.02 +[2026-02-04 03:28:04] (step=0075400) Train Loss: -2.9900, Train Steps/Sec: 1.02 +[2026-02-04 03:29:42] (step=0075500) Train Loss: -2.9866, Train Steps/Sec: 1.02 +[2026-02-04 03:30:56] Generating EMA samples... +[2026-02-04 03:31:20] (step=0075600) Train Loss: -2.9850, Train Steps/Sec: 1.02 +[2026-02-04 03:32:57] (step=0075700) Train Loss: -2.9845, Train Steps/Sec: 1.03 +[2026-02-04 03:34:36] (step=0075800) Train Loss: -2.9907, Train Steps/Sec: 1.02 +[2026-02-04 03:36:14] (step=0075900) Train Loss: -2.9899, Train Steps/Sec: 1.02 +[2026-02-04 03:37:52] (step=0076000) Train Loss: -2.9894, Train Steps/Sec: 1.02 +[2026-02-04 03:39:30] (step=0076100) Train Loss: -2.9877, Train Steps/Sec: 1.02 +[2026-02-04 03:41:08] (step=0076200) Train Loss: -2.9870, Train Steps/Sec: 1.02 +[2026-02-04 03:42:45] (step=0076300) Train Loss: -2.9843, Train Steps/Sec: 1.03 +[2026-02-04 03:44:23] (step=0076400) Train Loss: -2.9898, Train Steps/Sec: 1.02 +[2026-02-04 03:46:01] (step=0076500) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-04 03:47:39] (step=0076600) Train Loss: -2.9848, Train Steps/Sec: 1.02 +[2026-02-04 03:49:16] (step=0076700) Train Loss: -2.9864, Train Steps/Sec: 1.03 +[2026-02-04 03:50:54] (step=0076800) Train Loss: -2.9876, Train Steps/Sec: 1.03 +[2026-02-04 03:52:32] (step=0076900) Train Loss: -2.9862, Train Steps/Sec: 1.02 +[2026-02-04 03:54:10] (step=0077000) Train Loss: -2.9906, Train Steps/Sec: 1.02 +[2026-02-04 03:55:48] (step=0077100) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-04 03:57:26] (step=0077200) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-04 03:59:04] (step=0077300) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-04 04:00:42] (step=0077400) Train Loss: -2.9886, Train Steps/Sec: 1.02 +[2026-02-04 04:02:20] (step=0077500) Train Loss: -2.9870, Train Steps/Sec: 1.02 +[2026-02-04 04:03:58] (step=0077600) Train Loss: -2.9864, Train Steps/Sec: 1.02 +[2026-02-04 04:05:35] (step=0077700) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-04 04:07:13] (step=0077800) Train Loss: -2.9904, Train Steps/Sec: 1.02 +[2026-02-04 04:08:51] (step=0077900) Train Loss: -2.9850, Train Steps/Sec: 1.02 +[2026-02-04 04:10:29] (step=0078000) Train Loss: -2.9941, Train Steps/Sec: 1.02 +[2026-02-04 04:12:07] (step=0078100) Train Loss: -2.9890, Train Steps/Sec: 1.02 +[2026-02-04 04:13:45] (step=0078200) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-04 04:15:22] (step=0078300) Train Loss: -2.9915, Train Steps/Sec: 1.03 +[2026-02-04 04:17:00] (step=0078400) Train Loss: -2.9876, Train Steps/Sec: 1.03 +[2026-02-04 04:18:37] (step=0078500) Train Loss: -2.9893, Train Steps/Sec: 1.03 +[2026-02-04 04:20:15] (step=0078600) Train Loss: -2.9887, Train Steps/Sec: 1.02 +[2026-02-04 04:21:53] (step=0078700) Train Loss: -2.9854, Train Steps/Sec: 1.02 +[2026-02-04 04:23:31] (step=0078800) Train Loss: -2.9884, Train Steps/Sec: 1.03 +[2026-02-04 04:25:08] (step=0078900) Train Loss: -2.9884, Train Steps/Sec: 1.03 +[2026-02-04 04:26:46] (step=0079000) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 04:28:24] (step=0079100) Train Loss: -2.9918, Train Steps/Sec: 1.02 +[2026-02-04 04:30:01] (step=0079200) Train Loss: -2.9873, Train Steps/Sec: 1.03 +[2026-02-04 04:31:39] (step=0079300) Train Loss: -2.9867, Train Steps/Sec: 1.02 +[2026-02-04 04:33:17] (step=0079400) Train Loss: -2.9800, Train Steps/Sec: 1.02 +[2026-02-04 04:34:55] (step=0079500) Train Loss: -2.9873, Train Steps/Sec: 1.03 +[2026-02-04 04:36:32] (step=0079600) Train Loss: -2.9847, Train Steps/Sec: 1.02 +[2026-02-04 04:38:11] (step=0079700) Train Loss: -2.9876, Train Steps/Sec: 1.02 +[2026-02-04 04:39:48] (step=0079800) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-04 04:41:23] (step=0079900) Train Loss: -2.9922, Train Steps/Sec: 1.06 +[2026-02-04 04:43:00] (step=0080000) Train Loss: -2.9857, Train Steps/Sec: 1.03 +[2026-02-04 04:44:04] Beginning epoch 16... +[2026-02-04 04:44:40] (step=0080100) Train Loss: -2.9882, Train Steps/Sec: 1.00 +[2026-02-04 04:46:18] (step=0080200) Train Loss: -2.9875, Train Steps/Sec: 1.02 +[2026-02-04 04:47:56] (step=0080300) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 04:49:34] (step=0080400) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 04:51:11] (step=0080500) Train Loss: -2.9847, Train Steps/Sec: 1.03 +[2026-02-04 04:52:49] (step=0080600) Train Loss: -2.9891, Train Steps/Sec: 1.02 +[2026-02-04 04:54:27] (step=0080700) Train Loss: -2.9888, Train Steps/Sec: 1.03 +[2026-02-04 04:56:04] (step=0080800) Train Loss: -2.9902, Train Steps/Sec: 1.03 +[2026-02-04 04:57:42] (step=0080900) Train Loss: -2.9849, Train Steps/Sec: 1.02 +[2026-02-04 04:59:20] (step=0081000) Train Loss: -2.9865, Train Steps/Sec: 1.03 +[2026-02-04 05:00:58] (step=0081100) Train Loss: -2.9868, Train Steps/Sec: 1.02 +[2026-02-04 05:02:36] (step=0081200) Train Loss: -2.9889, Train Steps/Sec: 1.02 +[2026-02-04 05:04:14] (step=0081300) Train Loss: -2.9845, Train Steps/Sec: 1.02 +[2026-02-04 05:05:52] (step=0081400) Train Loss: -2.9906, Train Steps/Sec: 1.02 +[2026-02-04 05:07:29] (step=0081500) Train Loss: -2.9916, Train Steps/Sec: 1.02 +[2026-02-04 05:09:08] (step=0081600) Train Loss: -2.9953, Train Steps/Sec: 1.02 +[2026-02-04 05:10:46] (step=0081700) Train Loss: -2.9884, Train Steps/Sec: 1.02 +[2026-02-04 05:12:24] (step=0081800) Train Loss: -2.9865, Train Steps/Sec: 1.02 +[2026-02-04 05:14:01] (step=0081900) Train Loss: -2.9889, Train Steps/Sec: 1.03 +[2026-02-04 05:15:39] (step=0082000) Train Loss: -2.9850, Train Steps/Sec: 1.02 +[2026-02-04 05:17:17] (step=0082100) Train Loss: -2.9880, Train Steps/Sec: 1.02 +[2026-02-04 05:18:55] (step=0082200) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-04 05:20:33] (step=0082300) Train Loss: -2.9869, Train Steps/Sec: 1.02 +[2026-02-04 05:22:10] (step=0082400) Train Loss: -2.9872, Train Steps/Sec: 1.02 +[2026-02-04 05:23:49] (step=0082500) Train Loss: -2.9838, Train Steps/Sec: 1.02 +[2026-02-04 05:25:27] (step=0082600) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-04 05:27:04] (step=0082700) Train Loss: -2.9890, Train Steps/Sec: 1.03 +[2026-02-04 05:28:42] (step=0082800) Train Loss: -2.9881, Train Steps/Sec: 1.02 +[2026-02-04 05:30:19] (step=0082900) Train Loss: -2.9903, Train Steps/Sec: 1.03 +[2026-02-04 05:31:58] (step=0083000) Train Loss: -2.9946, Train Steps/Sec: 1.02 +[2026-02-04 05:33:36] (step=0083100) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-04 05:35:14] (step=0083200) Train Loss: -2.9879, Train Steps/Sec: 1.02 +[2026-02-04 05:36:52] (step=0083300) Train Loss: -2.9939, Train Steps/Sec: 1.02 +[2026-02-04 05:38:30] (step=0083400) Train Loss: -2.9914, Train Steps/Sec: 1.02 +[2026-02-04 05:40:07] (step=0083500) Train Loss: -2.9888, Train Steps/Sec: 1.03 diff --git a/Rectified_Noise/GVP-Disp/run.sh b/Rectified_Noise/GVP-Disp/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c05ab5bb5c16d60db76be761e00369c41555688 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/run.sh @@ -0,0 +1,14 @@ +nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=4 \ + --rdzv_endpoint=localhost:29739 \ + train_rectified_noise.py \ + --depth 2 \ + --results-dir results_256_gvp_disp \ + --data-path /gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/ \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_GVP/base.pt \ + --num-classes 1000 \ + --path-type GVP \ + --prediction velocity \ + --disp \ + > w_training1.log 2>&1 & diff --git a/Rectified_Noise/GVP-Disp/test.sh b/Rectified_Noise/GVP-Disp/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..861c82d88813a7ad3c4a8808ad9b547291a9bc6c --- /dev/null +++ b/Rectified_Noise/GVP-Disp/test.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# Execute all four commands in parallel +# Each command runs in the background using & + +echo "Starting all four sampling tasks in parallel..." + +CUDA_VISIBLE_DEVICES=0 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29110 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir GVP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2 False \ + --use-sitf2-before-t05 False \ + --sitf2-threshold 0.0 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_GVP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/checkpoints/0025000.pt > W_No.log 2>&1 & + +CUDA_VISIBLE_DEVICES=1 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29150 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir GVP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2-before-t05 False \ + --sitf2-threshold 1.0 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_GVP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/checkpoints/0025000.pt > W_False.log 2>&1 & + + +CUDA_VISIBLE_DEVICES=2 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29152 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir GVP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2-before-t05 True \ + --sitf2-threshold 0.5 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_GVP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/checkpoints/0025000.pt > W_True_0.5.log 2>&1 & + +CUDA_VISIBLE_DEVICES=3 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29121 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir GVP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2-before-t05 True \ + --sitf2-threshold 0.15 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_GVP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss/results_256_gvp_disp/depth-mu-2-002-SiT-XL-2-GVP-velocity-None/checkpoints/0025000.pt > W_True_0.15.log 2>&1 & + +# Wait for all background jobs to complete +echo "All tasks started. Waiting for completion..." +wait + +echo "All tasks completed!" \ No newline at end of file diff --git a/Rectified_Noise/GVP-Disp/train_rectified_noise.py b/Rectified_Noise/GVP-Disp/train_rectified_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..47ae11d3ee054026e30198fd8d88bb9ef5973853 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/train_rectified_noise.py @@ -0,0 +1,429 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for SiT using PyTorch DDP. +""" +import torch +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +import numpy as np +from collections import OrderedDict +from PIL import Image +from copy import deepcopy +from glob import glob +from time import time +import argparse +import logging +import os + +from models import SiT, SiTF1, SiTF2, CombinedModel +from models import SiT_models +from download import find_model +from transport import create_transport, Sampler +from diffusers.models import AutoencoderKL +from train_utils import parse_transport_args + + + +################################################################################# +# Training Helper Functions # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format='[\033[34m%(asctime)s\033[0m] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + """ + Trains a new SiT model. + """ + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + dist.init_process_group("nccl") + assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + local_batch_size = int(args.global_batch_size // dist.get_world_size()) + learn_mu = args.learn_mu + depth = args.depth + # Setup an experiment folder: + if rank == 0: + os.makedirs(args.results_dir, exist_ok=True) + experiment_index = len(glob(f"{args.results_dir}/*")) + model_string_name = args.model.replace("/", "-") + if learn_mu: + experiment_name = f"depth-mu-{args.depth}-{experiment_index:03d}-{model_string_name}-" \ + f"{args.path_type}-{args.prediction}-{args.loss_weight}" + else: + experiment_name = f"depth-sigma-{args.depth}-{experiment_index:03d}-{model_string_name}-" \ + f"{args.path_type}-{args.prediction}-{args.loss_weight}" + experiment_dir = f"{args.results_dir}/{experiment_name}" + checkpoint_dir = f"{experiment_dir}/checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + logger = create_logger(experiment_dir) + logger.info(f"Experiment directory created at {experiment_dir}") + + else: + logger = create_logger(None) + + # Create models: + assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." + latent_size = args.image_size // 8 + + # Get model configuration based on args.model + model_config = SiT_models[args.model] + model_kwargs = model_config().__dict__ # Get the default parameters for this model + + # Extract parameters from the model configuration based on the model name + # Model names follow the format like 'SiT-XL/2', 'SiT-B/4', etc. + model_name = args.model + if 'XL' in model_name: + hidden_size, depth, num_heads = 1152, 28, 16 + elif 'L' in model_name: + hidden_size, depth, num_heads = 1024, 24, 16 + elif 'B' in model_name: + hidden_size, depth, num_heads = 768, 12, 12 + elif 'S' in model_name: + hidden_size, depth, num_heads = 384, 12, 6 + else: + # Default fallback + hidden_size, depth, num_heads = 768, 12, 12 + + # Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2 + patch_size = int(model_name.split('/')[-1]) + + sitf1 = SiTF1( + input_size=latent_size, + patch_size=patch_size, + in_channels=4, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=args.num_classes, + learn_sigma=False + ).to(device) + sit = SiT( + input_size=latent_size, + patch_size=patch_size, + in_channels=4, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=args.num_classes, + learn_sigma=False + ).to(device) + sitf2 = SiTF2( + input_size=latent_size, + hidden_size=hidden_size, + out_channels=8, + patch_size=patch_size, + num_heads=num_heads, + mlp_ratio=4.0, + depth=args.depth, # Use the depth for sitf2 as specified by command line + learn_sigma=True, + num_classes=args.num_classes, + learn_mu=learn_mu + ).to(device) + sitf2_ema = deepcopy(sitf2).to(device) + combined_model = CombinedModel(sitf1, sitf2).to(device) + + if args.ckpt is not None: + ckpt_path = args.ckpt + state_dict = find_model(ckpt_path) + try: + sitf1.load_state_dict(state_dict["model"], strict=False) + sit.load_state_dict(state_dict["model"], strict=False) + except: + sitf1.load_state_dict(state_dict, strict=False) + sit.load_state_dict(state_dict, strict=False) + + + requires_grad(sitf1, False) + requires_grad(sit, False) + requires_grad(sitf2, True) + + opt = torch.optim.AdamW(sitf2.parameters(), lr=1e-4, weight_decay=0) + # Do NOT wrap sitf2 separately in DDP (avoids double-wrapping submodules); wrap only the combined model. + combined_model = DDP(combined_model, device_ids=[rank], find_unused_parameters=True) + + # Create transport object: path_type determines the loss form used in training_losses() + # path_type options: "Linear", "GVP", "VP" - each corresponds to a different loss calculation method + transport = create_transport( + args.path_type, # This directly affects how loss is computed in training_losses() + args.prediction, + args.loss_weight, + args.train_eps, + args.sample_eps, + args.disp_loss_weight, + args.temperature + ) + transport_sampler = Sampler(transport) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + logger.info(f"Combined_model Parameters: {sum(p.numel() for p in combined_model.parameters()):,}") + + grad_params = [(n, p.numel()) for n, p in combined_model.named_parameters() if p.requires_grad] + logger.info(f"Total trainable parameters: {sum(cnt for _, cnt in grad_params):,}") + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + dataset = ImageFolder(args.data_path, transform=transform) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=True, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=local_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True + ) + logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") + # Ensure EMA updates target the correct base model (whether sitf2 is wrapped or not) + base_sitf2 = sitf2.module if isinstance(sitf2, torch.nn.parallel.DistributedDataParallel) else sitf2 + update_ema(sitf2_ema, base_sitf2, decay=0) + sitf1.eval() + sit.eval() + sitf2.train() + sitf2_ema.eval() + + train_steps = 0 + log_steps = 0 + running_loss = 0 + start_time = time() + ys = torch.randint(1000, size=(local_batch_size,), device=device) + use_cfg = args.cfg_scale > 1.0 + n = ys.size(0) + zs = torch.randn(n, 4, latent_size, latent_size, device=device) + if use_cfg: + zs = torch.cat([zs, zs], 0) + y_null = torch.tensor([1000] * n, device=device) + ys = torch.cat([ys, y_null], 0) + sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale) + model_fn = sitf1.forward_with_cfg + else: + sample_model_kwargs = dict(y=ys) + model_fn = sitf1.forward + def combined_sampling_model(x, t, y=None, **kwargs): + with torch.no_grad(): + sit_out = sit.forward(x, t, y) + combined_out = combined_model.forward(x, t, y) + return sit_out + combined_out + logger.info(f"Training for {args.epochs} epochs...") + for epoch in range(args.epochs): + sampler.set_epoch(epoch) + logger.info(f"Beginning epoch {epoch}...") + for x, y in loader: + x = x.to(device) + y = y.to(device) + with torch.no_grad(): + x_latent = vae.encode(x).latent_dist.sample().mul_(0.18215) + model_kwargs = dict(y=y, return_act=args.disp) + # Compute training loss: the loss form depends on args.path_type (Linear/GVP/VP) + # Each path_type uses a different mathematical formulation for the transport loss + loss_dict = transport.training_losses(sit, x_latent, model_noise=combined_model, model_kwargs=model_kwargs) + loss = loss_dict["loss"].mean() + opt.zero_grad() + loss.backward() + opt.step() + # Update EMA of the trainable sitf2 base model + update_ema(sitf2_ema, base_sitf2) + running_loss += loss.item() + log_steps += 1 + train_steps += 1 + if train_steps % args.log_every == 0: + torch.cuda.synchronize() + end_time = time() + steps_per_sec = log_steps / (end_time - start_time) + avg_loss = torch.tensor(running_loss / log_steps, device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + avg_loss = avg_loss.item() / dist.get_world_size() + logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") + running_loss = 0 + log_steps = 0 + start_time = time() + if train_steps % args.ckpt_every == 0 and train_steps > 0: + print(train_steps) + if rank == 0: + checkpoint = { + "model": sitf2.state_dict(), + "ema": sitf2.state_dict(), + "opt": opt.state_dict(), + "args": args + } + checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" + torch.save(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + dist.barrier() + + if (train_steps % args.sample_every == 0 )and train_steps > 0: + logger.info("Generating EMA samples...") + if epoch == args.epochs: + break + + sitf1.eval() + sit.eval() + sitf2.eval() + logger.info("Final sampling done.") + + logger.info("Done!") + cleanup() + + +def save_samples_grid(out_samples, epoch, experiment_index, args, experiment_name, rank): + if rank == 0: + import os + import numpy as np + from PIL import Image + parent_dir = os.path.dirname(args.results_dir) + pic_dir = os.path.join(parent_dir, "pic") + os.makedirs(pic_dir, exist_ok=True) + experiment_pic_dir = os.path.join(pic_dir, experiment_name) + os.makedirs(experiment_pic_dir, exist_ok=True) + samples_np = torch.clamp(127.5 * out_samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + n_samples = samples_np.shape[0] + grid_size = int(np.ceil(np.sqrt(n_samples))) + canvas_size = grid_size * args.image_size + canvas = np.zeros((canvas_size, canvas_size, 3), dtype=np.uint8) + for i, sample in enumerate(samples_np): + row = i // grid_size + col = i % grid_size + canvas[row*args.image_size:(row+1)*args.image_size, col*args.image_size:(col+1)*args.image_size] = sample + combined_image = Image.fromarray(canvas) + combined_image.save(os.path.join(experiment_pic_dir, f"epoch_{epoch:04d}_combined.png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--results-dir", type=str, default="results_256_linear") + parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=3) + parser.add_argument("--epochs", type=int, default=100000) + parser.add_argument("--global-batch-size", type=int, default=256) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--ckpt-every", type=int, default=25000) + parser.add_argument("--sample-every", type=int, default=25192) + parser.add_argument("--cfg-scale", type=float, default=4.0) + parser.add_argument("--ckpt", type=str, default='/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/2000000.pt', + help="Optional path to a custom SiT checkpoint") + parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True, + help="Whether to learn mu parameter") + parser.add_argument("--depth", type=int, default=1, + help="Depth parameter for SiTF2 model") + parser.add_argument("--disp", action="store_true", + help="Toggle to enable Dispersive Loss") + parser.add_argument("--disp-loss-weight", type=float, default=0.5, + help="Weight λ for dispersive loss (default: 0.5)") + parser.add_argument("--temperature", type=float, default=1.0, + help="Temperature τ for dispersive loss (default: 1.0)") + + # Transport arguments (added by parse_transport_args): + # --path-type: Type of path for loss calculation (default: "GVP") + # Choices: "Linear" (linear interpolation), "GVP" (Geodesic Velocity Path), "VP" (Velocity Path) + # IMPORTANT: This parameter directly affects the loss form computed by transport.training_losses() + # The path_type determines how the transport loss is calculated during training. + # Make sure to use the correct path_type that matches your training objective. + # --prediction: Type of prediction (default: "velocity") + # --loss-weight: Loss weight type (default: None) + # --sample-eps, --train-eps: Epsilon values for sampling and training + parse_transport_args(parser) + args = parser.parse_args() + main(args) diff --git a/Rectified_Noise/GVP-Disp/transport/__init__.py b/Rectified_Noise/GVP-Disp/transport/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bdcd3999442f5aef777a1f9ecc2889a0b99b2603 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/transport/__init__.py @@ -0,0 +1,71 @@ +from .transport import Transport, ModelType, WeightType, PathType, Sampler + +def create_transport( + path_type='Linear', + prediction="velocity", + loss_weight=None, + train_eps=None, + sample_eps=None, + disp_loss_weight=0.5, + temperature=1.0, +): + """function for creating Transport object + **Note**: model prediction defaults to velocity + Args: + - path_type: type of path to use; default to linear + - learn_score: set model prediction to score + - learn_noise: set model prediction to noise + - velocity_weighted: weight loss by velocity weight + - likelihood_weighted: weight loss by likelihood weight + - train_eps: small epsilon for avoiding instability during training + - sample_eps: small epsilon for avoiding instability during sampling + - disp_loss_weight: weight λ for dispersive loss (default: 0.5) + - temperature: temperature τ for dispersive loss (default: 1.0) + """ + + if prediction == "noise": + model_type = ModelType.NOISE + elif prediction == "score": + model_type = ModelType.SCORE + else: + model_type = ModelType.VELOCITY + + if loss_weight == "velocity": + loss_type = WeightType.VELOCITY + elif loss_weight == "likelihood": + loss_type = WeightType.LIKELIHOOD + else: + loss_type = WeightType.NONE + + path_choice = { + "Linear": PathType.LINEAR, + "GVP": PathType.GVP, + "VP": PathType.VP, + } + + path_type = path_choice[path_type] + + if (path_type in [PathType.VP]): + train_eps_new = 1e-5 if train_eps is None else train_eps + sample_eps_new = 1e-3 if sample_eps is None else sample_eps + train_eps, sample_eps = train_eps_new, sample_eps_new + elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): + train_eps_new = 1e-3 if train_eps is None else train_eps + sample_eps_new = 1e-3 if sample_eps is None else sample_eps + train_eps, sample_eps = train_eps_new, sample_eps_new + else: # velocity & [GVP, LINEAR] is stable everywhere + train_eps = 0 + sample_eps = 0 + + # create flow state + state = Transport( + model_type=model_type, + path_type=path_type, + loss_type=loss_type, + train_eps=train_eps, + sample_eps=sample_eps, + disp_loss_weight=disp_loss_weight, + temperature=temperature, + ) + + return state \ No newline at end of file diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/__init__.cpython-312.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d5fb5227e4b618507cf7d7d23757f7f664c9125 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/__init__.cpython-312.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/__init__.cpython-38.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..467c54269962d466b6384854d290ce00df6853c5 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/__init__.cpython-38.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/integrators.cpython-312.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/integrators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90630213413632ff348e4d511263719efca11f19 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/integrators.cpython-312.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/integrators.cpython-38.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/integrators.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..861d76cda1ec304efb7f94298dd9e6f9acfb32e6 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/integrators.cpython-38.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/path.cpython-312.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/path.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3c9b77609d195510003685f61520a0814fcaca5 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/path.cpython-312.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/path.cpython-38.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/path.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..317dc28dc6becf862c57dc2470c38c1e262df481 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/path.cpython-38.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/transport.cpython-312.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/transport.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7695c20fd0df8ac64979f8ebee7eb3b608876b87 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/transport.cpython-312.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/transport.cpython-38.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/transport.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1db429fedcde7d5ae3138fec28324635c049ee4 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/transport.cpython-38.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/utils.cpython-312.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0080eb9831d5c94b540e2971284d76d6b9490e0f Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/utils.cpython-312.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/__pycache__/utils.cpython-38.pyc b/Rectified_Noise/GVP-Disp/transport/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31c517795212841649007aae4fd637696b508984 Binary files /dev/null and b/Rectified_Noise/GVP-Disp/transport/__pycache__/utils.cpython-38.pyc differ diff --git a/Rectified_Noise/GVP-Disp/transport/integrators.py b/Rectified_Noise/GVP-Disp/transport/integrators.py new file mode 100644 index 0000000000000000000000000000000000000000..adf7c7b4c50b6ff6c63973e0ddaa65b9759274c0 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/transport/integrators.py @@ -0,0 +1,117 @@ +import numpy as np +import torch as th +import torch.nn as nn +from torchdiffeq import odeint +from functools import partial +from tqdm import tqdm + +class sde: + """SDE solver class""" + def __init__( + self, + drift, + diffusion, + *, + t0, + t1, + num_steps, + sampler_type, + ): + assert t0 < t1, "SDE sampler has to be in forward time" + + self.num_timesteps = num_steps + self.t = th.linspace(t0, t1, num_steps) + self.dt = self.t[1] - self.t[0] + self.drift = drift + self.diffusion = diffusion + self.sampler_type = sampler_type + + def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + t = th.ones(x.size(0)).to(x) * t + dw = w_cur * th.sqrt(self.dt) + drift = self.drift(x, t, model, **model_kwargs) + diffusion = self.diffusion(x, t) + mean_x = x + drift * self.dt + x = mean_x + th.sqrt(2 * diffusion) * dw + return x, mean_x + + def __Heun_step(self, x, _, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + dw = w_cur * th.sqrt(self.dt) + t_cur = th.ones(x.size(0)).to(x) * t + diffusion = self.diffusion(x, t_cur) + xhat = x + th.sqrt(2 * diffusion) * dw + K1 = self.drift(xhat, t_cur, model, **model_kwargs) + xp = xhat + self.dt * K1 + K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) + return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step + + def __forward_fn(self): + """TODO: generalize here by adding all private functions ending with steps to it""" + sampler_dict = { + "Euler": self.__Euler_Maruyama_step, + "Heun": self.__Heun_step, + } + + try: + sampler = sampler_dict[self.sampler_type] + except: + raise NotImplementedError("Smapler type not implemented.") + + return sampler + + def sample(self, init, model, **model_kwargs): + """forward loop of sde""" + x = init + mean_x = init + samples = [] + sampler = self.__forward_fn() + for ti in self.t[:-1]: + with th.no_grad(): + x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) + samples.append(x) + + return samples + +class ode: + """ODE solver class""" + def __init__( + self, + drift, + *, + t0, + t1, + sampler_type, + num_steps, + atol, + rtol, + ): + assert t0 < t1, "ODE sampler has to be in forward time" + + self.drift = drift + self.t = th.linspace(t0, t1, num_steps) + self.atol = atol + self.rtol = rtol + self.sampler_type = sampler_type + + def sample(self, x, model, **model_kwargs): + + device = x[0].device if isinstance(x, tuple) else x.device + def _fn(t, x): + t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t + model_output = self.drift(x, t, model, **model_kwargs) + return model_output + + t = self.t.to(device) + atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] + rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] + samples = odeint( + _fn, + x, + t, + method=self.sampler_type, + atol=atol, + rtol=rtol + ) + return samples \ No newline at end of file diff --git a/Rectified_Noise/GVP-Disp/transport/path.py b/Rectified_Noise/GVP-Disp/transport/path.py new file mode 100644 index 0000000000000000000000000000000000000000..156a7b0dea03497a85306ebbeedfe4fbedf87c27 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/transport/path.py @@ -0,0 +1,192 @@ +import torch as th +import numpy as np +from functools import partial + +def expand_t_like_x(t, x): + """Function to reshape time t to broadcastable dimension of x + Args: + t: [batch_dim,], time vector + x: [batch_dim,...], data point + """ + dims = [1] * (len(x.size()) - 1) + t = t.view(t.size(0), *dims) + return t + + +#################### Coupling Plans #################### + +class ICPlan: + """Linear Coupling Plan""" + def __init__(self, sigma=0.0): + self.sigma = sigma + + def compute_alpha_t(self, t): + """Compute the data coefficient along the path""" + return t, 1 + + def compute_sigma_t(self, t): + """Compute the noise coefficient along the path""" + return 1 - t, -1 + + def compute_d_alpha_alpha_ratio_t(self, t): + """Compute the ratio between d_alpha and alpha""" + return 1 / t + + def compute_drift(self, x, t): + """We always output sde according to score parametrization; """ + t = expand_t_like_x(t, x) + alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + drift = alpha_ratio * x + diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t + + return -drift, diffusion + + def compute_diffusion(self, x, t, form="constant", norm=1.0): + """Compute the diffusion term of the SDE + Args: + x: [batch_dim, ...], data point + t: [batch_dim,], time vector + form: str, form of the diffusion term + norm: float, norm of the diffusion term + """ + t = expand_t_like_x(t, x) + choices = { + "constant": norm, + "SBDM": norm * self.compute_drift(x, t)[1], + "sigma": norm * self.compute_sigma_t(t)[0], + "linear": norm * (1 - t), + "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, + "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, + } + + try: + diffusion = choices[form] + except KeyError: + raise NotImplementedError(f"Diffusion form {form} not implemented") + + return diffusion + + def get_score_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to score + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_noise_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to denoiser + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = reverse_alpha_ratio * d_sigma_t - sigma_t + noise = (reverse_alpha_ratio * velocity - mean) / var + return noise + + def get_velocity_from_score(self, score, x, t): + """Wrapper function: transfrom score prediction model to velocity + Args: + score: [batch_dim, ...] shaped tensor; score model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + drift, var = self.compute_drift(x, t) + velocity = var * score - drift + return velocity + + def compute_mu_t(self, t, x0, x1): + """Compute the mean of time-dependent density p_t""" + t = expand_t_like_x(t, x1) + alpha_t, _ = self.compute_alpha_t(t) + sigma_t, _ = self.compute_sigma_t(t) + return alpha_t * x1 + sigma_t * x0 + + def compute_xt(self, t, x0, x1): + """Sample xt from time-dependent density p_t; rng is required""" + xt = self.compute_mu_t(t, x0, x1) + return xt + + def compute_ut(self, t, x0, x1, xt): + """Compute the vector field corresponding to p_t""" + t = expand_t_like_x(t, x1) + _, d_alpha_t = self.compute_alpha_t(t) + _, d_sigma_t = self.compute_sigma_t(t) + return d_alpha_t * x1 + d_sigma_t * x0 + + def plan(self, t, x0, x1): + xt = self.compute_xt(t, x0, x1) + ut = self.compute_ut(t, x0, x1, xt) + return t, xt, ut + + +class VPCPlan(ICPlan): + """class for VP path flow matching""" + + def __init__(self, sigma_min=0.1, sigma_max=20.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min + self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min + + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = self.log_mean_coeff(t) + alpha_t = th.exp(alpha_t) + d_alpha_t = alpha_t * self.d_log_mean_coeff(t) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + p_sigma_t = 2 * self.log_mean_coeff(t) + sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) + d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return self.d_log_mean_coeff(t) + + def compute_drift(self, x, t): + """Compute the drift term of the SDE""" + t = expand_t_like_x(t, x) + beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) + return -0.5 * beta_t * x, beta_t / 2 + + +class GVPCPlan(ICPlan): + def __init__(self, sigma=0.0): + super().__init__(sigma) + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = th.sin(t * np.pi / 2) + d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + sigma_t = th.cos(t * np.pi / 2) + d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return np.pi / (2 * th.tan(t * np.pi / 2)) \ No newline at end of file diff --git a/Rectified_Noise/GVP-Disp/transport/transport.py b/Rectified_Noise/GVP-Disp/transport/transport.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2136c4ea1e7538b750f4e6a53913503c5058eb --- /dev/null +++ b/Rectified_Noise/GVP-Disp/transport/transport.py @@ -0,0 +1,501 @@ +import torch as th +import numpy as np +import logging + +import enum + +from . import path +from .utils import EasyDict, log_state, mean_flat +from .integrators import ode, sde + +class ModelType(enum.Enum): + """ + Which type of output the model predicts. + """ + + NOISE = enum.auto() # the model predicts epsilon + SCORE = enum.auto() # the model predicts \nabla \log p(x) + VELOCITY = enum.auto() # the model predicts v(x) + +class PathType(enum.Enum): + """ + Which type of path to use. + """ + + LINEAR = enum.auto() + GVP = enum.auto() + VP = enum.auto() + +class WeightType(enum.Enum): + """ + Which type of weighting to use. + """ + + NONE = enum.auto() + VELOCITY = enum.auto() + LIKELIHOOD = enum.auto() + + +class Transport: + + def __init__( + self, + *, + model_type, + path_type, + loss_type, + train_eps, + sample_eps, + disp_loss_weight=0.5, + temperature=1.0, + ): + path_options = { + PathType.LINEAR: path.ICPlan, + PathType.GVP: path.GVPCPlan, + PathType.VP: path.VPCPlan, + } + + self.loss_type = loss_type + self.model_type = model_type + self.path_sampler = path_options[path_type]() + self.train_eps = train_eps + self.sample_eps = sample_eps + self.disp_loss_weight = disp_loss_weight # λ: weight for dispersive loss + self.temperature = temperature # τ: temperature parameter + + def prior_logp(self, z): + ''' + Standard multivariate normal prior + Assume z is batched + ''' + shape = th.tensor(z.size()) + N = th.prod(shape[1:]) + _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. + return th.vmap(_fn)(z) + + + def check_interval( + self, + train_eps, + sample_eps, + *, + diffusion_form="SBDM", + sde=False, + reverse=False, + eval=False, + last_step_size=0.0, + ): + t0 = 0 + t1 = 1 + eps = train_eps if not eval else sample_eps + if (type(self.path_sampler) in [path.VPCPlan]): + + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ + and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step + + t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + if reverse: + t0, t1 = 1 - t0, 1 - t1 + + return t0, t1 + + + def sample(self, x1): + """Sampling x0 & t based on shape of x1 (if needed) + Args: + x1 - data point; [batch, *dim] + """ + + x0 = th.randn_like(x1) + t0, t1 = self.check_interval(self.train_eps, self.sample_eps) + t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 + t = t.to(x1) + return t, x0, x1 + + def disp_loss(self, z): + """Dispersive Loss implementation (InfoNCE-L2 variant) + Args: + z: activation tensor from model layers + """ + z = z.reshape((z.shape[0], -1)) # flatten + diff = th.nn.functional.pdist(z).pow(2) / z.shape[1] # pairwise distance + diff = th.cat((diff, diff, th.zeros(z.shape[0], device=z.device))) # match JAX implementation of full BxB matrix + # Apply temperature scaling: divide by temperature τ + diff = diff / self.temperature + return th.log(th.exp(-diff).mean()) # calculate loss + + def training_losses( + self, + model, + x1, + model_noise=None, + model_kwargs=None + ): + """Loss for training the score model + Args: + - model: backbone model; could be score, noise, or velocity + - x1: datapoint + - model_kwargs: additional arguments for the model + """ + + + if model_kwargs == None: + model_kwargs = {} + + t, x0, x1 = self.sample(x1) + t, xt, ut = self.path_sampler.plan(t, x0, x1) + + # Handle return_act for dispersive loss + disp_loss = 0 + if model_noise==None: + model_output = model(xt, t, **model_kwargs) + # Check if model returns activations (for dispersive loss) + if "return_act" in model_kwargs and model_kwargs['return_act']: + model_output, act = model_output + if act is not None and len(act) > 0: + # Calculate dispersive loss for all blocks + for block_act in act: + disp_loss = disp_loss + self.disp_loss(block_act) + else: + model_output_pre = model(xt, t, **model_kwargs) + # Handle return_act for model_noise + if "return_act" in model_kwargs and model_kwargs['return_act']: + if isinstance(model_output_pre, tuple): + model_output_pre, act_pre = model_output_pre + else: + act_pre = None + else: + act_pre = None + + model_output_noise = model_noise(xt, t, **model_kwargs) + # Handle return_act for model_noise + if "return_act" in model_kwargs and model_kwargs['return_act']: + if isinstance(model_output_noise, tuple): + model_output_noise, act_noise = model_output_noise + else: + act_noise = None + # Calculate dispersive loss for all blocks in model_noise (sitf2) + if act_noise is not None and len(act_noise) > 0: + # Calculate dispersive loss for each block and sum them + for block_act in act_noise: + disp_loss = disp_loss + self.disp_loss(block_act) + model_output = model_output_pre + model_output_noise + + B, *_, C = xt.shape + assert model_output.size() == (B, *xt.size()[1:-1], C) + + terms = {} + terms['pred'] = model_output + if self.model_type == ModelType.VELOCITY: + terms['loss'] = mean_flat(((model_output - ut) ** 2)) + else: + _, drift_var = self.path_sampler.compute_drift(xt, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) + if self.loss_type in [WeightType.VELOCITY]: + weight = (drift_var / sigma_t) ** 2 + elif self.loss_type in [WeightType.LIKELIHOOD]: + weight = drift_var / (sigma_t ** 2) + elif self.loss_type in [WeightType.NONE]: + weight = 1 + else: + raise NotImplementedError() + + if self.model_type == ModelType.NOISE: + terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) + else: + terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) + + # Add dispersive loss to the total loss with weight λ + if disp_loss != 0: + terms['loss'] = terms['loss'] + self.disp_loss_weight * disp_loss + + return terms + + + def get_drift( + self + ): + """member function for obtaining the drift of the probability flow ODE""" + def score_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + model_output = model(x, t, **model_kwargs) + return (-drift_mean + drift_var * model_output) # by change of variable + + def noise_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) + model_output = model(x, t, **model_kwargs) + score = model_output / -sigma_t + return (-drift_mean + drift_var * score) + + def velocity_ode(x, t, model, **model_kwargs): + model_output = model(x, t, **model_kwargs) + return model_output + + if self.model_type == ModelType.NOISE: + drift_fn = noise_ode + elif self.model_type == ModelType.SCORE: + drift_fn = score_ode + else: + drift_fn = velocity_ode + + def body_fn(x, t, model, **model_kwargs): + model_output = drift_fn(x, t, model, **model_kwargs) + assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" + return model_output + + return body_fn + + + def get_score( + self, + ): + """member function for obtaining score of + x_t = alpha_t * x + sigma_t * eps""" + if self.model_type == ModelType.NOISE: + score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] + elif self.model_type == ModelType.SCORE: + score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) + elif self.model_type == ModelType.VELOCITY: + score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) + else: + raise NotImplementedError() + + return score_fn + + +class Sampler: + """Sampler class for the transport model""" + def __init__( + self, + transport, + ): + """Constructor for a general sampler; supporting different sampling methods + Args: + - transport: an tranport object specify model prediction & interpolant type + """ + + self.transport = transport + self.drift = self.transport.get_drift() + self.score = self.transport.get_score() + + def __get_sde_diffusion_and_drift( + self, + *, + diffusion_form="SBDM", + diffusion_norm=1.0, + ): + + def diffusion_fn(x, t): + diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) + return diffusion + + sde_drift = \ + lambda x, t, model, **kwargs: \ + self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) + + sde_diffusion = diffusion_fn + + return sde_drift, sde_diffusion + + def __get_last_step( + self, + sde_drift, + *, + last_step, + last_step_size, + ): + """Get the last step function of the SDE solver""" + + if last_step is None: + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + elif last_step == "Mean": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + sde_drift(x, t, model, **model_kwargs) * last_step_size + elif last_step == "Tweedie": + alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long + sigma = self.transport.path_sampler.compute_sigma_t + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) + elif last_step == "Euler": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + self.drift(x, t, model, **model_kwargs) * last_step_size + else: + raise NotImplementedError() + + return last_step_fn + + def sample_sde( + self, + *, + sampling_method="Euler", + diffusion_form="SBDM", + diffusion_norm=1.0, + last_step="Mean", + last_step_size=0.04, + num_steps=250, + ): + """returns a sampling function with given SDE settings + Args: + - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama + - diffusion_form: function form of diffusion coefficient; default to be matching SBDM + - diffusion_norm: function magnitude of diffusion coefficient; default to 1 + - last_step: type of the last step; default to identity + - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] + - num_steps: total integration step of SDE + """ + + if last_step is None: + last_step_size = 0.0 + + sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( + diffusion_form=diffusion_form, + diffusion_norm=diffusion_norm, + ) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + diffusion_form=diffusion_form, + sde=True, + eval=True, + reverse=False, + last_step_size=last_step_size, + ) + + _sde = sde( + sde_drift, + sde_diffusion, + t0=t0, + t1=t1, + num_steps=num_steps, + sampler_type=sampling_method + ) + + last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) + + + def _sample(init, model, **model_kwargs): + xs = _sde.sample(init, model, **model_kwargs) + ts = th.ones(init.size(0), device=init.device) * t1 + x = last_step_fn(xs[-1], ts, model, **model_kwargs) + xs.append(x) + + assert len(xs) == num_steps, "Samples does not match the number of steps" + + return xs + + return _sample + + def sample_ode( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + reverse=False, + ): + """returns a sampling function with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + - reverse: whether solving the ODE in reverse (data to noise); default to False + """ + if reverse: + drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) + else: + drift = self.drift + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=reverse, + last_step_size=0.0, + ) + + _ode = ode( + drift=drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + return _ode.sample + + def sample_ode_likelihood( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + ): + + """returns a sampling function for calculating likelihood with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + """ + def _likelihood_drift(x, t, model, **model_kwargs): + x, _ = x + eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 + t = th.ones_like(t) * (1 - t) + with th.enable_grad(): + x.requires_grad = True + grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] + logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) + drift = self.drift(x, t, model, **model_kwargs) + return (-drift, logp_grad) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=False, + last_step_size=0.0, + ) + + _ode = ode( + drift=_likelihood_drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + def _sample_fn(x, model, **model_kwargs): + init_logp = th.zeros(x.size(0)).to(x) + input = (x, init_logp) + drift, delta_logp = _ode.sample(input, model, **model_kwargs) + drift, delta_logp = drift[-1], delta_logp[-1] + prior_logp = self.transport.prior_logp(drift) + logp = prior_logp - delta_logp + return logp, drift + + return _sample_fn \ No newline at end of file diff --git a/Rectified_Noise/GVP-Disp/transport/utils.py b/Rectified_Noise/GVP-Disp/transport/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44646035531326b81883727f973900edb4eac494 --- /dev/null +++ b/Rectified_Noise/GVP-Disp/transport/utils.py @@ -0,0 +1,29 @@ +import torch as th + +class EasyDict: + + def __init__(self, sub_dict): + for k, v in sub_dict.items(): + setattr(self, k, v) + + def __getitem__(self, key): + return getattr(self, key) + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return th.mean(x, dim=list(range(1, len(x.size())))) + +def log_state(state): + result = [] + + sorted_state = dict(sorted(state.items())) + for key, value in sorted_state.items(): + # Check if the value is an instance of a class + if " + sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')()) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main + run(args) + File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run + elastic_launch( + File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +========================================================== +train_rectified_noise.py FAILED +---------------------------------------------------------- +Failures: + +---------------------------------------------------------- +Root Cause (first observed failure): +[0]: + time : 2026-02-04_05:40:25 + host : cabbd6562a3025dd000330e2d302e8fd-taskrole1-0 + rank : 0 (local_rank: 0) + exitcode : -9 (pid: 72202) + error_file: + traceback : Signal 9 (SIGKILL) received by PID 72202 +========================================================== diff --git a/Rectified_Noise/VP-Disp/README.md b/Rectified_Noise/VP-Disp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..116a845bb2f32f1d32b074e4335c9cf71eaacf60 --- /dev/null +++ b/Rectified_Noise/VP-Disp/README.md @@ -0,0 +1,92 @@ +# [AAAI 2026] Rectified Noise: A Generative Model Using Positive-incentive Noise + +![Visualization of the $\pi$-noise by $\Delta$RN.](assests/visual.png) + +
+ +HuggingFace + +## Introduction +This is a [Pytorch](https://pytorch.org) implementation of **Rectified Noise**, a generative model using positive-incentive noise to enhance model's sampling. + +![Overview of Laytrol](assests/pipeline.png) + +## Setup + +We provide an `environment.yml` file that can be used to create a Conda environment. + +```bash +conda env create -f environment.yml +conda activate RN +``` + +## Usage + +### Training +1. We provide a training script for RN in `train_rectified_noise.py` + + Run: + +```bash +torchrun --nnodes=1 --nproc_per_node=4 train_rectified_noise.py \ +--data-path /path/to/data \ +--num-classes 3 \ +--path-type Linear \ +--prediction velocity \ +--ckpt /path/to/pretrained_model \ +--model SiT-B/2 +--learn-mu True \ +--depth 1 \ +``` + +You can find relevant checkpoint files from the previous Hugging Face link. + +2. Parameters: + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--data-path ` | str | `-` | Path to the dataset. | +| `--num-classes` | int | `-` | Number of classes. | +| `--path-type` | str | `Linear` | Directory to save the generated images. | +| `--prediction` | str | `velocity` | Output type of network. | +| `--ckpt` | str | `-` | Path to pretrained model checkpoint. | +| `--model` | str | `SiT-B/2` | Model type, any option from the model list. | +| `--learn-mu` | bool | `True` | Whether to learn the mu parameter. | +| `--depth` | int | `1` | Depth parameter for the SiTF2 model(Extra SiT Block). | + +**Sampling** + +1. Using the trained RN model to enhance the pre-trained model + +```bash +torchrun --nnodes=1 --nproc_per_node=4 train_rectified_noise.py \ +--path-type Linear \ +--prediction velocity \ +--ckpt /path/to/pretrained_model \ +--sitf2-ckpt /path/to/pretrained_RN \ +--model SiT-B/2 +--learn-mu True \ +--depth 1 \ +``` + +## Ackownledgement +This repo benefits from [SiT](https://github.com/willisma/SiT). Thanks for their excellent works. + +## Contact +If you have any question about this project, please contact mguzhenyu@outlook.com. + +## Citation + +If you find the code useful for your research, please consider citing our work: + +``` +@misc{gu2025rectifiednoisegenerativemodel, + title={Rectified Noise: A Generative Model Using Positive-incentive Noise}, + author={Zhenyu Gu and Yanchen Xu and Sida Huang and Yubin Guo and Hongyuan Zhang}, + year={2025}, + eprint={2511.07911}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2511.07911}, +} +``` diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000059.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000059.png new file mode 100644 index 0000000000000000000000000000000000000000..90dfd8e184ef67cfaf40df1648df30929958c6cb Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000059.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000169.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000169.png new file mode 100644 index 0000000000000000000000000000000000000000..b1219c9eaeb2fa285a8b93dc5742d7f4cecb7e7b Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000169.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000286.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000286.png new file mode 100644 index 0000000000000000000000000000000000000000..453e52d1bf005484c031d3eb178ca622d8bdcadc Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000286.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000545.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000545.png new file mode 100644 index 0000000000000000000000000000000000000000..b97022e762db35f7fff9c365ac817bbacccd8eb3 Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000545.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000606.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000606.png new file mode 100644 index 0000000000000000000000000000000000000000..4069b3cd7eff0f9a93773aed1befeff9951adfc8 Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000606.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000769.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000769.png new file mode 100644 index 0000000000000000000000000000000000000000..65b51ad09bb578bc08ee478a312be8ba01d9c818 Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000769.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001050.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001050.png new file mode 100644 index 0000000000000000000000000000000000000000..24863a9b3dc45d0d080ceb2e02840b8496149a91 Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001050.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001099.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001099.png new file mode 100644 index 0000000000000000000000000000000000000000..0ec70b252ed53dc19890239566bbb1c02c2e528e Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001099.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001346.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001346.png new file mode 100644 index 0000000000000000000000000000000000000000..81221157cd7bfbc23596263f1afce7d04cc5de4e Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001346.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001475.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001475.png new file mode 100644 index 0000000000000000000000000000000000000000..eb5c659959a1a270b428a5a6cb4fdc5182ee82ae Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001475.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001518.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001518.png new file mode 100644 index 0000000000000000000000000000000000000000..107218f051120a1142ace458dce74e7d26991b0f Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001518.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001644.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001644.png new file mode 100644 index 0000000000000000000000000000000000000000..bc6d6a67207b05b8ffeea915bb5ab9d0b54e49a4 Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001644.png differ diff --git a/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001741.png b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001741.png new file mode 100644 index 0000000000000000000000000000000000000000..fad79510227f28a23fd6c8b8f49a7e0111d40486 Binary files /dev/null and b/Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.5-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001741.png differ diff --git a/Rectified_Noise/VP-Disp/W_False.log b/Rectified_Noise/VP-Disp/W_False.log new file mode 100644 index 0000000000000000000000000000000000000000..34b3e1dd1f2503ce269d7aa685cab0fc8434783a --- /dev/null +++ b/Rectified_Noise/VP-Disp/W_False.log @@ -0,0 +1,5 @@ +[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds. +Starting rank=0, seed=0, world_size=1. +Saving .png samples at VP_samples/depth-mu-2-threshold-1.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04 +Total number of images that will be sampled: 3008 + 0%| | 0/47 [00:00= 3.8 + - pytorch >= 1.13 + - torchvision + - pytorch-cuda >=11.7 + - pip + - pip: + - timm + - diffusers + - accelerate + - torchdiffeq + - wandb \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/evaluate_samples.sh b/Rectified_Noise/VP-Disp/evaluate_samples.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f8e0c543465b1b6f5ced478267f9d36bcac50a2 --- /dev/null +++ b/Rectified_Noise/VP-Disp/evaluate_samples.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Execute all evaluation tasks in parallel +# Each command runs in the background using & + +echo "Starting all evaluation tasks in parallel..." + +# Reference batch path +REF_BATCH="/gemini/space/zhaozy/zhy/dataset/VIRTUAL_imagenet256_labeled.npz" + +# Base directory for sample files +SAMPLE_DIR="/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/last_samples_depth_2_gvp_0.5" + +# Change to the project root directory +cd /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching + +# Evaluate threshold 0.0 on GPU 0 +CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.0.log 2>&1 & + +# Evaluate threshold 0.15 on GPU 1 +CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.15-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.15.log 2>&1 & + +# Evaluate threshold 0.25 on GPU 2 +CUDA_VISIBLE_DEVICES=2 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.25-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.25.log 2>&1 & + +# Evaluate threshold 0.5 on GPU 3 +CUDA_VISIBLE_DEVICES=3 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.5-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.5.log 2>&1 & + +# Evaluate threshold 0.75 on GPU 4 +CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.75-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_0.75.log 2>&1 & + +# Evaluate threshold 1.0 on GPU 5 +CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \ + --ref_batch ${REF_BATCH} \ + --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-1.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \ + > eval_threshold_1.0.log 2>&1 & + +# Wait for all background jobs to complete +echo "All evaluation tasks started. Waiting for completion..." +wait + +echo "All evaluation tasks completed!" +echo "" +echo "Results saved in:" +echo " - eval_threshold_0.0.log" +echo " - eval_threshold_0.15.log" +echo " - eval_threshold_0.25.log" +echo " - eval_threshold_0.5.log" +echo " - eval_threshold_0.75.log" +echo " - eval_threshold_1.0.log" diff --git a/Rectified_Noise/VP-Disp/evaluator.py b/Rectified_Noise/VP-Disp/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..62bdd4b20c4da0db41fe0328d3b2ee6040935d23 --- /dev/null +++ b/Rectified_Noise/VP-Disp/evaluator.py @@ -0,0 +1,689 @@ +import argparse +import io +import os +import random +import warnings +import zipfile +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool +from typing import Iterable, Optional, Tuple, Union + +import numpy as np +import requests +import tensorflow.compat.v1 as tf +from scipy import linalg +from tqdm.auto import tqdm +from datetime import timedelta +import torch + + + +INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" +INCEPTION_V3_PATH = "classify_image_graph_def.pb" + +FID_POOL_NAME = "pool_3:0" +FID_SPATIAL_NAME = "mixed_6/conv:0" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ref_batch", default='/gemini/space/zhaozy/zhy/dataset/VIRTUAL_imagenet256_labeled.npz',help="path to reference batch npz file") + parser.add_argument("--sample_batch", default='/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/last_samples_depth_2/depth-mu-28-0050000-2000000-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz', help="path to sample batch npz file") + args = parser.parse_args() + + config = tf.ConfigProto( + allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph + ) + config.gpu_options.allow_growth = True + evaluator = Evaluator(tf.Session(config=config)) + + print("warming up TensorFlow...") + # This will cause TF to print a bunch of verbose stuff now rather + # than after the next print(), to help prevent confusion. + evaluator.warmup() + + print("computing reference batch activations...") + ref_acts = evaluator.read_activations(args.ref_batch) + print("computing/reading reference batch statistics...") + ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) + + print("computing sample batch activations...") + sample_acts = evaluator.read_activations(args.sample_batch) + print("computing/reading sample batch statistics...") + sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) + + print("Computing evaluations...") + #print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) + print("FID:", sample_stats.frechet_distance(ref_stats)) + #print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) + #prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) + #print("Precision:", prec) + #print("Recall:", recall) + + +class InvalidFIDException(Exception): + pass + + +class FIDStatistics: + def __init__(self, mu: np.ndarray, sigma: np.ndarray): + self.mu = mu + self.sigma = sigma + + def frechet_distance(self, other, eps=1e-6): + """ + Compute the Frechet distance between two sets of statistics. + """ + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 + mu1, sigma1 = self.mu, self.sigma + mu2, sigma2 = other.mu, other.sigma + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" + assert ( + sigma1.shape == sigma2.shape + ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; adding %s to diagonal of cov estimates" + % eps + ) + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + #虚部报错部分 + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1): + m = np.max(np.abs(covmean.imag)) + print(f"Real component: {covmean.real}") + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +class Evaluator: + def __init__( + self, + session, + batch_size=64, + softmax_batch_size=512, + ): + self.sess = session + self.batch_size = batch_size + self.softmax_batch_size = softmax_batch_size + self.manifold_estimator = ManifoldEstimator(session) + with self.sess.graph.as_default(): + self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) + self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) + self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) + self.softmax = _create_softmax_graph(self.softmax_input) + + def warmup(self): + self.compute_activations(np.zeros([1, 8, 64, 64, 3])) + + def read_activations(self, npz_path: Union[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + if isinstance(npz_path, str): + # If npz_path is a string, treat it as a file path and read the .npz file + with open_npz_array(npz_path, "arr_0") as reader: + return self.compute_activations(reader.read_batches(self.batch_size)) + elif isinstance(npz_path, np.ndarray): + # If npz_path is a numpy array, split it into batches manually + print("--------line 140-----------") + batches = np.array_split(npz_path, range(self.batch_size, npz_path.shape[0], self.batch_size)) + print("--------line 143-----------") + return self.compute_activations(batches) + else: + raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)") + + + def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute image features for downstream evals. + + :param batches: a iterator over NHWC numpy arrays in [0, 255]. + :return: a tuple of numpy arrays of shape [N x X], where X is a feature + dimension. The tuple is (pool_3, spatial). + """ + preds = [] + spatial_preds = [] + for batch in tqdm(batches): + # print("--------line 164-----------") + + # # 识别当前进程信息 + # if 'RANK' in os.environ: + # rank = int(os.environ['RANK']) + # local_rank = int(os.environ.get('LOCAL_RANK', rank % torch.cuda.device_count())) + # print(f"Distributed training - Global Rank: {rank}, Local Rank: {local_rank}") + # print(f"Current GPU device: {torch.cuda.current_device()}" if torch.cuda.is_available() else "No CUDA") + # else: + # print("Single process mode") + + # print(f"Process PID: {os.getpid()}") + + batch = batch.astype(np.float32) + pred, spatial_pred = self.sess.run( + [self.pool_features, self.spatial_features], {self.image_input: batch} + ) + # print("--------line 169-----------") + preds.append(pred.reshape([pred.shape[0], -1])) + spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) + return ( + np.concatenate(preds, axis=0), + np.concatenate(spatial_preds, axis=0), + ) + + def read_statistics( + self, npz_path: Union[str, np.ndarray], activations: Tuple[np.ndarray, np.ndarray] + ) -> Tuple[FIDStatistics, FIDStatistics]: + if isinstance(npz_path, str): + obj = np.load(npz_path) + if "mu" in list(obj.keys()): + return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( + obj["mu_s"], obj["sigma_s"] + ) + elif isinstance(npz_path, np.ndarray): + obj = npz_path + else: + raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)") + return tuple(self.compute_statistics(x) for x in activations) + + def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return FIDStatistics(mu, sigma) + + def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: + softmax_out = [] + for i in range(0, len(activations), self.softmax_batch_size): + acts = activations[i : i + self.softmax_batch_size] + softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) + preds = np.concatenate(softmax_out, axis=0) + # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 + scores = [] + for i in range(0, len(preds), split_size): + part = preds[i : i + split_size] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)) + + def compute_prec_recall( + self, activations_ref: np.ndarray, activations_sample: np.ndarray + ) -> Tuple[float, float]: + radii_1 = self.manifold_estimator.manifold_radii(activations_ref) + radii_2 = self.manifold_estimator.manifold_radii(activations_sample) + pr = self.manifold_estimator.evaluate_pr( + activations_ref, radii_1, activations_sample, radii_2 + ) + return (float(pr[0][0]), float(pr[1][0])) + + +class ManifoldEstimator: + """ + A helper for comparing manifolds of feature vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 + """ + + def __init__( + self, + session, + row_batch_size=10000, + col_batch_size=10000, + nhood_sizes=(3,), + clamp_to_percentile=None, + eps=1e-5, + ): + """ + Estimate the manifold of given feature vectors. + + :param session: the TensorFlow session. + :param row_batch_size: row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + :param col_batch_size: column batch size to compute pairwise distances. + :param nhood_sizes: number of neighbors used to estimate the manifold. + :param clamp_to_percentile: prune hyperspheres that have radius larger than + the given percentile. + :param eps: small number for numerical stability. + """ + self.distance_block = DistanceBlock(session) + self.row_batch_size = row_batch_size + self.col_batch_size = col_batch_size + self.nhood_sizes = nhood_sizes + self.num_nhoods = len(nhood_sizes) + self.clamp_to_percentile = clamp_to_percentile + self.eps = eps + + def warmup(self): + feats, radii = ( + np.zeros([1, 2048], dtype=np.float32), + np.zeros([1, 1], dtype=np.float32), + ) + self.evaluate_pr(feats, radii, feats, radii) + + def manifold_radii(self, features: np.ndarray) -> np.ndarray: + num_images = len(features) + + # Estimate manifold of features by calculating distances to k-NN of each sample. + radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) + distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) + seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) + + for begin1 in range(0, num_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(row_batch, col_batch) + + # Find the k-nearest neighbor from the current batch. + radii[begin1:end1, :] = np.concatenate( + [ + x[:, self.nhood_sizes] + for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) + ], + axis=0, + ) + + if self.clamp_to_percentile is not None: + max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) + radii[radii > max_distances] = 0 + return radii + + def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): + """ + Evaluate if new feature vectors are at the manifold. + """ + num_eval_images = eval_features.shape[0] + num_ref_images = radii.shape[0] + distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) + batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) + max_realism_score = np.zeros([num_eval_images], dtype=np.float32) + nearest_indices = np.zeros([num_eval_images], dtype=np.int32) + + for begin1 in range(0, num_eval_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_eval_images) + feature_batch = eval_features[begin1:end1] + + for begin2 in range(0, num_ref_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_ref_images) + ref_batch = features[begin2:end2] + + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) + + # From the minibatch of new feature vectors, determine if they are in the estimated manifold. + # If a feature vector is inside a hypersphere of some reference sample, then + # the new sample lies at the estimated manifold. + # The radii of the hyperspheres are determined from distances of neighborhood size k. + samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii + batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) + + max_realism_score[begin1:end1] = np.max( + radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 + ) + nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) + + return { + "fraction": float(np.mean(batch_predictions)), + "batch_predictions": batch_predictions, + "max_realisim_score": max_realism_score, + "nearest_indices": nearest_indices, + } + + def evaluate_pr( + self, + features_1: np.ndarray, + radii_1: np.ndarray, + features_2: np.ndarray, + radii_2: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Evaluate precision and recall efficiently. + + :param features_1: [N1 x D] feature vectors for reference batch. + :param radii_1: [N1 x K1] radii for reference vectors. + :param features_2: [N2 x D] feature vectors for the other batch. + :param radii_2: [N x K2] radii for other vectors. + :return: a tuple of arrays for (precision, recall): + - precision: an np.ndarray of length K1 + - recall: an np.ndarray of length K2 + """ + features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) + features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) + for begin_1 in range(0, len(features_1), self.row_batch_size): + end_1 = begin_1 + self.row_batch_size + batch_1 = features_1[begin_1:end_1] + for begin_2 in range(0, len(features_2), self.col_batch_size): + end_2 = begin_2 + self.col_batch_size + batch_2 = features_2[begin_2:end_2] + batch_1_in, batch_2_in = self.distance_block.less_thans( + batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] + ) + features_1_status[begin_1:end_1] |= batch_1_in + features_2_status[begin_2:end_2] |= batch_2_in + return ( + np.mean(features_2_status.astype(np.float64), axis=0), + np.mean(features_1_status.astype(np.float64), axis=0), + ) + + +class DistanceBlock: + """ + Calculate pairwise distances between vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 + """ + + def __init__(self, session): + self.session = session + + # Initialize TF graph to calculate pairwise distances. + with session.graph.as_default(): + self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) + self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) + distance_block_16 = _batch_pairwise_distances( + tf.cast(self._features_batch1, tf.float16), + tf.cast(self._features_batch2, tf.float16), + ) + self.distance_block = tf.cond( + tf.reduce_all(tf.math.is_finite(distance_block_16)), + lambda: tf.cast(distance_block_16, tf.float32), + lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), + ) + + # Extra logic for less thans. + self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) + self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) + dist32 = tf.cast(self.distance_block, tf.float32)[..., None] + self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) + self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) + + def pairwise_distances(self, U, V): + """ + Evaluate pairwise distances between two batches of feature vectors. + """ + return self.session.run( + self.distance_block, + feed_dict={self._features_batch1: U, self._features_batch2: V}, + ) + + def less_thans(self, batch_1, radii_1, batch_2, radii_2): + return self.session.run( + [self._batch_1_in, self._batch_2_in], + feed_dict={ + self._features_batch1: batch_1, + self._features_batch2: batch_2, + self._radii1: radii_1, + self._radii2: radii_2, + }, + ) + + +def _batch_pairwise_distances(U, V): + """ + Compute pairwise distances between two batches of feature vectors. + """ + with tf.variable_scope("pairwise_dist_block"): + # Squared norms of each row in U and V. + norm_u = tf.reduce_sum(tf.square(U), 1) + norm_v = tf.reduce_sum(tf.square(V), 1) + + # norm_u as a column and norm_v as a row vectors. + norm_u = tf.reshape(norm_u, [-1, 1]) + norm_v = tf.reshape(norm_v, [1, -1]) + + # Pairwise squared Euclidean distances. + D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) + + return D + + +class NpzArrayReader(ABC): + @abstractmethod + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + pass + + @abstractmethod + def remaining(self) -> int: + pass + + def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: + def gen_fn(): + while True: + batch = self.read_batch(batch_size) + if batch is None: + break + yield batch + + rem = self.remaining() + num_batches = rem // batch_size + int(rem % batch_size != 0) + return BatchIterator(gen_fn, num_batches) + + +class BatchIterator: + def __init__(self, gen_fn, length): + self.gen_fn = gen_fn + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen_fn() + + +class StreamingNpzArrayReader(NpzArrayReader): + def __init__(self, arr_f, shape, dtype): + self.arr_f = arr_f + self.shape = shape + self.dtype = dtype + self.idx = 0 + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.shape[0]: + return None + + bs = min(batch_size, self.shape[0] - self.idx) + self.idx += bs + + if self.dtype.itemsize == 0: + return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) + + read_count = bs * np.prod(self.shape[1:]) + read_size = int(read_count * self.dtype.itemsize) + data = _read_bytes(self.arr_f, read_size, "array data") + return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) + + def remaining(self) -> int: + return max(0, self.shape[0] - self.idx) + + +class MemoryNpzArrayReader(NpzArrayReader): + def __init__(self, arr): + self.arr = arr + self.idx = 0 + + @classmethod + def load(cls, path: str, arr_name: str): + with open(path, "rb") as f: + arr = np.load(f)[arr_name] + return cls(arr) + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.arr.shape[0]: + return None + + res = self.arr[self.idx : self.idx + batch_size] + self.idx += batch_size + return res + + def remaining(self) -> int: + return max(0, self.arr.shape[0] - self.idx) + + +@contextmanager +def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: + with _open_npy_file(path, arr_name) as arr_f: + version = np.lib.format.read_magic(arr_f) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(arr_f) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(arr_f) + else: + yield MemoryNpzArrayReader.load(path, arr_name) + return + shape, fortran, dtype = header + if fortran or dtype.hasobject: + yield MemoryNpzArrayReader.load(path, arr_name) + else: + yield StreamingNpzArrayReader(arr_f, shape, dtype) + + +def _read_bytes(fp, size, error_template="ran out of data"): + """ + Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 + + Read from file-like object until size bytes are read. + Raises ValueError if not EOF is encountered before size bytes are read. + Non-blocking objects only supported if they derive from io objects. + Required as e.g. ZipExtFile in python 2.6 can return less data than + requested. + """ + data = bytes() + while True: + # io files (default in python3) return None or raise on + # would-block, python2 file will truncate, probably nothing can be + # done about that. note that regular files can't be non-blocking + try: + r = fp.read(size - len(data)) + data += r + if len(r) == 0 or len(data) == size: + break + except io.BlockingIOError: + pass + if len(data) != size: + msg = "EOF: reading %s, expected %d bytes got %d" + raise ValueError(msg % (error_template, size, len(data))) + else: + return data + + +@contextmanager +def _open_npy_file(path: str, arr_name: str): + with open(path, "rb") as f: + with zipfile.ZipFile(f, "r") as zip_f: + if f"{arr_name}.npy" not in zip_f.namelist(): + raise ValueError(f"missing {arr_name} in npz file") + with zip_f.open(f"{arr_name}.npy", "r") as arr_f: + yield arr_f + + +def _download_inception_model(): + if os.path.exists(INCEPTION_V3_PATH): + return + print("downloading InceptionV3 model...") + with requests.get(INCEPTION_V3_URL, stream=True) as r: + r.raise_for_status() + tmp_path = INCEPTION_V3_PATH + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in tqdm(r.iter_content(chunk_size=8192)): + f.write(chunk) + os.rename(tmp_path, INCEPTION_V3_PATH) + + +def _create_feature_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + pool3, spatial = tf.import_graph_def( + graph_def, + input_map={f"ExpandDims:0": input_batch}, + return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], + name=prefix, + ) + _update_shapes(pool3) + spatial = spatial[..., :7] + return pool3, spatial + + +def _create_softmax_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + (matmul,) = tf.import_graph_def( + graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix + ) + w = matmul.inputs[1] + logits = tf.matmul(input_batch, w) + return tf.nn.softmax(logits) + + +def _update_shapes(pool3): + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 + ops = pool3.graph.get_operations() + for op in ops: + for o in op.outputs: + shape = o.get_shape() + if shape._dims is not None: # pylint: disable=protected-access + # shape = [s.value for s in shape] TF 1.x + shape = [s for s in shape] # TF 2.x + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__["_shape_val"] = tf.TensorShape(new_shape) + return pool3 + + +def _numpy_partition(arr, kth, **kwargs): + num_workers = min(cpu_count(), len(arr)) + chunk_size = len(arr) // num_workers + extra = len(arr) % num_workers + + start_idx = 0 + batches = [] + for i in range(num_workers): + size = chunk_size + (1 if i < extra else 0) + batches.append(arr[start_idx : start_idx + size]) + start_idx += size + + with ThreadPool(num_workers) as pool: + return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/models.py b/Rectified_Noise/VP-Disp/models.py new file mode 100644 index 0000000000000000000000000000000000000000..50e817b8c9f7a1a39dc0b4cf0563777ff833a100 --- /dev/null +++ b/Rectified_Noise/VP-Disp/models.py @@ -0,0 +1,647 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import numpy as np +import math +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core SiT Model # +################################################################################# + +class SiTBlock(nn.Module): + """ + A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of SiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.learn_sigma = True + self.in_channels = in_channels + self.out_channels = in_channels * 2 + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in SiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, t, y, return_act=False): + """ + Forward pass of SiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + return_act: if True, return activations from transformer blocks + """ + act = [] + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + for block in self.blocks: + x = block(x, c) # (N, T, D) + if return_act: + act.append(x) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + if self.learn_sigma: + x, _ = x.chunk(2, dim=1) + if return_act: + return x, act + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# SiT Configs # +################################################################################# + +def SiT_XL_2(**kwargs): + return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +def SiT_XL_4(**kwargs): + return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + +def SiT_XL_8(**kwargs): + return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + +def SiT_L_2(**kwargs): + return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + +def SiT_L_4(**kwargs): + return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + +def SiT_L_8(**kwargs): + return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + +def SiT_B_2(**kwargs): + return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + +def SiT_B_4(**kwargs): + return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + +def SiT_B_8(**kwargs): + return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + +def SiT_S_2(**kwargs): + return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + +def SiT_S_4(**kwargs): + return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + +def SiT_S_8(**kwargs): + return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) + + +SiT_models = { + 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8, + 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8, + 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8, + 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8, +} + +################################################################################# +# SiTF1, SiTF2, CombinedModel # +################################################################################# + +class SiTF1(nn.Module): + """ + SiTF1 Model + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + final_layer=None, + ): + super().__init__() + self.input_size = input_size + self.patch_size= patch_size + self.hidden_size = hidden_size + self.in_channels = in_channels + self.out_channels = in_channels * 2 + self.patch_size = patch_size + self.num_heads = num_heads + self.learn_sigma = learn_sigma + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList([ + SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def forward(self, x, t, y): + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x_now = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x_now = self.unpatchify(x_now) # (N, out_channels, H, W) + x_now, _ = x_now.chunk(2, dim=1) + return x,x_now # patch token (N, T, D) + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass with classifier-free guidance for SiTF1. + Applies guidance consistently to both patch tokens and image output (x_now). + """ + # Take the first half (conditional inputs) and duplicate it so that + # it can be paired with conditional and unconditional labels in `y`. + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + patch_tokens, x_now = self.forward(combined, t, y) + + # Apply CFG on the image output channels (first 3 channels by default) + eps, rest = x_now[:, :3, ...], x_now[:, 3:, ...] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + x_now = torch.cat([eps, rest], dim=1) + + # Apply same guidance logic to patch tokens so downstream modules see + # a consistent guided representation. + cond_tok, uncond_tok = torch.split(patch_tokens, len(patch_tokens) // 2, dim=0) + half_tok = uncond_tok + cfg_scale * (cond_tok - uncond_tok) + patch_tokens = torch.cat([half_tok, half_tok], dim=0) + + return patch_tokens, x_now + + +class SiTF2(nn.Module): + """ + SiTF2: + """ + def __init__( + self, + input_size=32, + hidden_size=1152, + out_channels=8, + patch_size=2, + num_heads=16, + mlp_ratio=4.0, + depth=4, + learn_sigma=True, + final_layer=None, + num_classes=1000, + class_dropout_prob=0.1, + learn_mu=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.learn_mu = learn_mu + self.out_channels = out_channels + self.in_channels = 4 + self.patch_size = patch_size + self.num_heads = num_heads + self.blocks = nn.ModuleList([ + SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.x_embedder = PatchEmbed(input_size, patch_size, self.in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.num_patches = num_patches # Save original num_patches for unpatchify + # pos_embed needs to support 2*num_patches for concatenated input + self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False) + # Initialize pos_embed with sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5)) + # Repeat the pos_embed for both halves (or could use different embeddings) + pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0)) + + if final_layer is not None: + self.final_layer = final_layer + else: + self.final_layer = FinalLayer(hidden_size, patch_size, out_channels) + # if depth !=0: + # for p in self.final_layer.parameters(): + # if p is not None: + # torch.nn.init.constant_(p, 0) + + def unpatchify(self, x, patch_size, out_channels): + c = out_channels + p = patch_size + # x.shape[1] might be 2*num_patches when using concatenated input + # Use original num_patches to calculate h and w + h = w = int(self.num_patches ** 0.5) + # If input has 2*num_patches, we need to handle it + if x.shape[1] == 2 * self.num_patches: + # Take only the first half (or average, or other strategy) + # For now, we'll take the first half + x = x[:, :self.num_patches, :] + assert h * w == x.shape[1], f"Expected {h * w} patches, got {x.shape[1]}" + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, c, t, return_act=False): + act = [] + for block in self.blocks: + x = block(x, c) + if return_act: + act.append(x) + x = self.final_layer(x, c) + x = self.unpatchify(x, self.patch_size, self.out_channels) + if self.learn_sigma: + mean_pred, log_var_pred = x.chunk(2, dim=1) + variance_pred = torch.exp(log_var_pred) + std_dev_pred = torch.sqrt(variance_pred) + noise = torch.randn_like(mean_pred) + #uniform_noise = torch.rand_like(mean_pred) + #uniform_noise = uniform_noise.clamp(min=1e-5, max=1-1e-5) + #gumbel_noise = -torch.log(-torch.log(uniform_noise)) + + if self.learn_mu==True: + resampled_x = mean_pred + std_dev_pred * noise + else: + resampled_x = std_dev_pred * noise + x = resampled_x + else: + x, _ = x.chunk(2, dim=1) + if return_act: + return x, act + return x + + def forward_noise(self, x, c): + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + x = self.unpatchify(x, self.patch_size, self.out_channels) + if self.learn_sigma: + mean_pred, log_var_pred = x.chunk(2, dim=1) + variance_pred = torch.exp(log_var_pred) + std_dev_pred = torch.sqrt(variance_pred) + noise = torch.randn_like(mean_pred) + if self.learn_mu==True: + resampled_x = mean_pred + std_dev_pred * noise + else: + resampled_x = std_dev_pred * noise + x = resampled_x + else: + x, _ = x.chunk(2, dim=1) + return x + +#有两种写法,一种是拿理想的,一种是拿真实的,一种是拼接,一种是加和 +class CombinedModel(nn.Module): + """ + CombinedModel。 + """ + def __init__(self, sitf1: SiTF1, sitf2: SiTF2): + super().__init__() + self.sitf1 = sitf1 + self.sitf2 = sitf2 + input_size=self.sitf1.input_size + patch_size=self.sitf1.patch_size + hidden_size=self.sitf1.hidden_size + self.x_embedder = PatchEmbed(input_size, patch_size, 4, hidden_size, bias=True) + num_patches = self.x_embedder.num_patches + # pos_embed needs to support 2*num_patches for concatenated input + self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False) + # Initialize pos_embed with sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5)) + # Repeat the pos_embed for both halves (or could use different embeddings) + pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0)) + + def forward(self, x, t, y, return_act=False): + patch_tokens,x_now = self.sitf1(x, t, y) + # Interpolate between x_now and x using timestep t: (1-t)*x_now + t*x + # t shape is (N,), need to broadcast to (N, 1, 1, 1) for broadcasting with image (N, C, H, W) + t_broadcast = t.view(-1, 1, 1, 1) # (N, 1, 1, 1) + # Compute interpolated input: (1-t)*x_now + t*x + x_interpolated = (1 - t_broadcast) * x_now + x + # Convert interpolated input (image format) back to patch token format (without pos_embed, will add later) + x_now_patches = self.x_embedder(x_interpolated) + # Concatenate patch_tokens and x_now_patches along the sequence dimension + concatenated_input = torch.cat([patch_tokens, x_now_patches], dim=1) # (N, 2*T, D) + # Add position embedding for the concatenated input + # Use the same pos_embed for both halves (or could use different embeddings) + concatenated_input = concatenated_input + self.pos_embed + t_emb = self.sitf1.t_embedder(t) + y_emb = self.sitf1.y_embedder(y, self.training) + c = t_emb + y_emb + return self.sitf2(concatenated_input, c, t, return_act=return_act) \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/run.sh b/Rectified_Noise/VP-Disp/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..141029ec65df46697a374637968bcaf55d43727b --- /dev/null +++ b/Rectified_Noise/VP-Disp/run.sh @@ -0,0 +1,14 @@ +nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=4 \ + --rdzv_endpoint=localhost:29761 \ + train_rectified_noise.py \ + --depth 2 \ + --results-dir results_256_vp_disp \ + --data-path /gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/ \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_VP/base.pt \ + --num-classes 1000 \ + --path-type VP \ + --prediction velocity \ + --disp \ + > w_training1.log 2>&1 & diff --git a/Rectified_Noise/VP-Disp/sample_ddp.py b/Rectified_Noise/VP-Disp/sample_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e91da98a11a17d8f59d10e9f79eb9711e34324 --- /dev/null +++ b/Rectified_Noise/VP-Disp/sample_ddp.py @@ -0,0 +1,233 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Samples a large number of images from a pre-trained SiT model using DDP. +Subsequently saves a .npz file that can be used to compute FID and other +evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations + +For a simple single-GPU/CPU sampling script, see sample.py. +""" +import torch +import torch.distributed as dist +from models import SiT_models +from download import find_model +from transport import create_transport, Sampler +from diffusers.models import AutoencoderKL +from train_utils import parse_ode_args, parse_sde_args, parse_transport_args +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import math +import argparse +import sys + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def main(mode, args): + """ + Run sampling. + """ + torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + if args.ckpt is None: + assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available + learn_sigma = args.image_size == 256 + else: + learn_sigma = False + + # Load model: + latent_size = args.image_size // 8 + model = SiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes, + learn_sigma=learn_sigma, + ).to(device) + # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: + ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt" + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + model.eval() # important! + + + transport = create_transport( + args.path_type, + args.prediction, + args.loss_weight, + args.train_eps, + args.sample_eps + ) + sampler = Sampler(transport) + if mode == "ODE": + if args.likelihood: + assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" + sample_fn = sampler.sample_ode_likelihood( + sampling_method=args.sampling_method, + num_steps=args.num_sampling_steps, + atol=args.atol, + rtol=args.rtol, + ) + else: + sample_fn = sampler.sample_ode( + sampling_method=args.sampling_method, + num_steps=args.num_sampling_steps, + atol=args.atol, + rtol=args.rtol, + reverse=args.reverse + ) + elif mode == "SDE": + sample_fn = sampler.sample_sde( + sampling_method=args.sampling_method, + diffusion_form=args.diffusion_form, + diffusion_norm=args.diffusion_norm, + last_step=args.last_step, + last_step_size=args.last_step_size, + num_steps=args.num_sampling_steps, + ) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" + using_cfg = args.cfg_scale > 1.0 + + # Create folder to save samples: + model_string_name = args.model.replace("/", "-") + ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" + if mode == "ODE": + folder_name = f"{model_string_name}-{ckpt_string_name}-" \ + f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ + f"{mode}-{args.num_sampling_steps}-{args.sampling_method}" + elif mode == "SDE": + folder_name = f"{model_string_name}-{ckpt_string_name}-" \ + f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ + f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\ + f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: + num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)]) + total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) + if rank == 0: + print(f"Total number of images that will be sampled: {total_samples}") + assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" + samples_needed_this_gpu = int(total_samples // dist.get_world_size()) + assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" + iterations = int(samples_needed_this_gpu // n) + done_iterations = int( int(num_samples // dist.get_world_size()) // n) + pbar = range(iterations) + pbar = tqdm(pbar) if rank == 0 else pbar + total = 0 + + for i in pbar: + # Sample inputs: + z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) + y = torch.randint(0, args.num_classes, (n,), device=device) + + # Setup classifier-free guidance: + if using_cfg: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) + model_fn = model.forward_with_cfg + else: + model_kwargs = dict(y=y) + model_fn = model.forward + + samples = sample_fn(z, model_fn, **model_kwargs)[-1] + if using_cfg: + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + + samples = vae.decode(samples / 0.18215).sample + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + total += global_batch_size + dist.barrier() + + # Make sure all processes have finished saving their samples before attempting to convert to .npz + dist.barrier() + if rank == 0: + create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) + print("Done.") + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + if len(sys.argv) < 2: + print("Usage: program.py [options]") + sys.exit(1) + + mode = sys.argv[1] + + assert mode[:2] != "--", "Usage: program.py [options]" + assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" + + parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") + parser.add_argument("--sample-dir", type=str, default="samples") + parser.add_argument("--per-proc-batch-size", type=int, default=4) + parser.add_argument("--num-fid-samples", type=int, default=50_000) + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=1.0) + parser.add_argument("--num-sampling-steps", type=int, default=250) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, + help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).") + + parse_transport_args(parser) + if mode == "ODE": + parse_ode_args(parser) + # Further processing for ODE + elif mode == "SDE": + parse_sde_args(parser) + # Further processing for SDE + + args = parser.parse_known_args()[0] + main(mode, args) diff --git a/Rectified_Noise/VP-Disp/sample_rectified_noise.py b/Rectified_Noise/VP-Disp/sample_rectified_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..769872830a06009b874d443aa274fea9d1bab0a8 --- /dev/null +++ b/Rectified_Noise/VP-Disp/sample_rectified_noise.py @@ -0,0 +1,380 @@ +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from models import SiT_models +from download import find_model +from transport import create_transport, Sampler +from diffusers.models import AutoencoderKL +from train_utils import parse_ode_args, parse_sde_args, parse_transport_args +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import math +import argparse +import sys + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def fix_state_dict_for_ddp(state_dict): + """ + Fix state dict keys to match DistributedDataParallel model keys. + Add "module." prefix to keys if they don't have it. + """ + # Check if this is a full checkpoint dict with "model", "ema", or "opt" keys + if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict or "opt" in state_dict): + # This is a full checkpoint dict, extract the state dict we need + # Prefer "ema" then "model" then return as is + if "ema" in state_dict: + state_dict = state_dict["ema"] + elif "model" in state_dict: + state_dict = state_dict["model"] + else: + # If only "opt" or other keys exist, return original + state_dict = state_dict + + # Now fix the keys to match DDP format + fixed_state_dict = {} + for key, value in state_dict.items(): + if not key.startswith("module."): + new_key = "module." + key + else: + new_key = key + fixed_state_dict[new_key] = value + return fixed_state_dict + +def main(mode, args): + """ + Run sampling. + """ + torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + learn_mu = args.learn_mu + sitf2_depth = args.depth # Save SiTF2 depth before it gets overwritten + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + if args.ckpt is None: + assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available + learn_sigma = args.image_size == 256 + else: + learn_sigma = False + + # Load SiTF1 and SiTF2 models and create CombinedModel + from models import SiTF1, SiTF2, CombinedModel + latent_size = args.image_size // 8 + + # Get model configuration based on args.model + model_name = args.model + if 'XL' in model_name: + hidden_size, depth, num_heads = 1152, 28, 16 + elif 'L' in model_name: + hidden_size, depth, num_heads = 1024, 24, 16 + elif 'B' in model_name: + hidden_size, depth, num_heads = 768, 12, 12 + elif 'S' in model_name: + hidden_size, depth, num_heads = 384, 12, 6 + else: + # Default fallback + hidden_size, depth, num_heads = 768, 12, 12 + + # Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2 + patch_size = int(model_name.split('/')[-1]) + + # Load SiTF1 + sitf1 = SiTF1( + input_size=latent_size, + patch_size=patch_size, + in_channels=4, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=args.num_classes, + learn_sigma=False + ).to(device) + sitf1_state_raw = find_model(args.ckpt) + # find_model now returns ema if available, or the full checkpoint + # Extract the actual state_dict to use for both sitf1 and base_model + if isinstance(sitf1_state_raw, dict) and "model" in sitf1_state_raw: + sitf1_state = sitf1_state_raw["model"] + else: + # sitf1_state_raw is already a state_dict (either ema or direct model state) + sitf1_state = sitf1_state_raw + sitf1.load_state_dict(sitf1_state) + sitf1.eval() + + # For sampling, we can use sitf1 directly instead of creating a separate sit model + # since sitf1 and sit have the same architecture and weights + + # Load SiTF2 with the same architecture parameters as SiTF1 for compatibility + sitf2 = SiTF2( + input_size=latent_size, + hidden_size=hidden_size, # Use the same hidden_size as SiTF1 + out_channels=8, + patch_size=patch_size, # Use the same patch_size as SiTF1 + num_heads=num_heads, # Use the same num_heads as SiTF1 + mlp_ratio=4.0, + depth=sitf2_depth, # Use the depth specified by command line argument (not the model's default depth) + learn_sigma=True, + num_classes=args.num_classes, + learn_mu=learn_mu + ).to(device) + sitf2 = DDP(sitf2, device_ids=[device]) + sitf2_state = find_model(args.sitf2_ckpt) + # Fix state dict keys to match DDP model + sitf2_state_fixed = fix_state_dict_for_ddp(sitf2_state) + try: + sitf2.load_state_dict(sitf2_state_fixed) + except Exception as e: + print(f"Error loading state dict: {e}") + # Try loading with strict=False as fallback + sitf2.load_state_dict(sitf2_state_fixed, strict=False) + sitf2.eval() + # CombinedModel + + combined_model = CombinedModel(sitf1, sitf2).to(device) + sitf2.eval() + combined_model.eval() + + # Use SiT_models factory function to create the base model, same as in SiT_clean + # This ensures correct model configuration + # Use learn_sigma=False to match sitf1 configuration + base_model = SiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes, + learn_sigma=False, # Match sitf1's learn_sigma=False + ).to(device) + # Load the checkpoint (same as sitf1) - use the exact same state_dict + base_model.load_state_dict(sitf1_state) + base_model.eval() + + # Determine if CFG will be used (needed for combined_sampling_model function) + assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" + using_cfg = args.cfg_scale > 1.0 + + # There are repeated calculations in the middle, + # which will cause Flops to double. A simplified version will be released later + def combined_sampling_model(x, t, y=None, **kwargs): + with torch.no_grad(): + # Handle CFG same as in SiT_clean/sample_ddp.py + if using_cfg and 'cfg_scale' in kwargs: + # Use forward_with_cfg when CFG is enabled + sit_out = base_model.forward_with_cfg(x, t, y, kwargs['cfg_scale']) + else: + # Use regular forward when CFG is disabled + sit_out = base_model.forward(x, t, y) + # If use_sitf2_before_t05 is True, only use sitf2 when t < threshold + if args.use_sitf2: + if args.use_sitf2_before_t05: + # t is a tensor, check which samples have t < threshold + # Create a mask: 1.0 where t < threshold, 0.0 otherwise + mask = (t < args.sitf2_threshold).float() + # Compute sitf2 output for all samples + combined_out = combined_model.forward(x, t, y) + # Expand mask to match the spatial dimensions of combined_out + # combined_out shape is (batch, channels, height, width) + while len(mask.shape) < len(combined_out.shape): + mask = mask.unsqueeze(-1) + # Broadcast mask to match combined_out shape + mask = mask.expand_as(combined_out) + # Only use sitf2 output where t < threshold + combined_out = combined_out * mask + # Combine sit_out and masked combined_out + return sit_out + combined_out + else: + # Default behavior: only use base model output + return sit_out + else: + # Default behavior: only use base model output + return sit_out + + transport = create_transport( + args.path_type, + args.prediction, + args.loss_weight, + args.train_eps, + args.sample_eps + ) + sampler = Sampler(transport) + if mode == "ODE": + if args.likelihood: + assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" + sample_fn = sampler.sample_ode_likelihood( + sampling_method=args.sampling_method, + num_steps=args.num_sampling_steps, + atol=args.atol, + rtol=args.rtol, + ) + else: + sample_fn = sampler.sample_ode( + sampling_method=args.sampling_method, + num_steps=args.num_sampling_steps, + atol=args.atol, + rtol=args.rtol, + reverse=args.reverse + ) + elif mode == "SDE": + sample_fn = sampler.sample_sde( + sampling_method=args.sampling_method, + diffusion_form=args.diffusion_form, + diffusion_norm=args.diffusion_norm, + last_step=args.last_step, + last_step_size=args.last_step_size, + num_steps=args.num_sampling_steps, + ) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + + # Create folder to save samples: + model_string_name = args.model.replace("/", "-") + ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" + sitf2_ckpt_string_name = os.path.basename(args.sitf2_ckpt).replace(".pt", "") if args.ckpt else "pretrained" + if mode == "ODE": + folder_name = f"{sitf2_ckpt_string_name}-{ckpt_string_name}-" \ + f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ + f"{mode}-{args.num_sampling_steps}-{args.sampling_method}" + elif mode == "SDE": + # Add threshold info to folder name if use_sitf2_before_t05 is enabled + threshold_suffix = f"-threshold-{args.sitf2_threshold}" if args.use_sitf2_before_t05 else "" + if learn_mu: + folder_name = f"depth-mu-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \ + f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ + f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\ + f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}" + else: + folder_name = f"depth-sigma-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \ + f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ + f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\ + f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: + num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)]) + total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) + if rank == 0: + print(f"Total number of images that will be sampled: {total_samples}") + assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" + samples_needed_this_gpu = int(total_samples // dist.get_world_size()) + assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" + iterations = int(samples_needed_this_gpu // n) + done_iterations = int( int(num_samples // dist.get_world_size()) // n) + pbar = range(iterations) + pbar = tqdm(pbar) if rank == 0 else pbar + total = 0 + + for i in pbar: + # Sample inputs: + z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device) + y = torch.randint(0, args.num_classes, (n,), device=device) + # Setup classifier-free guidance: + if using_cfg: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) + else: + model_kwargs = dict(y=y) + samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1] + if using_cfg: + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae.decode(samples / 0.18215).sample + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + # Save samples to disk as individual .png files + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + total += global_batch_size + dist.barrier() + + # Make sure all processes have finished saving their samples before attempting to convert to .npz + dist.barrier() + if rank == 0: + create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) + print("Done.") + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + if len(sys.argv) < 2: + print("Usage: program.py [options]") + sys.exit(1) + + mode = sys.argv[1] + + assert mode[:2] != "--", "Usage: program.py [options]" + assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" + + parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") + parser.add_argument("--sample-dir", type=str, default="samples") + parser.add_argument("--per-proc-batch-size", type=int, default=64) + parser.add_argument("--num-fid-samples", type=int, default=50_000) + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=1.0) + parser.add_argument("--num-sampling-steps", type=int, default=100) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, + help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a SiT checkpoint.") + parser.add_argument("--sitf2-ckpt", type=str, required=True, help="Path to SiTF2 checkpoint") + parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True, + help="Whether to learn mu parameter") + parser.add_argument("--depth", type=int, default=1, + help="Depth parameter for SiTF2 model") + parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True, + help="Only use SiTF2 output when t < threshold, otherwise use only SiT") + parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False, + help="Only use SiTF2 output when t < threshold, otherwise use only SiT") + parser.add_argument("--sitf2-threshold", type=float, default=0.5, + help="Time threshold for using SiTF2 output (default: 0.5). Only effective when --use-sitf2-before-t05 is True") + parse_transport_args(parser) + if mode == "ODE": + parse_ode_args(parser) + # Further processing for ODE + elif mode == "SDE": + parse_sde_args(parser) + # Further processing for SDE + + args = parser.parse_known_args()[0] + main(mode, args) diff --git a/Rectified_Noise/VP-Disp/test.sh b/Rectified_Noise/VP-Disp/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..ecf30e40b672f9e5e57d0369666e308e24da9870 --- /dev/null +++ b/Rectified_Noise/VP-Disp/test.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Execute all four commands in parallel +# Each command runs in the background using & + +echo "Starting all four sampling tasks in parallel..." + +CUDA_VISIBLE_DEVICES=0 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29910 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir VP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2 False \ + --use-sitf2-before-t05 False \ + --sitf2-threshold 0.0 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_VP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss-vp/results_256_vp_disp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0175000.pt \ + > W_No.log 2>&1 & + +CUDA_VISIBLE_DEVICES=1 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29950 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir VP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2-before-t05 False \ + --sitf2-threshold 1.0 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_VP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss-vp/results_256_vp_disp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0175000.pt \ + > W_False.log 2>&1 & + + +CUDA_VISIBLE_DEVICES=2 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29952 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir VP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2-before-t05 True \ + --sitf2-threshold 0.5 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_VP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss-vp/results_256_vp_disp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0175000.pt \ + > W_True_0.5.log 2>&1 & + +CUDA_VISIBLE_DEVICES=3 nohup torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=localhost:29921 \ + sample_rectified_noise.py SDE \ + --depth 2 \ + --sample-dir VP_samples \ + --model SiT-XL/2 \ + --num-fid-samples 3000 \ + --num-classes 1000 \ + --global-seed 0 \ + --use-sitf2-before-t05 True \ + --sitf2-threshold 0.15 \ + --ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/SiT_clean_256_VP/base.pt \ + --sitf2-ckpt /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise-Dispersive-Loss-vp/results_256_vp_disp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0175000.pt \ + > W_True_0.15.log 2>&1 & + +# Wait for all background jobs to complete +echo "All tasks started. Waiting for completion..." +wait + +echo "All tasks completed!" \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/train_rectified_noise.py b/Rectified_Noise/VP-Disp/train_rectified_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b10906c109a1e56a5003e5b3d86cabfcb604a6 --- /dev/null +++ b/Rectified_Noise/VP-Disp/train_rectified_noise.py @@ -0,0 +1,440 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for SiT using PyTorch DDP. +""" +import torch +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +import numpy as np +from collections import OrderedDict +from PIL import Image +from copy import deepcopy +from glob import glob +from time import time +import argparse +import logging +import os + +from models import SiT, SiTF1, SiTF2, CombinedModel +from models import SiT_models +from download import find_model +from transport import create_transport, Sampler +from diffusers.models import AutoencoderKL +from train_utils import parse_transport_args + + + +################################################################################# +# Training Helper Functions # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format='[\033[34m%(asctime)s\033[0m] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + """ + Trains a new SiT model. + """ + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + dist.init_process_group("nccl") + assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + local_batch_size = int(args.global_batch_size // dist.get_world_size()) + learn_mu = args.learn_mu + depth = args.depth + # Setup an experiment folder: + if rank == 0: + os.makedirs(args.results_dir, exist_ok=True) + experiment_index = len(glob(f"{args.results_dir}/*")) + model_string_name = args.model.replace("/", "-") + if learn_mu: + experiment_name = f"depth-mu-{args.depth}-{experiment_index:03d}-{model_string_name}-" \ + f"{args.path_type}-{args.prediction}-{args.loss_weight}" + else: + experiment_name = f"depth-sigma-{args.depth}-{experiment_index:03d}-{model_string_name}-" \ + f"{args.path_type}-{args.prediction}-{args.loss_weight}" + experiment_dir = f"{args.results_dir}/{experiment_name}" + checkpoint_dir = f"{experiment_dir}/checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + logger = create_logger(experiment_dir) + logger.info(f"Experiment directory created at {experiment_dir}") + + else: + logger = create_logger(None) + + # Create models: + assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." + latent_size = args.image_size // 8 + + # Get model configuration based on args.model + model_config = SiT_models[args.model] + model_kwargs = model_config().__dict__ # Get the default parameters for this model + + # Extract parameters from the model configuration based on the model name + # Model names follow the format like 'SiT-XL/2', 'SiT-B/4', etc. + model_name = args.model + if 'XL' in model_name: + hidden_size, depth, num_heads = 1152, 28, 16 + elif 'L' in model_name: + hidden_size, depth, num_heads = 1024, 24, 16 + elif 'B' in model_name: + hidden_size, depth, num_heads = 768, 12, 12 + elif 'S' in model_name: + hidden_size, depth, num_heads = 384, 12, 6 + else: + # Default fallback + hidden_size, depth, num_heads = 768, 12, 12 + + # Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2 + patch_size = int(model_name.split('/')[-1]) + + sitf1 = SiTF1( + input_size=latent_size, + patch_size=patch_size, + in_channels=4, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=args.num_classes, + learn_sigma=False + ).to(device) + sit = SiT( + input_size=latent_size, + patch_size=patch_size, + in_channels=4, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=args.num_classes, + learn_sigma=False + ).to(device) + sitf2 = SiTF2( + input_size=latent_size, + hidden_size=hidden_size, + out_channels=8, + patch_size=patch_size, + num_heads=num_heads, + mlp_ratio=4.0, + depth=args.depth, # Use the depth for sitf2 as specified by command line + learn_sigma=True, + num_classes=args.num_classes, + learn_mu=learn_mu + ).to(device) + sitf2_ema = deepcopy(sitf2).to(device) + combined_model = CombinedModel(sitf1, sitf2).to(device) + + if args.ckpt is not None: + ckpt_path = args.ckpt + state_dict = find_model(ckpt_path) + try: + sitf1.load_state_dict(state_dict["model"], strict=False) + sit.load_state_dict(state_dict["model"], strict=False) + except: + sitf1.load_state_dict(state_dict, strict=False) + sit.load_state_dict(state_dict, strict=False) + + + requires_grad(sitf1, False) + requires_grad(sit, False) + requires_grad(sitf2, True) + + opt = torch.optim.AdamW(sitf2.parameters(), lr=1e-4, weight_decay=0) + # Do NOT wrap sitf2 separately in DDP (avoids double-wrapping submodules); wrap only the combined model. + combined_model = DDP(combined_model, device_ids=[rank], find_unused_parameters=True) + + # Create transport object: path_type determines the loss form used in training_losses() + # path_type options: "Linear", "GVP", "VP" - each corresponds to a different loss calculation method + transport = create_transport( + args.path_type, # This directly affects how loss is computed in training_losses() + args.prediction, + args.loss_weight, + args.train_eps, + args.sample_eps, + args.disp_loss_weight, + args.temperature + ) + transport_sampler = Sampler(transport) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + logger.info(f"Combined_model Parameters: {sum(p.numel() for p in combined_model.parameters()):,}") + + grad_params = [(n, p.numel()) for n, p in combined_model.named_parameters() if p.requires_grad] + logger.info(f"Total trainable parameters: {sum(cnt for _, cnt in grad_params):,}") + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + dataset = ImageFolder(args.data_path, transform=transform) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=True, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=local_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True + ) + logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") + # Ensure EMA updates target the correct base model (whether sitf2 is wrapped or not) + base_sitf2 = sitf2.module if isinstance(sitf2, torch.nn.parallel.DistributedDataParallel) else sitf2 + update_ema(sitf2_ema, base_sitf2, decay=0) + sitf1.eval() + sit.eval() + sitf2.train() + sitf2_ema.eval() + + train_steps = 0 + log_steps = 0 + running_loss = 0 + start_time = time() + ys = torch.randint(1000, size=(local_batch_size,), device=device) + use_cfg = args.cfg_scale > 1.0 + n = ys.size(0) + zs = torch.randn(n, 4, latent_size, latent_size, device=device) + if use_cfg: + zs = torch.cat([zs, zs], 0) + y_null = torch.tensor([1000] * n, device=device) + ys = torch.cat([ys, y_null], 0) + sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale) + model_fn = sitf1.forward_with_cfg + else: + sample_model_kwargs = dict(y=ys) + model_fn = sitf1.forward + def combined_sampling_model(x, t, y=None, **kwargs): + with torch.no_grad(): + sit_out = sit.forward(x, t, y) + combined_out = combined_model.forward(x, t, y) + return sit_out + combined_out + logger.info(f"Training for {args.epochs} epochs...") + for epoch in range(args.epochs): + sampler.set_epoch(epoch) + logger.info(f"Beginning epoch {epoch}...") + for x, y in loader: + x = x.to(device) + y = y.to(device) + with torch.no_grad(): + x_latent = vae.encode(x).latent_dist.sample().mul_(0.18215) + model_kwargs = dict(y=y, return_act=args.disp) + # Compute training loss: the loss form depends on args.path_type (Linear/GVP/VP) + # Each path_type uses a different mathematical formulation for the transport loss + loss_dict = transport.training_losses(sit, x_latent, model_noise=combined_model, model_kwargs=model_kwargs) + loss = loss_dict["loss"].mean() + + # Check for NaN/Inf loss before backward + if torch.isnan(loss) or torch.isinf(loss): + if rank == 0: + logger.warning(f"NaN/Inf loss detected at step {train_steps}, skipping this batch. Loss: {loss.item()}") + continue + + opt.zero_grad() + loss.backward() + + # Gradient clipping for numerical stability (especially important for VP path) + torch.nn.utils.clip_grad_norm_(sitf2.parameters(), max_norm=1.0) + + opt.step() + # Update EMA of the trainable sitf2 base model + update_ema(sitf2_ema, base_sitf2) + running_loss += loss.item() + log_steps += 1 + train_steps += 1 + if train_steps % args.log_every == 0: + torch.cuda.synchronize() + end_time = time() + steps_per_sec = log_steps / (end_time - start_time) + avg_loss = torch.tensor(running_loss / log_steps, device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + avg_loss = avg_loss.item() / dist.get_world_size() + logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") + running_loss = 0 + log_steps = 0 + start_time = time() + if train_steps % args.ckpt_every == 0 and train_steps > 0: + print(train_steps) + if rank == 0: + checkpoint = { + "model": sitf2.state_dict(), + "ema": sitf2.state_dict(), + "opt": opt.state_dict(), + "args": args + } + checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" + torch.save(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + dist.barrier() + + if (train_steps % args.sample_every == 0 )and train_steps > 0: + logger.info("Generating EMA samples...") + if epoch == args.epochs: + break + + sitf1.eval() + sit.eval() + sitf2.eval() + logger.info("Final sampling done.") + + logger.info("Done!") + cleanup() + + +def save_samples_grid(out_samples, epoch, experiment_index, args, experiment_name, rank): + if rank == 0: + import os + import numpy as np + from PIL import Image + parent_dir = os.path.dirname(args.results_dir) + pic_dir = os.path.join(parent_dir, "pic") + os.makedirs(pic_dir, exist_ok=True) + experiment_pic_dir = os.path.join(pic_dir, experiment_name) + os.makedirs(experiment_pic_dir, exist_ok=True) + samples_np = torch.clamp(127.5 * out_samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + n_samples = samples_np.shape[0] + grid_size = int(np.ceil(np.sqrt(n_samples))) + canvas_size = grid_size * args.image_size + canvas = np.zeros((canvas_size, canvas_size, 3), dtype=np.uint8) + for i, sample in enumerate(samples_np): + row = i // grid_size + col = i % grid_size + canvas[row*args.image_size:(row+1)*args.image_size, col*args.image_size:(col+1)*args.image_size] = sample + combined_image = Image.fromarray(canvas) + combined_image.save(os.path.join(experiment_pic_dir, f"epoch_{epoch:04d}_combined.png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--results-dir", type=str, default="results_256_linear") + parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=3) + parser.add_argument("--epochs", type=int, default=100000) + parser.add_argument("--global-batch-size", type=int, default=256) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--ckpt-every", type=int, default=25000) + parser.add_argument("--sample-every", type=int, default=25192) + parser.add_argument("--cfg-scale", type=float, default=4.0) + parser.add_argument("--ckpt", type=str, default='/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/2000000.pt', + help="Optional path to a custom SiT checkpoint") + parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True, + help="Whether to learn mu parameter") + parser.add_argument("--depth", type=int, default=1, + help="Depth parameter for SiTF2 model") + parser.add_argument("--disp", action="store_true", + help="Toggle to enable Dispersive Loss") + parser.add_argument("--disp-loss-weight", type=float, default=0.5, + help="Weight λ for dispersive loss (default: 0.5)") + parser.add_argument("--temperature", type=float, default=1.0, + help="Temperature τ for dispersive loss (default: 1.0)") + + # Transport arguments (added by parse_transport_args): + # --path-type: Type of path for loss calculation (default: "GVP") + # Choices: "Linear" (linear interpolation), "GVP" (Geodesic Velocity Path), "VP" (Velocity Path) + # IMPORTANT: This parameter directly affects the loss form computed by transport.training_losses() + # The path_type determines how the transport loss is calculated during training. + # Make sure to use the correct path_type that matches your training objective. + # --prediction: Type of prediction (default: "velocity") + # --loss-weight: Loss weight type (default: None) + # --sample-eps, --train-eps: Epsilon values for sampling and training + parse_transport_args(parser) + args = parser.parse_args() + main(args) diff --git a/Rectified_Noise/VP-Disp/train_utils.py b/Rectified_Noise/VP-Disp/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1f22303350f853df4dd8bd8e1736a338f706fa44 --- /dev/null +++ b/Rectified_Noise/VP-Disp/train_utils.py @@ -0,0 +1,35 @@ +def none_or_str(value): + if value == 'None': + return None + return value + +def parse_transport_args(parser): + group = parser.add_argument_group("Transport arguments") + group.add_argument("--path-type", type=str, default="VP", choices=["Linear", "GVP", "VP"], + help="Type of path for loss calculation. This parameter directly affects the loss form used during training. " + "Choices: Linear (linear interpolation path), GVP (Geodesic Velocity Path), VP (Velocity Path). " + "The path_type determines how the transport loss is computed in training_losses().") + group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"]) + group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"]) + group.add_argument("--sample-eps", type=float, default=0.0) + group.add_argument("--train-eps", type=float, default=0.0) + +def parse_ode_args(parser): + group = parser.add_argument_group("ODE arguments") + group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq") + group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance") + group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") + group.add_argument("--reverse", action="store_true") + group.add_argument("--likelihood", action="store_true") + +def parse_sde_args(parser): + group = parser.add_argument_group("SDE arguments") + group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"]) + group.add_argument("--diffusion-form", type=str, default="sigma", \ + choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\ + help="form of diffusion coefficient in the SDE") + group.add_argument("--diffusion-norm", type=float, default=1.0) + group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\ + help="form of last step taken in the SDE") + group.add_argument("--last-step-size", type=float, default=0.04, \ + help="size of the last step taken") \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/transport/__init__.py b/Rectified_Noise/VP-Disp/transport/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bdcd3999442f5aef777a1f9ecc2889a0b99b2603 --- /dev/null +++ b/Rectified_Noise/VP-Disp/transport/__init__.py @@ -0,0 +1,71 @@ +from .transport import Transport, ModelType, WeightType, PathType, Sampler + +def create_transport( + path_type='Linear', + prediction="velocity", + loss_weight=None, + train_eps=None, + sample_eps=None, + disp_loss_weight=0.5, + temperature=1.0, +): + """function for creating Transport object + **Note**: model prediction defaults to velocity + Args: + - path_type: type of path to use; default to linear + - learn_score: set model prediction to score + - learn_noise: set model prediction to noise + - velocity_weighted: weight loss by velocity weight + - likelihood_weighted: weight loss by likelihood weight + - train_eps: small epsilon for avoiding instability during training + - sample_eps: small epsilon for avoiding instability during sampling + - disp_loss_weight: weight λ for dispersive loss (default: 0.5) + - temperature: temperature τ for dispersive loss (default: 1.0) + """ + + if prediction == "noise": + model_type = ModelType.NOISE + elif prediction == "score": + model_type = ModelType.SCORE + else: + model_type = ModelType.VELOCITY + + if loss_weight == "velocity": + loss_type = WeightType.VELOCITY + elif loss_weight == "likelihood": + loss_type = WeightType.LIKELIHOOD + else: + loss_type = WeightType.NONE + + path_choice = { + "Linear": PathType.LINEAR, + "GVP": PathType.GVP, + "VP": PathType.VP, + } + + path_type = path_choice[path_type] + + if (path_type in [PathType.VP]): + train_eps_new = 1e-5 if train_eps is None else train_eps + sample_eps_new = 1e-3 if sample_eps is None else sample_eps + train_eps, sample_eps = train_eps_new, sample_eps_new + elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): + train_eps_new = 1e-3 if train_eps is None else train_eps + sample_eps_new = 1e-3 if sample_eps is None else sample_eps + train_eps, sample_eps = train_eps_new, sample_eps_new + else: # velocity & [GVP, LINEAR] is stable everywhere + train_eps = 0 + sample_eps = 0 + + # create flow state + state = Transport( + model_type=model_type, + path_type=path_type, + loss_type=loss_type, + train_eps=train_eps, + sample_eps=sample_eps, + disp_loss_weight=disp_loss_weight, + temperature=temperature, + ) + + return state \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/__init__.cpython-312.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f863f6600f0c78d33dd5672dd914a73151be494 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/__init__.cpython-312.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/__init__.cpython-38.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..467c54269962d466b6384854d290ce00df6853c5 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/__init__.cpython-38.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/integrators.cpython-312.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/integrators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9de8bb340f7a96efa82fc5bfbabfe6bdb4383d46 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/integrators.cpython-312.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/integrators.cpython-38.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/integrators.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..861d76cda1ec304efb7f94298dd9e6f9acfb32e6 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/integrators.cpython-38.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/path.cpython-312.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/path.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3a2c0a168ed8329a72ac3f1c137e0acad9129ff Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/path.cpython-312.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/path.cpython-38.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/path.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..317dc28dc6becf862c57dc2470c38c1e262df481 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/path.cpython-38.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/transport.cpython-312.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/transport.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8408363783bd8710709ec3fcb10242f4edfc5a26 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/transport.cpython-312.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/transport.cpython-38.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/transport.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1db429fedcde7d5ae3138fec28324635c049ee4 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/transport.cpython-38.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/utils.cpython-312.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5dfe9050ca5b5ca5e6561a003212dfbe8ab2e3 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/utils.cpython-312.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/__pycache__/utils.cpython-38.pyc b/Rectified_Noise/VP-Disp/transport/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31c517795212841649007aae4fd637696b508984 Binary files /dev/null and b/Rectified_Noise/VP-Disp/transport/__pycache__/utils.cpython-38.pyc differ diff --git a/Rectified_Noise/VP-Disp/transport/integrators.py b/Rectified_Noise/VP-Disp/transport/integrators.py new file mode 100644 index 0000000000000000000000000000000000000000..adf7c7b4c50b6ff6c63973e0ddaa65b9759274c0 --- /dev/null +++ b/Rectified_Noise/VP-Disp/transport/integrators.py @@ -0,0 +1,117 @@ +import numpy as np +import torch as th +import torch.nn as nn +from torchdiffeq import odeint +from functools import partial +from tqdm import tqdm + +class sde: + """SDE solver class""" + def __init__( + self, + drift, + diffusion, + *, + t0, + t1, + num_steps, + sampler_type, + ): + assert t0 < t1, "SDE sampler has to be in forward time" + + self.num_timesteps = num_steps + self.t = th.linspace(t0, t1, num_steps) + self.dt = self.t[1] - self.t[0] + self.drift = drift + self.diffusion = diffusion + self.sampler_type = sampler_type + + def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + t = th.ones(x.size(0)).to(x) * t + dw = w_cur * th.sqrt(self.dt) + drift = self.drift(x, t, model, **model_kwargs) + diffusion = self.diffusion(x, t) + mean_x = x + drift * self.dt + x = mean_x + th.sqrt(2 * diffusion) * dw + return x, mean_x + + def __Heun_step(self, x, _, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + dw = w_cur * th.sqrt(self.dt) + t_cur = th.ones(x.size(0)).to(x) * t + diffusion = self.diffusion(x, t_cur) + xhat = x + th.sqrt(2 * diffusion) * dw + K1 = self.drift(xhat, t_cur, model, **model_kwargs) + xp = xhat + self.dt * K1 + K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) + return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step + + def __forward_fn(self): + """TODO: generalize here by adding all private functions ending with steps to it""" + sampler_dict = { + "Euler": self.__Euler_Maruyama_step, + "Heun": self.__Heun_step, + } + + try: + sampler = sampler_dict[self.sampler_type] + except: + raise NotImplementedError("Smapler type not implemented.") + + return sampler + + def sample(self, init, model, **model_kwargs): + """forward loop of sde""" + x = init + mean_x = init + samples = [] + sampler = self.__forward_fn() + for ti in self.t[:-1]: + with th.no_grad(): + x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) + samples.append(x) + + return samples + +class ode: + """ODE solver class""" + def __init__( + self, + drift, + *, + t0, + t1, + sampler_type, + num_steps, + atol, + rtol, + ): + assert t0 < t1, "ODE sampler has to be in forward time" + + self.drift = drift + self.t = th.linspace(t0, t1, num_steps) + self.atol = atol + self.rtol = rtol + self.sampler_type = sampler_type + + def sample(self, x, model, **model_kwargs): + + device = x[0].device if isinstance(x, tuple) else x.device + def _fn(t, x): + t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t + model_output = self.drift(x, t, model, **model_kwargs) + return model_output + + t = self.t.to(device) + atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] + rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] + samples = odeint( + _fn, + x, + t, + method=self.sampler_type, + atol=atol, + rtol=rtol + ) + return samples \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/transport/path.py b/Rectified_Noise/VP-Disp/transport/path.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3cf748669b4950c0c956d32c772ba4348b9ce9 --- /dev/null +++ b/Rectified_Noise/VP-Disp/transport/path.py @@ -0,0 +1,198 @@ +import torch as th +import numpy as np +from functools import partial + +def expand_t_like_x(t, x): + """Function to reshape time t to broadcastable dimension of x + Args: + t: [batch_dim,], time vector + x: [batch_dim,...], data point + """ + dims = [1] * (len(x.size()) - 1) + t = t.view(t.size(0), *dims) + return t + + +#################### Coupling Plans #################### + +class ICPlan: + """Linear Coupling Plan""" + def __init__(self, sigma=0.0): + self.sigma = sigma + + def compute_alpha_t(self, t): + """Compute the data coefficient along the path""" + return t, 1 + + def compute_sigma_t(self, t): + """Compute the noise coefficient along the path""" + return 1 - t, -1 + + def compute_d_alpha_alpha_ratio_t(self, t): + """Compute the ratio between d_alpha and alpha""" + return 1 / t + + def compute_drift(self, x, t): + """We always output sde according to score parametrization; """ + t = expand_t_like_x(t, x) + alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + drift = alpha_ratio * x + diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t + + return -drift, diffusion + + def compute_diffusion(self, x, t, form="constant", norm=1.0): + """Compute the diffusion term of the SDE + Args: + x: [batch_dim, ...], data point + t: [batch_dim,], time vector + form: str, form of the diffusion term + norm: float, norm of the diffusion term + """ + t = expand_t_like_x(t, x) + choices = { + "constant": norm, + "SBDM": norm * self.compute_drift(x, t)[1], + "sigma": norm * self.compute_sigma_t(t)[0], + "linear": norm * (1 - t), + "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, + "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, + } + + try: + diffusion = choices[form] + except KeyError: + raise NotImplementedError(f"Diffusion form {form} not implemented") + + return diffusion + + def get_score_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to score + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_noise_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to denoiser + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = reverse_alpha_ratio * d_sigma_t - sigma_t + noise = (reverse_alpha_ratio * velocity - mean) / var + return noise + + def get_velocity_from_score(self, score, x, t): + """Wrapper function: transfrom score prediction model to velocity + Args: + score: [batch_dim, ...] shaped tensor; score model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + drift, var = self.compute_drift(x, t) + velocity = var * score - drift + return velocity + + def compute_mu_t(self, t, x0, x1): + """Compute the mean of time-dependent density p_t""" + t = expand_t_like_x(t, x1) + alpha_t, _ = self.compute_alpha_t(t) + sigma_t, _ = self.compute_sigma_t(t) + return alpha_t * x1 + sigma_t * x0 + + def compute_xt(self, t, x0, x1): + """Sample xt from time-dependent density p_t; rng is required""" + xt = self.compute_mu_t(t, x0, x1) + return xt + + def compute_ut(self, t, x0, x1, xt): + """Compute the vector field corresponding to p_t""" + t = expand_t_like_x(t, x1) + _, d_alpha_t = self.compute_alpha_t(t) + _, d_sigma_t = self.compute_sigma_t(t) + return d_alpha_t * x1 + d_sigma_t * x0 + + def plan(self, t, x0, x1): + xt = self.compute_xt(t, x0, x1) + ut = self.compute_ut(t, x0, x1, xt) + return t, xt, ut + + +class VPCPlan(ICPlan): + """class for VP path flow matching""" + + def __init__(self, sigma_min=0.1, sigma_max=20.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min + self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min + + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = self.log_mean_coeff(t) + alpha_t = th.exp(alpha_t) + d_alpha_t = alpha_t * self.d_log_mean_coeff(t) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + p_sigma_t = 2 * self.log_mean_coeff(t) + # Clip p_sigma_t to prevent numerical issues + p_sigma_t = th.clamp(p_sigma_t, min=-50.0, max=0.0) + exp_p_sigma = th.exp(p_sigma_t) + # Ensure 1 - exp_p_sigma >= 0 for sqrt + one_minus_exp = th.clamp(1 - exp_p_sigma, min=1e-8) + sigma_t = th.sqrt(one_minus_exp) + # Add small epsilon to denominator to prevent division by zero + d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t + 1e-8) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return self.d_log_mean_coeff(t) + + def compute_drift(self, x, t): + """Compute the drift term of the SDE""" + t = expand_t_like_x(t, x) + beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) + return -0.5 * beta_t * x, beta_t / 2 + + +class GVPCPlan(ICPlan): + def __init__(self, sigma=0.0): + super().__init__(sigma) + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = th.sin(t * np.pi / 2) + d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + sigma_t = th.cos(t * np.pi / 2) + d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return np.pi / (2 * th.tan(t * np.pi / 2)) \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/transport/transport.py b/Rectified_Noise/VP-Disp/transport/transport.py new file mode 100644 index 0000000000000000000000000000000000000000..2424c3b989046242d74b5e93668a4b92f7cdb879 --- /dev/null +++ b/Rectified_Noise/VP-Disp/transport/transport.py @@ -0,0 +1,523 @@ +import torch as th +import numpy as np +import logging + +import enum + +from . import path +from .utils import EasyDict, log_state, mean_flat +from .integrators import ode, sde + +class ModelType(enum.Enum): + """ + Which type of output the model predicts. + """ + + NOISE = enum.auto() # the model predicts epsilon + SCORE = enum.auto() # the model predicts \nabla \log p(x) + VELOCITY = enum.auto() # the model predicts v(x) + +class PathType(enum.Enum): + """ + Which type of path to use. + """ + + LINEAR = enum.auto() + GVP = enum.auto() + VP = enum.auto() + +class WeightType(enum.Enum): + """ + Which type of weighting to use. + """ + + NONE = enum.auto() + VELOCITY = enum.auto() + LIKELIHOOD = enum.auto() + + +class Transport: + + def __init__( + self, + *, + model_type, + path_type, + loss_type, + train_eps, + sample_eps, + disp_loss_weight=0.5, + temperature=1.0, + ): + path_options = { + PathType.LINEAR: path.ICPlan, + PathType.GVP: path.GVPCPlan, + PathType.VP: path.VPCPlan, + } + + self.loss_type = loss_type + self.model_type = model_type + self.path_sampler = path_options[path_type]() + self.train_eps = train_eps + self.sample_eps = sample_eps + self.disp_loss_weight = disp_loss_weight # λ: weight for dispersive loss + self.temperature = temperature # τ: temperature parameter + + def prior_logp(self, z): + ''' + Standard multivariate normal prior + Assume z is batched + ''' + shape = th.tensor(z.size()) + N = th.prod(shape[1:]) + _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. + return th.vmap(_fn)(z) + + + def check_interval( + self, + train_eps, + sample_eps, + *, + diffusion_form="SBDM", + sde=False, + reverse=False, + eval=False, + last_step_size=0.0, + ): + t0 = 0 + t1 = 1 + eps = train_eps if not eval else sample_eps + if (type(self.path_sampler) in [path.VPCPlan]): + + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ + and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step + + t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + if reverse: + t0, t1 = 1 - t0, 1 - t1 + + return t0, t1 + + + def sample(self, x1): + """Sampling x0 & t based on shape of x1 (if needed) + Args: + x1 - data point; [batch, *dim] + """ + + x0 = th.randn_like(x1) + t0, t1 = self.check_interval(self.train_eps, self.sample_eps) + t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 + t = t.to(x1) + return t, x0, x1 + + def disp_loss(self, z): + """Dispersive Loss implementation (InfoNCE-L2 variant) + Args: + z: activation tensor from model layers + """ + # Clip activations to prevent numerical instability + z = z.reshape((z.shape[0], -1)) # flatten + # Normalize activations to prevent extreme values + z = z / (th.norm(z, dim=1, keepdim=True) + 1e-8) + + diff = th.nn.functional.pdist(z).pow(2) / z.shape[1] # pairwise distance + diff = th.cat((diff, diff, th.zeros(z.shape[0], device=z.device))) # match JAX implementation of full BxB matrix + # Apply temperature scaling: divide by temperature τ + diff = diff / (self.temperature + 1e-8) + + # Use log-sum-exp trick for numerical stability + # log(mean(exp(-diff))) = log(sum(exp(-diff))) - log(n) + # For numerical stability: log(sum(exp(-diff))) = -min(diff) + log(sum(exp(-(diff - min(diff))))) + diff_min = th.min(diff) + diff_shifted = diff - diff_min + exp_diff = th.exp(-diff_shifted) + # Clip exp values to prevent overflow + exp_diff = th.clamp(exp_diff, min=1e-20, max=1e20) + log_mean = -diff_min + th.log(exp_diff.mean() + 1e-20) + + # Clip final loss to prevent extreme values + return th.clamp(log_mean, min=-10.0, max=10.0) + + def training_losses( + self, + model, + x1, + model_noise=None, + model_kwargs=None + ): + """Loss for training the score model + Args: + - model: backbone model; could be score, noise, or velocity + - x1: datapoint + - model_kwargs: additional arguments for the model + """ + + + if model_kwargs == None: + model_kwargs = {} + + t, x0, x1 = self.sample(x1) + t, xt, ut = self.path_sampler.plan(t, x0, x1) + + # Handle return_act for dispersive loss + disp_loss = 0 + if model_noise==None: + model_output = model(xt, t, **model_kwargs) + # Check if model returns activations (for dispersive loss) + if "return_act" in model_kwargs and model_kwargs['return_act']: + model_output, act = model_output + if act is not None and len(act) > 0: + # Calculate dispersive loss for all blocks + for block_act in act: + disp_loss = disp_loss + self.disp_loss(block_act) + else: + model_output_pre = model(xt, t, **model_kwargs) + # Handle return_act for model_noise + if "return_act" in model_kwargs and model_kwargs['return_act']: + if isinstance(model_output_pre, tuple): + model_output_pre, act_pre = model_output_pre + else: + act_pre = None + else: + act_pre = None + + model_output_noise = model_noise(xt, t, **model_kwargs) + # Handle return_act for model_noise + if "return_act" in model_kwargs and model_kwargs['return_act']: + if isinstance(model_output_noise, tuple): + model_output_noise, act_noise = model_output_noise + else: + act_noise = None + # Calculate dispersive loss for all blocks in model_noise (sitf2) + if act_noise is not None and len(act_noise) > 0: + # Calculate dispersive loss for each block and sum them + for block_act in act_noise: + disp_loss = disp_loss + self.disp_loss(block_act) + model_output = model_output_pre + model_output_noise + + B, *_, C = xt.shape + assert model_output.size() == (B, *xt.size()[1:-1], C) + + terms = {} + terms['pred'] = model_output + if self.model_type == ModelType.VELOCITY: + terms['loss'] = mean_flat(((model_output - ut) ** 2)) + else: + _, drift_var = self.path_sampler.compute_drift(xt, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) + if self.loss_type in [WeightType.VELOCITY]: + weight = (drift_var / sigma_t) ** 2 + elif self.loss_type in [WeightType.LIKELIHOOD]: + weight = drift_var / (sigma_t ** 2) + elif self.loss_type in [WeightType.NONE]: + weight = 1 + else: + raise NotImplementedError() + + if self.model_type == ModelType.NOISE: + terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) + else: + terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) + + # Add dispersive loss to the total loss with weight λ + # Check for NaN/Inf before adding + if disp_loss != 0: + if th.isnan(disp_loss) or th.isinf(disp_loss): + # If dispersive loss is NaN/Inf, skip it and log a warning + import warnings + warnings.warn(f"Dispersive loss is NaN/Inf, skipping. Value: {disp_loss}") + else: + terms['loss'] = terms['loss'] + self.disp_loss_weight * disp_loss + + return terms + + + def get_drift( + self + ): + """member function for obtaining the drift of the probability flow ODE""" + def score_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + model_output = model(x, t, **model_kwargs) + return (-drift_mean + drift_var * model_output) # by change of variable + + def noise_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) + model_output = model(x, t, **model_kwargs) + score = model_output / -sigma_t + return (-drift_mean + drift_var * score) + + def velocity_ode(x, t, model, **model_kwargs): + model_output = model(x, t, **model_kwargs) + return model_output + + if self.model_type == ModelType.NOISE: + drift_fn = noise_ode + elif self.model_type == ModelType.SCORE: + drift_fn = score_ode + else: + drift_fn = velocity_ode + + def body_fn(x, t, model, **model_kwargs): + model_output = drift_fn(x, t, model, **model_kwargs) + assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" + return model_output + + return body_fn + + + def get_score( + self, + ): + """member function for obtaining score of + x_t = alpha_t * x + sigma_t * eps""" + if self.model_type == ModelType.NOISE: + score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] + elif self.model_type == ModelType.SCORE: + score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) + elif self.model_type == ModelType.VELOCITY: + score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) + else: + raise NotImplementedError() + + return score_fn + + +class Sampler: + """Sampler class for the transport model""" + def __init__( + self, + transport, + ): + """Constructor for a general sampler; supporting different sampling methods + Args: + - transport: an tranport object specify model prediction & interpolant type + """ + + self.transport = transport + self.drift = self.transport.get_drift() + self.score = self.transport.get_score() + + def __get_sde_diffusion_and_drift( + self, + *, + diffusion_form="SBDM", + diffusion_norm=1.0, + ): + + def diffusion_fn(x, t): + diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) + return diffusion + + sde_drift = \ + lambda x, t, model, **kwargs: \ + self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) + + sde_diffusion = diffusion_fn + + return sde_drift, sde_diffusion + + def __get_last_step( + self, + sde_drift, + *, + last_step, + last_step_size, + ): + """Get the last step function of the SDE solver""" + + if last_step is None: + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + elif last_step == "Mean": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + sde_drift(x, t, model, **model_kwargs) * last_step_size + elif last_step == "Tweedie": + alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long + sigma = self.transport.path_sampler.compute_sigma_t + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) + elif last_step == "Euler": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + self.drift(x, t, model, **model_kwargs) * last_step_size + else: + raise NotImplementedError() + + return last_step_fn + + def sample_sde( + self, + *, + sampling_method="Euler", + diffusion_form="SBDM", + diffusion_norm=1.0, + last_step="Mean", + last_step_size=0.04, + num_steps=250, + ): + """returns a sampling function with given SDE settings + Args: + - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama + - diffusion_form: function form of diffusion coefficient; default to be matching SBDM + - diffusion_norm: function magnitude of diffusion coefficient; default to 1 + - last_step: type of the last step; default to identity + - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] + - num_steps: total integration step of SDE + """ + + if last_step is None: + last_step_size = 0.0 + + sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( + diffusion_form=diffusion_form, + diffusion_norm=diffusion_norm, + ) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + diffusion_form=diffusion_form, + sde=True, + eval=True, + reverse=False, + last_step_size=last_step_size, + ) + + _sde = sde( + sde_drift, + sde_diffusion, + t0=t0, + t1=t1, + num_steps=num_steps, + sampler_type=sampling_method + ) + + last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) + + + def _sample(init, model, **model_kwargs): + xs = _sde.sample(init, model, **model_kwargs) + ts = th.ones(init.size(0), device=init.device) * t1 + x = last_step_fn(xs[-1], ts, model, **model_kwargs) + xs.append(x) + + assert len(xs) == num_steps, "Samples does not match the number of steps" + + return xs + + return _sample + + def sample_ode( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + reverse=False, + ): + """returns a sampling function with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + - reverse: whether solving the ODE in reverse (data to noise); default to False + """ + if reverse: + drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) + else: + drift = self.drift + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=reverse, + last_step_size=0.0, + ) + + _ode = ode( + drift=drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + return _ode.sample + + def sample_ode_likelihood( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + ): + + """returns a sampling function for calculating likelihood with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + """ + def _likelihood_drift(x, t, model, **model_kwargs): + x, _ = x + eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 + t = th.ones_like(t) * (1 - t) + with th.enable_grad(): + x.requires_grad = True + grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] + logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) + drift = self.drift(x, t, model, **model_kwargs) + return (-drift, logp_grad) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=False, + last_step_size=0.0, + ) + + _ode = ode( + drift=_likelihood_drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + def _sample_fn(x, model, **model_kwargs): + init_logp = th.zeros(x.size(0)).to(x) + input = (x, init_logp) + drift, delta_logp = _ode.sample(input, model, **model_kwargs) + drift, delta_logp = drift[-1], delta_logp[-1] + prior_logp = self.transport.prior_logp(drift) + logp = prior_logp - delta_logp + return logp, drift + + return _sample_fn \ No newline at end of file diff --git a/Rectified_Noise/VP-Disp/transport/utils.py b/Rectified_Noise/VP-Disp/transport/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44646035531326b81883727f973900edb4eac494 --- /dev/null +++ b/Rectified_Noise/VP-Disp/transport/utils.py @@ -0,0 +1,29 @@ +import torch as th + +class EasyDict: + + def __init__(self, sub_dict): + for k, v in sub_dict.items(): + setattr(self, k, v) + + def __getitem__(self, key): + return getattr(self, key) + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return th.mean(x, dim=list(range(1, len(x.size())))) + +def log_state(state): + result = [] + + sorted_state = dict(sorted(state.items())) + for key, value in sorted_state.items(): + # Check if the value is an instance of a class + if "