Fixed misc.to bug
Browse files- misc.py +41 -44
- world_generation_pipeline.py +6 -6
misc.py
CHANGED
|
@@ -30,53 +30,50 @@ import termcolor
|
|
| 30 |
import torch
|
| 31 |
|
| 32 |
from .distributed import get_rank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class misc():
|
| 36 |
-
|
| 37 |
-
@staticmethod
|
| 38 |
-
def to(
|
| 39 |
-
data: Any,
|
| 40 |
-
device: str | torch.device | None = None,
|
| 41 |
-
dtype: torch.dtype | None = None,
|
| 42 |
-
memory_format: torch.memory_format = torch.preserve_format,
|
| 43 |
-
) -> Any:
|
| 44 |
-
"""Recursively cast data into the specified device, dtype, and/or memory_format.
|
| 45 |
-
|
| 46 |
-
The input data can be a tensor, a list of tensors, a dict of tensors.
|
| 47 |
-
See the documentation for torch.Tensor.to() for details.
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
data (Any): Input data.
|
| 51 |
-
device (str | torch.device): GPU device (default: None).
|
| 52 |
-
dtype (torch.dtype): data type (default: None).
|
| 53 |
-
memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
data (Any): Data cast to the specified device, dtype, and/or memory_format.
|
| 57 |
-
"""
|
| 58 |
-
assert (
|
| 59 |
-
device is not None or dtype is not None or memory_format is not None
|
| 60 |
-
), "at least one of device, dtype, memory_format should be specified"
|
| 61 |
-
if isinstance(data, torch.Tensor):
|
| 62 |
-
is_cpu = (isinstance(device, str) and device == "cpu") or (
|
| 63 |
-
isinstance(device, torch.device) and device.type == "cpu"
|
| 64 |
-
)
|
| 65 |
-
data = data.to(
|
| 66 |
-
device=device,
|
| 67 |
-
dtype=dtype,
|
| 68 |
-
memory_format=memory_format,
|
| 69 |
-
non_blocking=(not is_cpu),
|
| 70 |
-
)
|
| 71 |
-
return data
|
| 72 |
-
elif isinstance(data, collections.abc.Mapping):
|
| 73 |
-
return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
|
| 74 |
-
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
|
| 75 |
-
return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
|
| 76 |
-
else:
|
| 77 |
-
return data
|
| 78 |
-
|
| 79 |
-
|
| 80 |
@staticmethod
|
| 81 |
def serialize(data: Any) -> Any:
|
| 82 |
"""Serialize data by hierarchically traversing through iterables.
|
|
|
|
| 30 |
import torch
|
| 31 |
|
| 32 |
from .distributed import get_rank
|
| 33 |
+
def to(
|
| 34 |
+
data: Any,
|
| 35 |
+
device: str | torch.device | None = None,
|
| 36 |
+
dtype: torch.dtype | None = None,
|
| 37 |
+
memory_format: torch.memory_format = torch.preserve_format,
|
| 38 |
+
) -> Any:
|
| 39 |
+
"""Recursively cast data into the specified device, dtype, and/or memory_format.
|
| 40 |
+
|
| 41 |
+
The input data can be a tensor, a list of tensors, a dict of tensors.
|
| 42 |
+
See the documentation for torch.Tensor.to() for details.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
data (Any): Input data.
|
| 46 |
+
device (str | torch.device): GPU device (default: None).
|
| 47 |
+
dtype (torch.dtype): data type (default: None).
|
| 48 |
+
memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
data (Any): Data cast to the specified device, dtype, and/or memory_format.
|
| 52 |
+
"""
|
| 53 |
+
assert (
|
| 54 |
+
device is not None or dtype is not None or memory_format is not None
|
| 55 |
+
), "at least one of device, dtype, memory_format should be specified"
|
| 56 |
+
if isinstance(data, torch.Tensor):
|
| 57 |
+
is_cpu = (isinstance(device, str) and device == "cpu") or (
|
| 58 |
+
isinstance(device, torch.device) and device.type == "cpu"
|
| 59 |
+
)
|
| 60 |
+
data = data.to(
|
| 61 |
+
device=device,
|
| 62 |
+
dtype=dtype,
|
| 63 |
+
memory_format=memory_format,
|
| 64 |
+
non_blocking=(not is_cpu),
|
| 65 |
+
)
|
| 66 |
+
return data
|
| 67 |
+
elif isinstance(data, collections.abc.Mapping):
|
| 68 |
+
return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
|
| 69 |
+
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
|
| 70 |
+
return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
|
| 71 |
+
else:
|
| 72 |
+
return data
|
| 73 |
|
| 74 |
|
| 75 |
class misc():
|
| 76 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
@staticmethod
|
| 78 |
def serialize(data: Any) -> Any:
|
| 79 |
"""Serialize data by hierarchically traversing through iterables.
|
world_generation_pipeline.py
CHANGED
|
@@ -40,7 +40,7 @@ from .inference_utils import (
|
|
| 40 |
load_tokenizer_model,
|
| 41 |
)
|
| 42 |
from .log import log
|
| 43 |
-
from .misc import misc
|
| 44 |
|
| 45 |
|
| 46 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|
|
@@ -311,7 +311,7 @@ class ARBaseGenerationPipeline(BaseWorldGenerationPipeline):
|
|
| 311 |
log.info(f"Using input size of {context_used} frames")
|
| 312 |
|
| 313 |
data_batch = {"video": inp_vid}
|
| 314 |
-
data_batch =
|
| 315 |
|
| 316 |
T, H, W = self.latent_shape
|
| 317 |
num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W]))
|
|
@@ -508,13 +508,13 @@ class ARBaseGenerationPipeline(BaseWorldGenerationPipeline):
|
|
| 508 |
context = data_batch.get("context", None) if task_condition != "video" else None
|
| 509 |
context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None
|
| 510 |
if context is not None:
|
| 511 |
-
context =
|
| 512 |
if context_mask is not None:
|
| 513 |
-
context_mask =
|
| 514 |
|
| 515 |
# get the video tokens
|
| 516 |
data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch)
|
| 517 |
-
data_tokens =
|
| 518 |
batch_size = data_tokens.shape[0]
|
| 519 |
|
| 520 |
for sample_num in range(batch_size):
|
|
@@ -816,7 +816,7 @@ class ARVideo2WorldGenerationPipeline(ARBaseGenerationPipeline):
|
|
| 816 |
data_batch["video"] = inp_vid
|
| 817 |
data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1)
|
| 818 |
|
| 819 |
-
data_batch =
|
| 820 |
|
| 821 |
log.debug(f" num_tokens_to_generate: {num_gen_tokens}")
|
| 822 |
log.debug(f" sampling_config: {sampling_config}")
|
|
|
|
| 40 |
load_tokenizer_model,
|
| 41 |
)
|
| 42 |
from .log import log
|
| 43 |
+
from .misc import misc, to
|
| 44 |
|
| 45 |
|
| 46 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|
|
|
|
| 311 |
log.info(f"Using input size of {context_used} frames")
|
| 312 |
|
| 313 |
data_batch = {"video": inp_vid}
|
| 314 |
+
data_batch = to(data_batch, "cuda")
|
| 315 |
|
| 316 |
T, H, W = self.latent_shape
|
| 317 |
num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W]))
|
|
|
|
| 508 |
context = data_batch.get("context", None) if task_condition != "video" else None
|
| 509 |
context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None
|
| 510 |
if context is not None:
|
| 511 |
+
context = to(context, "cuda").detach().clone()
|
| 512 |
if context_mask is not None:
|
| 513 |
+
context_mask = to(context_mask, "cuda").detach().clone()
|
| 514 |
|
| 515 |
# get the video tokens
|
| 516 |
data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch)
|
| 517 |
+
data_tokens = to(data_tokens, "cuda").detach().clone()
|
| 518 |
batch_size = data_tokens.shape[0]
|
| 519 |
|
| 520 |
for sample_num in range(batch_size):
|
|
|
|
| 816 |
data_batch["video"] = inp_vid
|
| 817 |
data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1)
|
| 818 |
|
| 819 |
+
data_batch = to(data_batch, "cuda")
|
| 820 |
|
| 821 |
log.debug(f" num_tokens_to_generate: {num_gen_tokens}")
|
| 822 |
log.debug(f" sampling_config: {sampling_config}")
|