Commit
·
6ded867
1
Parent(s):
541c8d3
use tqdm to track the current step
Browse files
cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py
CHANGED
|
@@ -27,6 +27,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
| 27 |
|
| 28 |
import attrs
|
| 29 |
import torch
|
|
|
|
| 30 |
|
| 31 |
from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
|
| 32 |
from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
|
|
@@ -204,7 +205,7 @@ def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_
|
|
| 204 |
The final result after all iterations.
|
| 205 |
"""
|
| 206 |
val = init_val
|
| 207 |
-
for i in range(lower, upper):
|
| 208 |
val = body_fun(i, val)
|
| 209 |
return val
|
| 210 |
|
|
@@ -251,7 +252,7 @@ def differential_equation_solver(
|
|
| 251 |
def step_fn(
|
| 252 |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
| 253 |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 254 |
-
log.
|
| 255 |
input_x_B_StateShape, x0_preds = state
|
| 256 |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
| 257 |
|
|
|
|
| 27 |
|
| 28 |
import attrs
|
| 29 |
import torch
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
|
| 32 |
from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
|
| 33 |
from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
|
|
|
|
| 205 |
The final result after all iterations.
|
| 206 |
"""
|
| 207 |
val = init_val
|
| 208 |
+
for i in tqdm(range(lower, upper)):
|
| 209 |
val = body_fun(i, val)
|
| 210 |
return val
|
| 211 |
|
|
|
|
| 252 |
def step_fn(
|
| 253 |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
| 254 |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 255 |
+
log.debug(f"Step [{i_th}/{num_step}]")
|
| 256 |
input_x_B_StateShape, x0_preds = state
|
| 257 |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
| 258 |
|