xizaoqu
commited on
Commit
·
e128dab
1
Parent(s):
9c45273
update
Browse files- algorithms/worldmem/df_video.py +20 -12
- app.py +16 -9
algorithms/worldmem/df_video.py
CHANGED
|
@@ -792,39 +792,46 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 792 |
|
| 793 |
@torch.no_grad()
|
| 794 |
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
|
| 795 |
-
self_frames, self_poses, self_memory_c2w, self_frame_idx):
|
| 796 |
|
| 797 |
condition_similar_length = self.condition_similar_length
|
| 798 |
|
| 799 |
if self_frames is None:
|
|
|
|
|
|
|
|
|
|
| 800 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 801 |
self_frames = first_frame_encode.cpu()
|
| 802 |
-
|
| 803 |
self_poses = first_pose[None, None].to(device)
|
| 804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 805 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
| 806 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
| 807 |
-
return first_frame.cpu(), self_frames.cpu().numpy(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
| 808 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
last_frame = self_frames[-1].clone()
|
| 810 |
-
self_poses = self_poses.to(device)
|
| 811 |
-
self_memory_c2w = self_memory_c2w.to(device)
|
| 812 |
-
self_frame_idx = self_frame_idx.to(device)
|
| 813 |
last_pose_condition = self_poses[-1].clone()
|
| 814 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 815 |
-
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None]
|
| 816 |
|
| 817 |
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
|
| 818 |
new_pose_condition = last_pose_condition + new_pose_condition_offset
|
| 819 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
| 820 |
new_pose_condition[:,3:] %= 360
|
| 821 |
-
|
| 822 |
-
self_poses = torch.cat([self_poses, new_pose_condition[None]
|
| 823 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
| 824 |
-
self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]
|
| 825 |
self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
|
| 826 |
|
| 827 |
-
conditions =
|
| 828 |
pose_conditions = self_poses.clone()
|
| 829 |
c2w_mat = self_memory_c2w .clone()
|
| 830 |
frame_idx = self_frame_idx.clone()
|
|
@@ -903,7 +910,8 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 903 |
|
| 904 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 905 |
|
| 906 |
-
return xs_pred[-1,0].cpu(), self_frames.cpu()
|
|
|
|
| 907 |
|
| 908 |
|
| 909 |
def reset(self):
|
|
|
|
| 792 |
|
| 793 |
@torch.no_grad()
|
| 794 |
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
|
| 795 |
+
self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
|
| 796 |
|
| 797 |
condition_similar_length = self.condition_similar_length
|
| 798 |
|
| 799 |
if self_frames is None:
|
| 800 |
+
first_frame = torch.from_numpy(first_frame)
|
| 801 |
+
curr_actions = torch.from_numpy(curr_actions)
|
| 802 |
+
first_pose = torch.from_numpy(first_pose)
|
| 803 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 804 |
self_frames = first_frame_encode.cpu()
|
| 805 |
+
self_actions = curr_actions[None, None].to(device)
|
| 806 |
self_poses = first_pose[None, None].to(device)
|
| 807 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 808 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
| 809 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
| 810 |
+
return first_frame.cpu(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
|
| 811 |
else:
|
| 812 |
+
self_frames = torch.from_numpy(self_frames)
|
| 813 |
+
self_actions = torch.from_numpy(self_actions).to(device)
|
| 814 |
+
self_poses = torch.from_numpy(self_poses).to(device)
|
| 815 |
+
self_memory_c2w = torch.from_numpy(self_memory_c2w).to(device)
|
| 816 |
+
self_frame_idx = torch.from_numpy(self_frame_idx).to(device)
|
| 817 |
+
curr_actions = curr_actions.to(device)
|
| 818 |
+
|
| 819 |
last_frame = self_frames[-1].clone()
|
|
|
|
|
|
|
|
|
|
| 820 |
last_pose_condition = self_poses[-1].clone()
|
| 821 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 822 |
+
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None], last_pose_condition)
|
| 823 |
|
| 824 |
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
|
| 825 |
new_pose_condition = last_pose_condition + new_pose_condition_offset
|
| 826 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
| 827 |
new_pose_condition[:,3:] %= 360
|
| 828 |
+
self_actions = torch.cat([self_actions, curr_actions[None, None]])
|
| 829 |
+
self_poses = torch.cat([self_poses, new_pose_condition[None]])
|
| 830 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
| 831 |
+
self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]])
|
| 832 |
self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
|
| 833 |
|
| 834 |
+
conditions = self_actions.clone()
|
| 835 |
pose_conditions = self_poses.clone()
|
| 836 |
c2w_mat = self_memory_c2w .clone()
|
| 837 |
frame_idx = self_frame_idx.clone()
|
|
|
|
| 910 |
|
| 911 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 912 |
|
| 913 |
+
return xs_pred[-1,0].cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), \
|
| 914 |
+
self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
|
| 915 |
|
| 916 |
|
| 917 |
def reset(self):
|
app.py
CHANGED
|
@@ -177,30 +177,33 @@ load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.
|
|
| 177 |
worldmem.to("cuda").eval()
|
| 178 |
|
| 179 |
|
| 180 |
-
actions =
|
| 181 |
-
poses =
|
| 182 |
|
| 183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 184 |
|
| 185 |
self_frames = None
|
|
|
|
| 186 |
self_poses = None
|
| 187 |
self_memory_c2w = None
|
| 188 |
self_frame_idx = None
|
| 189 |
|
| 190 |
|
| 191 |
@spaces.GPU()
|
| 192 |
-
def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames,
|
| 193 |
-
|
|
|
|
| 194 |
action,
|
| 195 |
first_pose,
|
| 196 |
curr_frame,
|
| 197 |
device=device,
|
| 198 |
self_frames=self_frames,
|
|
|
|
| 199 |
self_poses=self_poses,
|
| 200 |
self_memory_c2w=self_memory_c2w,
|
| 201 |
self_frame_idx=self_frame_idx)
|
| 202 |
-
|
| 203 |
-
return self_frames
|
| 204 |
|
| 205 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
| 206 |
worldmem.sampling_timesteps = denoising_steps
|
|
@@ -215,6 +218,7 @@ def generate(keys):
|
|
| 215 |
global input_history
|
| 216 |
global memory_curr_frame
|
| 217 |
global self_frames
|
|
|
|
| 218 |
global self_poses
|
| 219 |
global self_memory_c2w
|
| 220 |
global self_frame_idx
|
|
@@ -222,12 +226,13 @@ def generate(keys):
|
|
| 222 |
for i in range(len(actions)):
|
| 223 |
memory_curr_frame += 1
|
| 224 |
|
| 225 |
-
new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
| 226 |
actions[i],
|
| 227 |
None,
|
| 228 |
memory_curr_frame,
|
| 229 |
device=device,
|
| 230 |
self_frames=self_frames,
|
|
|
|
| 231 |
self_poses=self_poses,
|
| 232 |
self_memory_c2w=self_memory_c2w,
|
| 233 |
self_frame_idx=self_frame_idx)
|
|
@@ -254,6 +259,7 @@ def reset():
|
|
| 254 |
global input_history
|
| 255 |
global memory_frames
|
| 256 |
global self_frames
|
|
|
|
| 257 |
global self_poses
|
| 258 |
global self_memory_c2w
|
| 259 |
global self_frame_idx
|
|
@@ -263,16 +269,17 @@ def reset():
|
|
| 263 |
self_memory_c2w = None
|
| 264 |
self_frame_idx = None
|
| 265 |
memory_frames = []
|
| 266 |
-
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 267 |
memory_curr_frame = 0
|
| 268 |
input_history = ""
|
| 269 |
|
| 270 |
-
self_frames = run_interactive(memory_frames[0],
|
| 271 |
actions[0],
|
| 272 |
poses[0],
|
| 273 |
memory_curr_frame,
|
| 274 |
device=device,
|
| 275 |
self_frames=self_frames,
|
|
|
|
| 276 |
self_poses=self_poses,
|
| 277 |
self_memory_c2w=self_memory_c2w,
|
| 278 |
self_frame_idx=self_frame_idx)
|
|
|
|
| 177 |
worldmem.to("cuda").eval()
|
| 178 |
|
| 179 |
|
| 180 |
+
actions = np.zeros((1, 25), dtype=np.float32)
|
| 181 |
+
poses = np.zeros((1, 5), dtype=np.float32)
|
| 182 |
|
| 183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
| 184 |
|
| 185 |
self_frames = None
|
| 186 |
+
self_actions = None
|
| 187 |
self_poses = None
|
| 188 |
self_memory_c2w = None
|
| 189 |
self_frame_idx = None
|
| 190 |
|
| 191 |
|
| 192 |
@spaces.GPU()
|
| 193 |
+
def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames, self_actions,
|
| 194 |
+
self_poses, self_memory_c2w, self_frame_idx):
|
| 195 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
|
| 196 |
action,
|
| 197 |
first_pose,
|
| 198 |
curr_frame,
|
| 199 |
device=device,
|
| 200 |
self_frames=self_frames,
|
| 201 |
+
self_actions=self_actions,
|
| 202 |
self_poses=self_poses,
|
| 203 |
self_memory_c2w=self_memory_c2w,
|
| 204 |
self_frame_idx=self_frame_idx)
|
| 205 |
+
|
| 206 |
+
return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 207 |
|
| 208 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
| 209 |
worldmem.sampling_timesteps = denoising_steps
|
|
|
|
| 218 |
global input_history
|
| 219 |
global memory_curr_frame
|
| 220 |
global self_frames
|
| 221 |
+
global self_actions
|
| 222 |
global self_poses
|
| 223 |
global self_memory_c2w
|
| 224 |
global self_frame_idx
|
|
|
|
| 226 |
for i in range(len(actions)):
|
| 227 |
memory_curr_frame += 1
|
| 228 |
|
| 229 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
| 230 |
actions[i],
|
| 231 |
None,
|
| 232 |
memory_curr_frame,
|
| 233 |
device=device,
|
| 234 |
self_frames=self_frames,
|
| 235 |
+
self_actions=self_actions,
|
| 236 |
self_poses=self_poses,
|
| 237 |
self_memory_c2w=self_memory_c2w,
|
| 238 |
self_frame_idx=self_frame_idx)
|
|
|
|
| 259 |
global input_history
|
| 260 |
global memory_frames
|
| 261 |
global self_frames
|
| 262 |
+
global self_actions
|
| 263 |
global self_poses
|
| 264 |
global self_memory_c2w
|
| 265 |
global self_frame_idx
|
|
|
|
| 269 |
self_memory_c2w = None
|
| 270 |
self_frame_idx = None
|
| 271 |
memory_frames = []
|
| 272 |
+
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE).numpy())
|
| 273 |
memory_curr_frame = 0
|
| 274 |
input_history = ""
|
| 275 |
|
| 276 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
| 277 |
actions[0],
|
| 278 |
poses[0],
|
| 279 |
memory_curr_frame,
|
| 280 |
device=device,
|
| 281 |
self_frames=self_frames,
|
| 282 |
+
self_actions=self_actions,
|
| 283 |
self_poses=self_poses,
|
| 284 |
self_memory_c2w=self_memory_c2w,
|
| 285 |
self_frame_idx=self_frame_idx)
|