File size: 6,502 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from enum import Enum
import json
import os
from typing import Dict, List, Optional

from lpm_kernel.api.domains.trainprocess.progress_enum import Status
from lpm_kernel.api.domains.trainprocess.train_progress import TrainProgress
from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep
from lpm_kernel.configs.logging import get_train_process_logger

logger = get_train_process_logger()

class TrainProgressHolder:
    """Progress management class"""

    def __init__(self, model_name: str = None):
        progress_dir = os.path.join(os.getcwd(), "data", "progress")
        if not os.path.exists(progress_dir):
            os.makedirs(progress_dir)
        
        # Generate progress file name based on model name
        progress_file = "trainprocess_progress.json"  # Default name
        if model_name:
            progress_file = f"trainprocess_progress_{model_name}.json"
            
        self.progress_file = os.path.normpath(os.path.join(progress_dir, progress_file))
        if not self.progress_file.startswith(progress_dir):
            raise ValueError("Invalid progress file path")
        self.progress = TrainProgress()

        # Stage mapping for process steps
        self._stage_mapping = {
            ProcessStep.MODEL_DOWNLOAD: "downloading_the_base_model",

            ProcessStep.LIST_DOCUMENTS: "activating_the_memory_matrix",
            ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS: "activating_the_memory_matrix",
            ProcessStep.CHUNK_DOCUMENT: "activating_the_memory_matrix",
            ProcessStep.CHUNK_EMBEDDING: "activating_the_memory_matrix",

            ProcessStep.EXTRACT_DIMENSIONAL_TOPICS: "synthesize_your_life_narrative",
            ProcessStep.GENERATE_BIOGRAPHY: "synthesize_your_life_narrative",
            ProcessStep.MAP_ENTITY_NETWORK: "synthesize_your_life_narrative",

            ProcessStep.DECODE_PREFERENCE_PATTERNS: "prepare_training_data_for_deep_comprehension",
            ProcessStep.REINFORCE_IDENTITY: "prepare_training_data_for_deep_comprehension",
            ProcessStep.AUGMENT_CONTENT_RETENTION: "prepare_training_data_for_deep_comprehension",

            ProcessStep.TRAIN: "training_to_create_second_me",
            ProcessStep.MERGE_WEIGHTS: "training_to_create_second_me",
            ProcessStep.CONVERT_MODEL: "training_to_create_second_me",
        }
        
        self._load_progress()

    def _load_progress(self):
        """Load progress file"""
        if os.path.exists(self.progress_file):
            try:
                with open(self.progress_file, "r") as f:
                    saved_progress = json.load(f)
                    self.progress.data = saved_progress
                    
                    self.progress.stage_map = {}
                    for stage in self.progress.data["stages"]:
                        stage_name = stage["name"].lower().replace(" ", "_")
                        self.progress.stage_map[stage_name] = stage
                    
                    self.progress.steps_map = {}
                    for stage_name, stage in self.progress.stage_map.items():
                        self.progress.steps_map[stage_name] = {}
                        for step in stage["steps"]:
                            step_name = step["name"].lower().replace(" ", "_")
                            self.progress.steps_map[stage_name][step_name] = step                    
                    # Check and reset any in_progress status to failed
                    self._reset_in_progress_status()
            except Exception as e:
                logger.error(f"Error loading progress: {str(e)}")
                # Reset progress on any error
                self.progress = TrainProgress()
                
    def _reset_in_progress_status(self):
        """Reset any in_progress status to failed after loading from file"""
        need_save = False
        
        # Check overall status
        if self.progress.data["status"] == "in_progress":
            self.progress.data["status"] = "failed"
            need_save = True
            logger.info("Reset overall in_progress status to failed")
        
        # Check each stage
        for stage in self.progress.data["stages"]:
            if stage["status"] == "in_progress":
                stage["status"] = "failed"
                need_save = True
                logger.info(f"Reset stage '{stage['name']}' in_progress status to failed")
            
            # Check each step in the stage
            for step in stage["steps"]:
                if step["status"] == "in_progress":
                    step["status"] = "failed"
                    step["completed"] = False
                    need_save = True
                    logger.info(f"Reset step '{step['name']}' in_progress status to failed")
        
        # Save changes if any were made
        if need_save:
            progress_dict = self.progress.to_dict()
            with open(self.progress_file, "w") as f:
                json.dump(progress_dict, f, indent=2)
            logger.info("Saved progress after resetting in_progress statuses")

    def _save_progress(self):
        """Save progress"""
        progress_dict = self.progress.to_dict()
        with open(self.progress_file, "w") as f:
            json.dump(progress_dict, f, indent=2)

    def is_step_completed(self, step: ProcessStep) -> bool:
        """Check if a step is completed"""
        stage_name = self._stage_mapping[step]
        step_name = step.value
        step_info = self.progress.steps_map[stage_name][step_name]
        return step_info.get("completed", False)

    def mark_step_status(self, step: ProcessStep, status: Status):
        """Mark a step with the specified status
        
        Args:
            step: The process step to mark
            status: The status to set for the step
        """
        stage_name = self._stage_mapping[step]
        step_name = step.value
        self.progress.update_progress(stage_name, step_name, status)
        self._save_progress()

    def reset_progress(self):
        """Reset all progress"""
        self.progress = TrainProgress()
        self._save_progress()

    def get_last_successful_step(self) -> Optional[ProcessStep]:
        """Get the last successfully completed step"""
        ordered_steps = ProcessStep.get_ordered_steps()
        for step in reversed(ordered_steps):
            if self.is_step_completed(step):
                return step
        return None