Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
a92ddb8
1
Parent(s):
29a0aca
- online_data_generation.py +88 -38
online_data_generation.py
CHANGED
|
@@ -21,6 +21,7 @@ import pandas as pd
|
|
| 21 |
import ast
|
| 22 |
import pickle
|
| 23 |
from moviepy.editor import VideoFileClip
|
|
|
|
| 24 |
|
| 25 |
# Import the existing functions
|
| 26 |
from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
|
|
@@ -44,6 +45,7 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
| 44 |
SCREEN_WIDTH = 512
|
| 45 |
SCREEN_HEIGHT = 384
|
| 46 |
MEMORY_LIMIT = "2g"
|
|
|
|
| 47 |
|
| 48 |
# load autoencoder
|
| 49 |
config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
|
|
@@ -51,6 +53,19 @@ autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_
|
|
| 51 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
autoencoder = autoencoder.to(device)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def initialize_database():
|
| 55 |
"""Initialize the SQLite database if it doesn't exist."""
|
| 56 |
conn = sqlite3.connect(DB_FILE)
|
|
@@ -531,7 +546,8 @@ def generate_comparison_video(client_id, trajectory, output_file, start_time, en
|
|
| 531 |
|
| 532 |
def main():
|
| 533 |
"""Main function to run the data processing pipeline."""
|
| 534 |
-
|
|
|
|
| 535 |
# create a padding image first
|
| 536 |
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
|
| 537 |
logger.info("Creating padding image...")
|
|
@@ -543,52 +559,86 @@ def main():
|
|
| 543 |
latent = torch.zeros_like(latent).squeeze(0)
|
| 544 |
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
|
| 545 |
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
|
|
|
| 546 |
# Initialize database
|
| 547 |
initialize_database()
|
| 548 |
|
| 549 |
-
# Initialize clean Docker state
|
| 550 |
logger.info("Initializing clean container state...")
|
| 551 |
clean_state = initialize_clean_state()
|
| 552 |
logger.info(f"Clean state initialized: {clean_state}")
|
| 553 |
|
| 554 |
-
#
|
| 555 |
-
|
| 556 |
-
logger.info(f"Found {len(log_files)} log files")
|
| 557 |
-
|
| 558 |
-
# Filter for complete sessions
|
| 559 |
-
complete_sessions = [f for f in log_files if is_session_complete(f)]
|
| 560 |
-
logger.info(f"Found {len(complete_sessions)} complete sessions")
|
| 561 |
-
|
| 562 |
-
# Filter for sessions not yet processed
|
| 563 |
-
conn = sqlite3.connect(DB_FILE)
|
| 564 |
-
cursor = conn.cursor()
|
| 565 |
-
cursor.execute("SELECT log_file FROM processed_sessions")
|
| 566 |
-
processed_files = set(row[0] for row in cursor.fetchall())
|
| 567 |
-
conn.close()
|
| 568 |
-
|
| 569 |
-
new_sessions = [f for f in complete_sessions if f not in processed_files]
|
| 570 |
-
logger.info(f"Found {len(new_sessions)} new sessions to process")
|
| 571 |
|
| 572 |
-
|
| 573 |
-
valid_sessions = [f for f in new_sessions if is_session_valid(f)]
|
| 574 |
-
logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
|
| 575 |
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
|
| 593 |
|
| 594 |
if __name__ == "__main__":
|
|
|
|
| 21 |
import ast
|
| 22 |
import pickle
|
| 23 |
from moviepy.editor import VideoFileClip
|
| 24 |
+
import signal
|
| 25 |
|
| 26 |
# Import the existing functions
|
| 27 |
from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
|
|
|
|
| 45 |
SCREEN_WIDTH = 512
|
| 46 |
SCREEN_HEIGHT = 384
|
| 47 |
MEMORY_LIMIT = "2g"
|
| 48 |
+
CHECK_INTERVAL = 60 # Check for new data every 60 seconds
|
| 49 |
|
| 50 |
# load autoencoder
|
| 51 |
config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
|
|
|
|
| 53 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 54 |
autoencoder = autoencoder.to(device)
|
| 55 |
|
| 56 |
+
# Global flag for graceful shutdown
|
| 57 |
+
running = True
|
| 58 |
+
|
| 59 |
+
def signal_handler(sig, frame):
|
| 60 |
+
"""Handle Ctrl+C and other termination signals"""
|
| 61 |
+
global running
|
| 62 |
+
logger.info("Shutdown signal received. Finishing current processing and exiting...")
|
| 63 |
+
running = False
|
| 64 |
+
|
| 65 |
+
# Register signal handlers
|
| 66 |
+
signal.signal(signal.SIGINT, signal_handler)
|
| 67 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
| 68 |
+
|
| 69 |
def initialize_database():
|
| 70 |
"""Initialize the SQLite database if it doesn't exist."""
|
| 71 |
conn = sqlite3.connect(DB_FILE)
|
|
|
|
| 546 |
|
| 547 |
def main():
|
| 548 |
"""Main function to run the data processing pipeline."""
|
| 549 |
+
global running
|
| 550 |
+
|
| 551 |
# create a padding image first
|
| 552 |
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
|
| 553 |
logger.info("Creating padding image...")
|
|
|
|
| 559 |
latent = torch.zeros_like(latent).squeeze(0)
|
| 560 |
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
|
| 561 |
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
| 562 |
+
|
| 563 |
# Initialize database
|
| 564 |
initialize_database()
|
| 565 |
|
| 566 |
+
# Initialize clean Docker state
|
| 567 |
logger.info("Initializing clean container state...")
|
| 568 |
clean_state = initialize_clean_state()
|
| 569 |
logger.info(f"Clean state initialized: {clean_state}")
|
| 570 |
|
| 571 |
+
# Ensure output directory exists
|
| 572 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
+
logger.info(f"Starting continuous monitoring for new sessions (check interval: {CHECK_INTERVAL} seconds)")
|
|
|
|
|
|
|
| 575 |
|
| 576 |
+
try:
|
| 577 |
+
# Main monitoring loop
|
| 578 |
+
while running:
|
| 579 |
+
try:
|
| 580 |
+
# Find all log files
|
| 581 |
+
log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
|
| 582 |
+
logger.info(f"Found {len(log_files)} log files")
|
| 583 |
+
|
| 584 |
+
# Filter for complete sessions
|
| 585 |
+
complete_sessions = [f for f in log_files if is_session_complete(f)]
|
| 586 |
+
logger.info(f"Found {len(complete_sessions)} complete sessions")
|
| 587 |
+
|
| 588 |
+
# Filter for sessions not yet processed
|
| 589 |
+
conn = sqlite3.connect(DB_FILE)
|
| 590 |
+
cursor = conn.cursor()
|
| 591 |
+
cursor.execute("SELECT log_file FROM processed_sessions")
|
| 592 |
+
processed_files = set(row[0] for row in cursor.fetchall())
|
| 593 |
+
conn.close()
|
| 594 |
+
|
| 595 |
+
new_sessions = [f for f in complete_sessions if f not in processed_files]
|
| 596 |
+
logger.info(f"Found {len(new_sessions)} new sessions to process")
|
| 597 |
+
|
| 598 |
+
# Filter for valid sessions
|
| 599 |
+
valid_sessions = [f for f in new_sessions if is_session_valid(f)]
|
| 600 |
+
logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
|
| 601 |
+
|
| 602 |
+
# Process each valid session
|
| 603 |
+
total_trajectories = 0
|
| 604 |
+
for log_file in valid_sessions:
|
| 605 |
+
if not running:
|
| 606 |
+
logger.info("Shutdown in progress, stopping processing")
|
| 607 |
+
break
|
| 608 |
+
|
| 609 |
+
logger.info(f"Processing session file: {log_file}")
|
| 610 |
+
processed_ids = process_session_file(log_file, clean_state)
|
| 611 |
+
total_trajectories += len(processed_ids)
|
| 612 |
+
|
| 613 |
+
if total_trajectories > 0:
|
| 614 |
+
# Get next ID for reporting
|
| 615 |
+
conn = sqlite3.connect(DB_FILE)
|
| 616 |
+
cursor = conn.cursor()
|
| 617 |
+
cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
|
| 618 |
+
next_id = int(cursor.fetchone()[0])
|
| 619 |
+
conn.close()
|
| 620 |
+
|
| 621 |
+
logger.info(f"Processing cycle complete. Generated {total_trajectories} new trajectories.")
|
| 622 |
+
logger.info(f"Next ID will be {next_id}")
|
| 623 |
+
else:
|
| 624 |
+
logger.info("No new trajectories processed in this cycle")
|
| 625 |
+
|
| 626 |
+
# Sleep until next check, but with periodic wake-ups to check running flag
|
| 627 |
+
remaining_sleep = CHECK_INTERVAL
|
| 628 |
+
while remaining_sleep > 0 and running:
|
| 629 |
+
sleep_chunk = min(5, remaining_sleep) # Check running flag every 5 seconds max
|
| 630 |
+
time.sleep(sleep_chunk)
|
| 631 |
+
remaining_sleep -= sleep_chunk
|
| 632 |
+
|
| 633 |
+
except Exception as e:
|
| 634 |
+
logger.error(f"Error in processing cycle: {e}")
|
| 635 |
+
# Sleep briefly to avoid rapid error loops
|
| 636 |
+
time.sleep(10)
|
| 637 |
+
|
| 638 |
+
except KeyboardInterrupt:
|
| 639 |
+
logger.info("Keyboard interrupt received, shutting down")
|
| 640 |
+
finally:
|
| 641 |
+
logger.info("Shutting down trajectory processor")
|
| 642 |
|
| 643 |
|
| 644 |
if __name__ == "__main__":
|