xizaoqu
commited on
Commit
·
f373311
1
Parent(s):
eda3a61
update
Browse files- algorithms/worldmem/df_video.py +29 -26
- app.py +23 -6
algorithms/worldmem/df_video.py
CHANGED
|
@@ -354,10 +354,10 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 354 |
|
| 355 |
self.is_interactive = cfg.get("is_interactive", False)
|
| 356 |
if self.is_interactive:
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
|
| 362 |
super().__init__(cfg)
|
| 363 |
|
|
@@ -791,21 +791,23 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 791 |
return
|
| 792 |
|
| 793 |
@torch.no_grad()
|
| 794 |
-
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device
|
|
|
|
|
|
|
| 795 |
condition_similar_length = self.condition_similar_length
|
| 796 |
|
| 797 |
-
if
|
| 798 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 799 |
-
|
| 800 |
self.actions = curr_actions[None, None].to(device)
|
| 801 |
-
|
| 802 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
return first_frame
|
| 806 |
else:
|
| 807 |
-
last_frame =
|
| 808 |
-
last_pose_condition =
|
| 809 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 810 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
| 811 |
|
|
@@ -814,15 +816,15 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 814 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
| 815 |
new_pose_condition[:,3:] %= 360
|
| 816 |
self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
|
| 817 |
-
|
| 818 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
| 819 |
-
|
| 820 |
-
|
| 821 |
|
| 822 |
conditions = self.actions.clone()
|
| 823 |
-
pose_conditions =
|
| 824 |
-
c2w_mat =
|
| 825 |
-
frame_idx =
|
| 826 |
|
| 827 |
|
| 828 |
curr_frame = 0
|
|
@@ -831,7 +833,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 831 |
n_frames = curr_frame + horizon
|
| 832 |
# context
|
| 833 |
n_context_frames = context_frames_idx // self.frame_stack
|
| 834 |
-
xs_pred =
|
| 835 |
curr_frame += n_context_frames
|
| 836 |
|
| 837 |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
|
@@ -894,14 +896,15 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 894 |
curr_frame += horizon
|
| 895 |
pbar.update(horizon)
|
| 896 |
|
| 897 |
-
|
| 898 |
|
| 899 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 900 |
-
|
|
|
|
| 901 |
|
| 902 |
|
| 903 |
def reset(self):
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
|
|
|
| 354 |
|
| 355 |
self.is_interactive = cfg.get("is_interactive", False)
|
| 356 |
if self.is_interactive:
|
| 357 |
+
self_frames = None
|
| 358 |
+
self_poses = None
|
| 359 |
+
self_memory_c2w = None
|
| 360 |
+
self_frame_idx = None
|
| 361 |
|
| 362 |
super().__init__(cfg)
|
| 363 |
|
|
|
|
| 791 |
return
|
| 792 |
|
| 793 |
@torch.no_grad()
|
| 794 |
+
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
|
| 795 |
+
self_frames, self_poses, self_memory_c2w, self_frame_idx):
|
| 796 |
+
|
| 797 |
condition_similar_length = self.condition_similar_length
|
| 798 |
|
| 799 |
+
if self_frames is None:
|
| 800 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 801 |
+
self_frames = first_frame_encode.cpu()
|
| 802 |
self.actions = curr_actions[None, None].to(device)
|
| 803 |
+
self_poses = first_pose[None, None].to(device)
|
| 804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 805 |
+
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
| 806 |
+
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
| 807 |
+
return first_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
|
| 808 |
else:
|
| 809 |
+
last_frame = self_frames[-1].clone()
|
| 810 |
+
last_pose_condition = self_poses[-1].clone()
|
| 811 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 812 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
| 813 |
|
|
|
|
| 816 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
| 817 |
new_pose_condition[:,3:] %= 360
|
| 818 |
self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
|
| 819 |
+
self_poses = torch.cat([self_poses, new_pose_condition[None].to(device)])
|
| 820 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
| 821 |
+
self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None].to(device)])
|
| 822 |
+
self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
|
| 823 |
|
| 824 |
conditions = self.actions.clone()
|
| 825 |
+
pose_conditions = self_poses.clone()
|
| 826 |
+
c2w_mat = self_memory_c2w .clone()
|
| 827 |
+
frame_idx = self_frame_idx.clone()
|
| 828 |
|
| 829 |
|
| 830 |
curr_frame = 0
|
|
|
|
| 833 |
n_frames = curr_frame + horizon
|
| 834 |
# context
|
| 835 |
n_context_frames = context_frames_idx // self.frame_stack
|
| 836 |
+
xs_pred = self_frames[:n_context_frames].clone()
|
| 837 |
curr_frame += n_context_frames
|
| 838 |
|
| 839 |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
|
|
|
| 896 |
curr_frame += horizon
|
| 897 |
pbar.update(horizon)
|
| 898 |
|
| 899 |
+
self_frames = torch.cat([self_frames, xs_pred[n_context_frames:]])
|
| 900 |
|
| 901 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 902 |
+
|
| 903 |
+
return xs_pred[-1,0], self_frames, self_poses, self_memory_c2w, self_frame_idx
|
| 904 |
|
| 905 |
|
| 906 |
def reset(self):
|
| 907 |
+
self_frames = None
|
| 908 |
+
self_poses = None
|
| 909 |
+
self_memory_c2w = None
|
| 910 |
+
self_frame_idx = None
|
app.py
CHANGED
|
@@ -182,15 +182,28 @@ poses = torch.zeros((1, 5))
|
|
| 182 |
|
| 183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
@spaces.GPU()
|
| 186 |
def run_interactive(first_frame, action, first_pose, curr_frame, device):
|
| 187 |
-
global
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
action,
|
| 190 |
first_pose,
|
| 191 |
curr_frame,
|
| 192 |
-
device=device
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
| 194 |
return new_frame
|
| 195 |
|
| 196 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
|
@@ -201,7 +214,7 @@ def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
|
| 201 |
return sampling_timesteps_state
|
| 202 |
|
| 203 |
def generate(keys):
|
| 204 |
-
print("algo frame:", len(worldmem.frames))
|
| 205 |
actions = parse_input_to_tensor(keys)
|
| 206 |
global input_history
|
| 207 |
global memory_curr_frame
|
|
@@ -236,7 +249,11 @@ def reset():
|
|
| 236 |
global input_history
|
| 237 |
global memory_frames
|
| 238 |
|
| 239 |
-
worldmem.reset()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
memory_frames = []
|
| 241 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 242 |
memory_curr_frame = 0
|
|
|
|
| 182 |
|
| 183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 184 |
|
| 185 |
+
self_frames = None
|
| 186 |
+
self_poses = None
|
| 187 |
+
self_memory_c2w = None
|
| 188 |
+
self_frame_idx = None
|
| 189 |
+
|
| 190 |
+
|
| 191 |
@spaces.GPU()
|
| 192 |
def run_interactive(first_frame, action, first_pose, curr_frame, device):
|
| 193 |
+
global self_frames
|
| 194 |
+
global self_poses
|
| 195 |
+
global self_memory_c2w
|
| 196 |
+
global self_frame_idx
|
| 197 |
+
|
| 198 |
+
new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
|
| 199 |
action,
|
| 200 |
first_pose,
|
| 201 |
curr_frame,
|
| 202 |
+
device=device,
|
| 203 |
+
self_frames=self_frames,
|
| 204 |
+
self_poses=self_poses,
|
| 205 |
+
self_memory_c2w=self_memory_c2w,
|
| 206 |
+
self_frame_idx=self_frame_idx)
|
| 207 |
return new_frame
|
| 208 |
|
| 209 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
|
|
|
| 214 |
return sampling_timesteps_state
|
| 215 |
|
| 216 |
def generate(keys):
|
| 217 |
+
# print("algo frame:", len(worldmem.frames))
|
| 218 |
actions = parse_input_to_tensor(keys)
|
| 219 |
global input_history
|
| 220 |
global memory_curr_frame
|
|
|
|
| 249 |
global input_history
|
| 250 |
global memory_frames
|
| 251 |
|
| 252 |
+
# worldmem.reset()
|
| 253 |
+
self_frames = None
|
| 254 |
+
self_poses = None
|
| 255 |
+
self_memory_c2w = None
|
| 256 |
+
self_frame_idx = None
|
| 257 |
memory_frames = []
|
| 258 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 259 |
memory_curr_frame = 0
|