IceClear commited on
Commit
63837ca
Β·
1 Parent(s): 341bd76
projects/video_diffusion_sr/infer.py CHANGED
@@ -75,13 +75,13 @@ class VideoDiffusionInfer():
75
  # For fast init & resume,
76
  # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
77
  # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
78
- if self.config.dit.get("init_with_meta_device", False):
79
- init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta"
80
- else:
81
- init_device = "cpu"
82
 
83
  # Create dit model.
84
- with torch.device(init_device):
85
  self.dit = create_object(self.config.dit.model)
86
  self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
87
 
@@ -92,8 +92,8 @@ class VideoDiffusionInfer():
92
  print(f"Loading info: {loading_info}")
93
  self.dit = meta_non_persistent_buffer_init_fn(self.dit)
94
 
95
- if device in [get_device(), "cuda"]:
96
- self.dit.to(get_device())
97
 
98
  # Print model size.
99
  num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
@@ -106,11 +106,11 @@ class VideoDiffusionInfer():
106
  dtype = getattr(torch, self.config.vae.dtype)
107
  self.vae = create_object(self.config.vae.model)
108
  self.vae.requires_grad_(False).eval()
109
- self.vae.to(device=get_device(), dtype=dtype)
110
 
111
  # Load vae checkpoint.
112
  state = torch.load(
113
- self.config.vae.checkpoint, map_location=get_device(), mmap=True
114
  )
115
  self.vae.load_state_dict(state)
116
 
 
75
  # For fast init & resume,
76
  # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
77
  # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
78
+ # if self.config.dit.get("init_with_meta_device", False):
79
+ # init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta"
80
+ # else:
81
+ # init_device = "cpu"
82
 
83
  # Create dit model.
84
+ with torch.device("cpu"):
85
  self.dit = create_object(self.config.dit.model)
86
  self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
87
 
 
92
  print(f"Loading info: {loading_info}")
93
  self.dit = meta_non_persistent_buffer_init_fn(self.dit)
94
 
95
+ # if device in [get_device(), "cuda"]:
96
+ self.dit.to("cuda")
97
 
98
  # Print model size.
99
  num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
 
106
  dtype = getattr(torch, self.config.vae.dtype)
107
  self.vae = create_object(self.config.vae.model)
108
  self.vae.requires_grad_(False).eval()
109
+ self.vae.to(device="cuda", dtype=dtype)
110
 
111
  # Load vae checkpoint.
112
  state = torch.load(
113
+ self.config.vae.checkpoint, map_location="cuda", mmap=True
114
  )
115
  self.vae.load_state_dict(state)
116