|
|
import os |
|
|
import shutil |
|
|
from watchdog.observers import Observer |
|
|
from watchdog.events import FileSystemEventHandler |
|
|
import time |
|
|
import re |
|
|
|
|
|
class CheckpointHandler(FileSystemEventHandler): |
|
|
def __init__(self, folder_path, max_checkpoints=2): |
|
|
self.folder_path = folder_path |
|
|
self.max_checkpoints = max_checkpoints |
|
|
|
|
|
def on_created(self, event): |
|
|
if not event.is_directory: |
|
|
return |
|
|
|
|
|
|
|
|
def cleanup_checkpoints(self): |
|
|
|
|
|
checkpoints = [os.path.join(self.folder_path, d) for d in os.listdir(self.folder_path) if os.path.isdir(os.path.join(self.folder_path, d))] |
|
|
|
|
|
|
|
|
checkpoints = [checkpoint for checkpoint in checkpoints if re.match(r'global_step_\d+', os.path.basename(checkpoint))] |
|
|
|
|
|
|
|
|
checkpoints_with_time = [(os.path.getctime(checkpoint), checkpoint) for checkpoint in checkpoints] |
|
|
checkpoints_with_time.sort() |
|
|
|
|
|
specific_checkpoints = {f"global_step_{i}" for i in [45, 90, 135, 180, 220]} |
|
|
|
|
|
|
|
|
if len(checkpoints_with_time) <= self.max_checkpoints: |
|
|
print(f"No need to remove any checkpoints, {len(checkpoints_with_time)} checkpoints exist") |
|
|
else: |
|
|
for _, checkpoint in checkpoints_with_time[:-self.max_checkpoints]: |
|
|
checkpoint_name = os.path.basename(checkpoint) |
|
|
if checkpoint_name not in specific_checkpoints: |
|
|
shutil.rmtree(checkpoint) |
|
|
print(f"Removed old checkpoint: {checkpoint}") |
|
|
else: |
|
|
print(f"Skipped specific checkpoint: {checkpoint}") |
|
|
|
|
|
def main(): |
|
|
folder_path = '/data/wuxinrui/easyr1_checkpoints/1_5B_TCMv2_long_short_regular_budget_modified' |
|
|
event_handler = CheckpointHandler(folder_path) |
|
|
observer = Observer() |
|
|
observer.schedule(event_handler, folder_path, recursive=False) |
|
|
observer.start() |
|
|
|
|
|
try: |
|
|
while True: |
|
|
event_handler.cleanup_checkpoints() |
|
|
time.sleep(300) |
|
|
except KeyboardInterrupt: |
|
|
observer.stop() |
|
|
observer.join() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |