File size: 2,544 Bytes
7155cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import shutil
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import time
import re

class CheckpointHandler(FileSystemEventHandler):
    def __init__(self, folder_path, max_checkpoints=2):
        self.folder_path = folder_path
        self.max_checkpoints = max_checkpoints

    def on_created(self, event):
        if not event.is_directory:
            return
        # No need to call cleanup_checkpoints here if we're already calling it every 30 minutes

    def cleanup_checkpoints(self):
        # List all subdirectories in the folder
        checkpoints = [os.path.join(self.folder_path, d) for d in os.listdir(self.folder_path) if os.path.isdir(os.path.join(self.folder_path, d))]
        
        # Filter checkpoints that match the pattern "checkpoint-<number>"
        checkpoints = [checkpoint for checkpoint in checkpoints if re.match(r'global_step_\d+', os.path.basename(checkpoint))]

        # Get creation time and sort by creation time
        checkpoints_with_time = [(os.path.getctime(checkpoint), checkpoint) for checkpoint in checkpoints]
        checkpoints_with_time.sort()  # Sort by creation time
        
        specific_checkpoints = {f"global_step_{i}" for i in [45, 90, 135, 180, 220]}  # Add more as needed

        # Remove all but the last max_checkpoints directories
        if len(checkpoints_with_time) <= self.max_checkpoints:
            print(f"No need to remove any checkpoints, {len(checkpoints_with_time)} checkpoints exist")
        else:
            for _, checkpoint in checkpoints_with_time[:-self.max_checkpoints]:
                checkpoint_name = os.path.basename(checkpoint)
                if checkpoint_name not in specific_checkpoints:
                    shutil.rmtree(checkpoint)
                    print(f"Removed old checkpoint: {checkpoint}")
                else:
                    print(f"Skipped specific checkpoint: {checkpoint}")

def main():
    folder_path = '/data/wuxinrui/easyr1_checkpoints/1_5B_TCMv2_long_short_regular_budget_modified'  # Change this to your path
    event_handler = CheckpointHandler(folder_path)
    observer = Observer()
    observer.schedule(event_handler, folder_path, recursive=False)
    observer.start()

    try:
        while True:
            event_handler.cleanup_checkpoints()  # Call cleanup_checkpoints every 30 minutes
            time.sleep(300)  # Wait for 5 minutes
    except KeyboardInterrupt:
        observer.stop()
    observer.join()

if __name__ == "__main__":
    main()