sheldonl commited on
Commit
5c48668
·
1 Parent(s): 15756c4

Fixed misc.to bug

Browse files
Files changed (2) hide show
  1. misc.py +41 -44
  2. world_generation_pipeline.py +6 -6
misc.py CHANGED
@@ -30,53 +30,50 @@ import termcolor
30
  import torch
31
 
32
  from .distributed import get_rank
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  class misc():
36
-
37
- @staticmethod
38
- def to(
39
- data: Any,
40
- device: str | torch.device | None = None,
41
- dtype: torch.dtype | None = None,
42
- memory_format: torch.memory_format = torch.preserve_format,
43
- ) -> Any:
44
- """Recursively cast data into the specified device, dtype, and/or memory_format.
45
-
46
- The input data can be a tensor, a list of tensors, a dict of tensors.
47
- See the documentation for torch.Tensor.to() for details.
48
-
49
- Args:
50
- data (Any): Input data.
51
- device (str | torch.device): GPU device (default: None).
52
- dtype (torch.dtype): data type (default: None).
53
- memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
54
-
55
- Returns:
56
- data (Any): Data cast to the specified device, dtype, and/or memory_format.
57
- """
58
- assert (
59
- device is not None or dtype is not None or memory_format is not None
60
- ), "at least one of device, dtype, memory_format should be specified"
61
- if isinstance(data, torch.Tensor):
62
- is_cpu = (isinstance(device, str) and device == "cpu") or (
63
- isinstance(device, torch.device) and device.type == "cpu"
64
- )
65
- data = data.to(
66
- device=device,
67
- dtype=dtype,
68
- memory_format=memory_format,
69
- non_blocking=(not is_cpu),
70
- )
71
- return data
72
- elif isinstance(data, collections.abc.Mapping):
73
- return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
74
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
75
- return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
76
- else:
77
- return data
78
-
79
-
80
  @staticmethod
81
  def serialize(data: Any) -> Any:
82
  """Serialize data by hierarchically traversing through iterables.
 
30
  import torch
31
 
32
  from .distributed import get_rank
33
+ def to(
34
+ data: Any,
35
+ device: str | torch.device | None = None,
36
+ dtype: torch.dtype | None = None,
37
+ memory_format: torch.memory_format = torch.preserve_format,
38
+ ) -> Any:
39
+ """Recursively cast data into the specified device, dtype, and/or memory_format.
40
+
41
+ The input data can be a tensor, a list of tensors, a dict of tensors.
42
+ See the documentation for torch.Tensor.to() for details.
43
+
44
+ Args:
45
+ data (Any): Input data.
46
+ device (str | torch.device): GPU device (default: None).
47
+ dtype (torch.dtype): data type (default: None).
48
+ memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
49
+
50
+ Returns:
51
+ data (Any): Data cast to the specified device, dtype, and/or memory_format.
52
+ """
53
+ assert (
54
+ device is not None or dtype is not None or memory_format is not None
55
+ ), "at least one of device, dtype, memory_format should be specified"
56
+ if isinstance(data, torch.Tensor):
57
+ is_cpu = (isinstance(device, str) and device == "cpu") or (
58
+ isinstance(device, torch.device) and device.type == "cpu"
59
+ )
60
+ data = data.to(
61
+ device=device,
62
+ dtype=dtype,
63
+ memory_format=memory_format,
64
+ non_blocking=(not is_cpu),
65
+ )
66
+ return data
67
+ elif isinstance(data, collections.abc.Mapping):
68
+ return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
69
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
70
+ return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
71
+ else:
72
+ return data
73
 
74
 
75
  class misc():
76
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @staticmethod
78
  def serialize(data: Any) -> Any:
79
  """Serialize data by hierarchically traversing through iterables.
world_generation_pipeline.py CHANGED
@@ -40,7 +40,7 @@ from .inference_utils import (
40
  load_tokenizer_model,
41
  )
42
  from .log import log
43
- from .misc import misc
44
 
45
 
46
  def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
@@ -311,7 +311,7 @@ class ARBaseGenerationPipeline(BaseWorldGenerationPipeline):
311
  log.info(f"Using input size of {context_used} frames")
312
 
313
  data_batch = {"video": inp_vid}
314
- data_batch = misc.to(data_batch, "cuda")
315
 
316
  T, H, W = self.latent_shape
317
  num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W]))
@@ -508,13 +508,13 @@ class ARBaseGenerationPipeline(BaseWorldGenerationPipeline):
508
  context = data_batch.get("context", None) if task_condition != "video" else None
509
  context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None
510
  if context is not None:
511
- context = misc.to(context, "cuda").detach().clone()
512
  if context_mask is not None:
513
- context_mask = misc.to(context_mask, "cuda").detach().clone()
514
 
515
  # get the video tokens
516
  data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch)
517
- data_tokens = misc.to(data_tokens, "cuda").detach().clone()
518
  batch_size = data_tokens.shape[0]
519
 
520
  for sample_num in range(batch_size):
@@ -816,7 +816,7 @@ class ARVideo2WorldGenerationPipeline(ARBaseGenerationPipeline):
816
  data_batch["video"] = inp_vid
817
  data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1)
818
 
819
- data_batch = misc.to(data_batch, "cuda")
820
 
821
  log.debug(f" num_tokens_to_generate: {num_gen_tokens}")
822
  log.debug(f" sampling_config: {sampling_config}")
 
40
  load_tokenizer_model,
41
  )
42
  from .log import log
43
+ from .misc import misc, to
44
 
45
 
46
  def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
 
311
  log.info(f"Using input size of {context_used} frames")
312
 
313
  data_batch = {"video": inp_vid}
314
+ data_batch = to(data_batch, "cuda")
315
 
316
  T, H, W = self.latent_shape
317
  num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W]))
 
508
  context = data_batch.get("context", None) if task_condition != "video" else None
509
  context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None
510
  if context is not None:
511
+ context = to(context, "cuda").detach().clone()
512
  if context_mask is not None:
513
+ context_mask = to(context_mask, "cuda").detach().clone()
514
 
515
  # get the video tokens
516
  data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch)
517
+ data_tokens = to(data_tokens, "cuda").detach().clone()
518
  batch_size = data_tokens.shape[0]
519
 
520
  for sample_num in range(batch_size):
 
816
  data_batch["video"] = inp_vid
817
  data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1)
818
 
819
+ data_batch = to(data_batch, "cuda")
820
 
821
  log.debug(f" num_tokens_to_generate: {num_gen_tokens}")
822
  log.debug(f" sampling_config: {sampling_config}")