Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
98a4a00
1
Parent(s):
46f6899
main.py
CHANGED
|
@@ -131,10 +131,11 @@ def load_initial_images(width, height):
|
|
| 131 |
initial_images = []
|
| 132 |
if DEBUG_TEACHER_FORCING:
|
| 133 |
# Load the previous 7 frames for image_81
|
| 134 |
-
for i in range(
|
| 135 |
img = Image.open(f"record_100/image_{i}.png").resize((width, height))
|
| 136 |
initial_images.append(np.array(img))
|
| 137 |
else:
|
|
|
|
| 138 |
for i in range(7):
|
| 139 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
| 140 |
return initial_images
|
|
@@ -229,7 +230,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 229 |
x, y = pos
|
| 230 |
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
| 231 |
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
| 232 |
-
if
|
| 233 |
norm_x = x + (1920 - 512) / 2
|
| 234 |
norm_y = y + (1080 - 512) / 2
|
| 235 |
#if DEBUG:
|
|
@@ -306,6 +307,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 306 |
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 307 |
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
| 308 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
previous_actions = []
|
| 310 |
for action in debug_actions[-8:]:
|
| 311 |
x, y, action_type = parse_action_string(action)
|
|
@@ -316,7 +327,18 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 316 |
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
| 317 |
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
| 318 |
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
| 319 |
-
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
#positions = positions[:4]
|
| 321 |
try:
|
| 322 |
while True:
|
|
@@ -340,12 +362,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 340 |
#mouse_position = position.split('~')
|
| 341 |
#mouse_position = [int(item) for item in mouse_position]
|
| 342 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
| 343 |
-
if
|
| 344 |
position = positions[0]
|
| 345 |
positions = positions[1:]
|
| 346 |
x, y, action_type = parse_action_string(position)
|
| 347 |
mouse_position = (x, y)
|
| 348 |
-
|
|
|
|
| 349 |
#previous_actions = [(action_type, mouse_position)]
|
| 350 |
|
| 351 |
# Log the start time
|
|
@@ -361,7 +384,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 361 |
if False and DEBUG_TEACHER_FORCING:
|
| 362 |
img = Image.open(f"record_100/image_{82+len(previous_frames)}.png")
|
| 363 |
previous_frames.append(img)
|
| 364 |
-
|
| 365 |
previous_frames.append(next_frame_append)
|
| 366 |
|
| 367 |
# Convert the numpy array to a base64 encoded image
|
|
|
|
| 131 |
initial_images = []
|
| 132 |
if DEBUG_TEACHER_FORCING:
|
| 133 |
# Load the previous 7 frames for image_81
|
| 134 |
+
for i in range(222-7, 222): # Load images 74-80
|
| 135 |
img = Image.open(f"record_100/image_{i}.png").resize((width, height))
|
| 136 |
initial_images.append(np.array(img))
|
| 137 |
else:
|
| 138 |
+
assert False
|
| 139 |
for i in range(7):
|
| 140 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
| 141 |
return initial_images
|
|
|
|
| 230 |
x, y = pos
|
| 231 |
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
| 232 |
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
| 233 |
+
if True and DEBUG_TEACHER_FORCING:
|
| 234 |
norm_x = x + (1920 - 512) / 2
|
| 235 |
norm_y = y + (1080 - 512) / 2
|
| 236 |
#if DEBUG:
|
|
|
|
| 307 |
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 308 |
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
| 309 |
]
|
| 310 |
+
debug_actions = [
|
| 311 |
+
'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
|
| 312 |
+
'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
|
| 313 |
+
'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
|
| 314 |
+
'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
|
| 315 |
+
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
| 316 |
+
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
| 317 |
+
'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
| 318 |
+
'N + 1 1 9 7 : + 0 2 9 7'
|
| 319 |
+
]
|
| 320 |
previous_actions = []
|
| 321 |
for action in debug_actions[-8:]:
|
| 322 |
x, y, action_type = parse_action_string(action)
|
|
|
|
| 327 |
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
| 328 |
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
| 329 |
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
| 330 |
+
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
|
| 331 |
+
]
|
| 332 |
+
positions = [
|
| 333 |
+
#'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
| 334 |
+
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
| 335 |
+
'N + 1 1 7 9 : + 0 3 0 3', 'N + 1 1 4 2 : + 0 3 1 4',
|
| 336 |
+
'N + 1 1 0 6 : + 0 3 2 6', 'N + 1 0 6 9 : + 0 3 3 7',
|
| 337 |
+
'N + 1 0 5 1 : + 0 3 4 3', 'N + 1 0 1 4 : + 0 3 5 4',
|
| 338 |
+
'N + 0 9 7 8 : + 0 3 6 5', 'N + 0 9 4 2 : + 0 3 7 7',
|
| 339 |
+
'N + 0 9 0 5 : + 0 3 8 8', 'N + 0 8 6 8 : + 0 4 0 0',
|
| 340 |
+
'N + 0 8 3 2 : + 0 4 1 1'
|
| 341 |
+
]
|
| 342 |
#positions = positions[:4]
|
| 343 |
try:
|
| 344 |
while True:
|
|
|
|
| 362 |
#mouse_position = position.split('~')
|
| 363 |
#mouse_position = [int(item) for item in mouse_position]
|
| 364 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
| 365 |
+
if True and DEBUG_TEACHER_FORCING:
|
| 366 |
position = positions[0]
|
| 367 |
positions = positions[1:]
|
| 368 |
x, y, action_type = parse_action_string(position)
|
| 369 |
mouse_position = (x, y)
|
| 370 |
+
if False:
|
| 371 |
+
previous_actions.append((action_type, mouse_position))
|
| 372 |
#previous_actions = [(action_type, mouse_position)]
|
| 373 |
|
| 374 |
# Log the start time
|
|
|
|
| 384 |
if False and DEBUG_TEACHER_FORCING:
|
| 385 |
img = Image.open(f"record_100/image_{82+len(previous_frames)}.png")
|
| 386 |
previous_frames.append(img)
|
| 387 |
+
elif False:
|
| 388 |
previous_frames.append(next_frame_append)
|
| 389 |
|
| 390 |
# Convert the numpy array to a base64 encoded image
|