Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
2163e7f
1
Parent(s):
5670558
main.py
CHANGED
|
@@ -201,9 +201,11 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 201 |
initial_images = load_initial_images(width, height)
|
| 202 |
|
| 203 |
# Prepare the image sequence for the model
|
|
|
|
| 204 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
| 205 |
while len(image_sequence) < 7:
|
| 206 |
-
image_sequence.insert(0, initial_images[len(image_sequence)])
|
|
|
|
| 207 |
|
| 208 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
| 209 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
|
@@ -219,6 +221,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 219 |
|
| 220 |
# Process initial actions if there are not enough previous actions
|
| 221 |
while len(previous_actions) < 8:
|
|
|
|
| 222 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
| 223 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
| 224 |
prev_x = 0
|
|
|
|
| 201 |
initial_images = load_initial_images(width, height)
|
| 202 |
|
| 203 |
# Prepare the image sequence for the model
|
| 204 |
+
assert len(initial_images) == 7
|
| 205 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
| 206 |
while len(image_sequence) < 7:
|
| 207 |
+
#image_sequence.insert(0, initial_images[len(image_sequence)])
|
| 208 |
+
image_sequence.append(initial_images[len(image_sequence)])
|
| 209 |
|
| 210 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
| 211 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
|
|
|
| 221 |
|
| 222 |
# Process initial actions if there are not enough previous actions
|
| 223 |
while len(previous_actions) < 8:
|
| 224 |
+
assert False
|
| 225 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
| 226 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
| 227 |
prev_x = 0
|