Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
e4cd0fb
1
Parent(s):
dffe378
- online_data_generation.py +10 -8
online_data_generation.py
CHANGED
|
@@ -533,14 +533,16 @@ def main():
|
|
| 533 |
"""Main function to run the data processing pipeline."""
|
| 534 |
|
| 535 |
# create a padding image first
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
|
|
|
|
|
|
| 544 |
# Initialize database
|
| 545 |
initialize_database()
|
| 546 |
|
|
|
|
| 533 |
"""Main function to run the data processing pipeline."""
|
| 534 |
|
| 535 |
# create a padding image first
|
| 536 |
+
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
|
| 537 |
+
logger.info("Creating padding image...")
|
| 538 |
+
padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.uint8)
|
| 539 |
+
padding_tensor = torch.tensor(padding_data).unsqueeze(0)
|
| 540 |
+
padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device)
|
| 541 |
+
posterior = autoencoder.encode(padding_tensor)
|
| 542 |
+
latent = posterior.sample()
|
| 543 |
+
latent = torch.zeros_like(latent).squeeze(0)
|
| 544 |
+
np.save(os.path.join(OUTPUT_DIR, 'padding.npy.tmp'), latent.cpu().numpy())
|
| 545 |
+
os.rename(os.path.join(OUTPUT_DIR, 'padding.npy.tmp'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
| 546 |
# Initialize database
|
| 547 |
initialize_database()
|
| 548 |
|