xizaoqu
commited on
Commit
·
4652b57
1
Parent(s):
0256e9a
update
Browse files
algorithms/worldmem/df_video.py
CHANGED
|
@@ -804,9 +804,12 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 805 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
| 806 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
| 807 |
-
return first_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
|
| 808 |
else:
|
| 809 |
last_frame = self_frames[-1].clone()
|
|
|
|
|
|
|
|
|
|
| 810 |
last_pose_condition = self_poses[-1].clone()
|
| 811 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 812 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
|
@@ -900,7 +903,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 900 |
|
| 901 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 902 |
|
| 903 |
-
return xs_pred[-1,0], self_frames, self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
| 904 |
|
| 905 |
|
| 906 |
def reset(self):
|
|
|
|
| 804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 805 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
| 806 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
| 807 |
+
return first_frame.cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
| 808 |
else:
|
| 809 |
last_frame = self_frames[-1].clone()
|
| 810 |
+
self_poses = self_poses.to(device)
|
| 811 |
+
self_memory_c2w = self_memory_c2w.to(device)
|
| 812 |
+
self_frame_idx = self_frame_idx.to(device)
|
| 813 |
last_pose_condition = self_poses[-1].clone()
|
| 814 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 815 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
|
|
|
| 903 |
|
| 904 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 905 |
|
| 906 |
+
return xs_pred[-1,0].cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
| 907 |
|
| 908 |
|
| 909 |
def reset(self):
|