Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
50eea75
1
Parent(s):
a10a91e
main.py
CHANGED
|
@@ -13,6 +13,7 @@ import os
|
|
| 13 |
import time
|
| 14 |
|
| 15 |
DEBUG = True
|
|
|
|
| 16 |
app = FastAPI()
|
| 17 |
|
| 18 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
@@ -128,15 +129,14 @@ model = model.to(device)
|
|
| 128 |
|
| 129 |
def load_initial_images(width, height):
|
| 130 |
initial_images = []
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
#
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
| 140 |
return initial_images
|
| 141 |
|
| 142 |
def normalize_images(images, target_range=(-1, 1)):
|
|
@@ -156,13 +156,15 @@ def denormalize_image(image, source_range=(-1, 1)):
|
|
| 156 |
else:
|
| 157 |
raise ValueError(f"Unsupported source range: {source_range}")
|
| 158 |
|
| 159 |
-
def format_action(action_str, is_padding=False):
|
| 160 |
if is_padding:
|
| 161 |
return "N N N N N N : N N N N N"
|
| 162 |
|
| 163 |
# Split the x~y coordinates
|
| 164 |
x, y = map(int, action_str.split('~'))
|
| 165 |
prefix = 'N'
|
|
|
|
|
|
|
| 166 |
# Convert numbers to padded strings and add spaces between digits
|
| 167 |
x_str = f"{abs(x):04d}"
|
| 168 |
y_str = f"{abs(y):04d}"
|
|
@@ -200,6 +202,22 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 200 |
prev_x = 0
|
| 201 |
prev_y = 0
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
for action_type, pos in previous_actions: #[-8:]:
|
| 205 |
if action_type == "move":
|
|
@@ -217,7 +235,17 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 217 |
prev_x = norm_x
|
| 218 |
prev_y = norm_y
|
| 219 |
elif action_type == "left_click":
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
elif action_type == "right_click":
|
| 222 |
action_descriptions.append("right_click")
|
| 223 |
|
|
|
|
| 13 |
import time
|
| 14 |
|
| 15 |
DEBUG = True
|
| 16 |
+
DEBUG_TEACHER_FORCING = True
|
| 17 |
app = FastAPI()
|
| 18 |
|
| 19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
|
|
| 129 |
|
| 130 |
def load_initial_images(width, height):
|
| 131 |
initial_images = []
|
| 132 |
+
if DEBUG_TEACHER_FORCING:
|
| 133 |
+
# Load the previous 7 frames for image_81
|
| 134 |
+
for i in range(74, 81): # Load images 74-80
|
| 135 |
+
img = Image.open(f"record_100/image_{i}.png").resize((width, height))
|
| 136 |
+
initial_images.append(np.array(img))
|
| 137 |
+
else:
|
| 138 |
+
for i in range(7):
|
| 139 |
+
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
|
|
|
| 140 |
return initial_images
|
| 141 |
|
| 142 |
def normalize_images(images, target_range=(-1, 1)):
|
|
|
|
| 156 |
else:
|
| 157 |
raise ValueError(f"Unsupported source range: {source_range}")
|
| 158 |
|
| 159 |
+
def format_action(action_str, is_padding=False, is_leftclick=False):
|
| 160 |
if is_padding:
|
| 161 |
return "N N N N N N : N N N N N"
|
| 162 |
|
| 163 |
# Split the x~y coordinates
|
| 164 |
x, y = map(int, action_str.split('~'))
|
| 165 |
prefix = 'N'
|
| 166 |
+
if is_leftclick:
|
| 167 |
+
prefix = 'L'
|
| 168 |
# Convert numbers to padded strings and add spaces between digits
|
| 169 |
x_str = f"{abs(x):04d}"
|
| 170 |
y_str = f"{abs(y):04d}"
|
|
|
|
| 202 |
prev_x = 0
|
| 203 |
prev_y = 0
|
| 204 |
|
| 205 |
+
if DEBUG_TEACHER_FORCING:
|
| 206 |
+
# Use the predefined actions for image_81
|
| 207 |
+
debug_actions = [
|
| 208 |
+
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
| 209 |
+
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
| 210 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 211 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 212 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 213 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 214 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 215 |
+
'N + 0 9 2 7 : + 0 5 0 1'
|
| 216 |
+
]
|
| 217 |
+
previous_actions = []
|
| 218 |
+
for action in debug_actions:
|
| 219 |
+
x, y, action_type = parse_action_string(action)
|
| 220 |
+
previous_actions.append((action_type, (x, y)))
|
| 221 |
|
| 222 |
for action_type, pos in previous_actions: #[-8:]:
|
| 223 |
if action_type == "move":
|
|
|
|
| 235 |
prev_x = norm_x
|
| 236 |
prev_y = norm_y
|
| 237 |
elif action_type == "left_click":
|
| 238 |
+
x, y = pos
|
| 239 |
+
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
| 240 |
+
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
| 241 |
+
norm_x = x + (1920 - 512) / 2
|
| 242 |
+
norm_y = y + (1080 - 512) / 2
|
| 243 |
+
#if DEBUG:
|
| 244 |
+
# norm_x = x
|
| 245 |
+
# norm_y = y
|
| 246 |
+
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
| 247 |
+
#action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
|
| 248 |
+
action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True))
|
| 249 |
elif action_type == "right_click":
|
| 250 |
action_descriptions.append("right_click")
|
| 251 |
|