Spaces:
Sleeping
Sleeping
File size: 6,502 Bytes
01d5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from enum import Enum
import json
import os
from typing import Dict, List, Optional
from lpm_kernel.api.domains.trainprocess.progress_enum import Status
from lpm_kernel.api.domains.trainprocess.train_progress import TrainProgress
from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
class TrainProgressHolder:
"""Progress management class"""
def __init__(self, model_name: str = None):
progress_dir = os.path.join(os.getcwd(), "data", "progress")
if not os.path.exists(progress_dir):
os.makedirs(progress_dir)
# Generate progress file name based on model name
progress_file = "trainprocess_progress.json" # Default name
if model_name:
progress_file = f"trainprocess_progress_{model_name}.json"
self.progress_file = os.path.normpath(os.path.join(progress_dir, progress_file))
if not self.progress_file.startswith(progress_dir):
raise ValueError("Invalid progress file path")
self.progress = TrainProgress()
# Stage mapping for process steps
self._stage_mapping = {
ProcessStep.MODEL_DOWNLOAD: "downloading_the_base_model",
ProcessStep.LIST_DOCUMENTS: "activating_the_memory_matrix",
ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS: "activating_the_memory_matrix",
ProcessStep.CHUNK_DOCUMENT: "activating_the_memory_matrix",
ProcessStep.CHUNK_EMBEDDING: "activating_the_memory_matrix",
ProcessStep.EXTRACT_DIMENSIONAL_TOPICS: "synthesize_your_life_narrative",
ProcessStep.GENERATE_BIOGRAPHY: "synthesize_your_life_narrative",
ProcessStep.MAP_ENTITY_NETWORK: "synthesize_your_life_narrative",
ProcessStep.DECODE_PREFERENCE_PATTERNS: "prepare_training_data_for_deep_comprehension",
ProcessStep.REINFORCE_IDENTITY: "prepare_training_data_for_deep_comprehension",
ProcessStep.AUGMENT_CONTENT_RETENTION: "prepare_training_data_for_deep_comprehension",
ProcessStep.TRAIN: "training_to_create_second_me",
ProcessStep.MERGE_WEIGHTS: "training_to_create_second_me",
ProcessStep.CONVERT_MODEL: "training_to_create_second_me",
}
self._load_progress()
def _load_progress(self):
"""Load progress file"""
if os.path.exists(self.progress_file):
try:
with open(self.progress_file, "r") as f:
saved_progress = json.load(f)
self.progress.data = saved_progress
self.progress.stage_map = {}
for stage in self.progress.data["stages"]:
stage_name = stage["name"].lower().replace(" ", "_")
self.progress.stage_map[stage_name] = stage
self.progress.steps_map = {}
for stage_name, stage in self.progress.stage_map.items():
self.progress.steps_map[stage_name] = {}
for step in stage["steps"]:
step_name = step["name"].lower().replace(" ", "_")
self.progress.steps_map[stage_name][step_name] = step
# Check and reset any in_progress status to failed
self._reset_in_progress_status()
except Exception as e:
logger.error(f"Error loading progress: {str(e)}")
# Reset progress on any error
self.progress = TrainProgress()
def _reset_in_progress_status(self):
"""Reset any in_progress status to failed after loading from file"""
need_save = False
# Check overall status
if self.progress.data["status"] == "in_progress":
self.progress.data["status"] = "failed"
need_save = True
logger.info("Reset overall in_progress status to failed")
# Check each stage
for stage in self.progress.data["stages"]:
if stage["status"] == "in_progress":
stage["status"] = "failed"
need_save = True
logger.info(f"Reset stage '{stage['name']}' in_progress status to failed")
# Check each step in the stage
for step in stage["steps"]:
if step["status"] == "in_progress":
step["status"] = "failed"
step["completed"] = False
need_save = True
logger.info(f"Reset step '{step['name']}' in_progress status to failed")
# Save changes if any were made
if need_save:
progress_dict = self.progress.to_dict()
with open(self.progress_file, "w") as f:
json.dump(progress_dict, f, indent=2)
logger.info("Saved progress after resetting in_progress statuses")
def _save_progress(self):
"""Save progress"""
progress_dict = self.progress.to_dict()
with open(self.progress_file, "w") as f:
json.dump(progress_dict, f, indent=2)
def is_step_completed(self, step: ProcessStep) -> bool:
"""Check if a step is completed"""
stage_name = self._stage_mapping[step]
step_name = step.value
step_info = self.progress.steps_map[stage_name][step_name]
return step_info.get("completed", False)
def mark_step_status(self, step: ProcessStep, status: Status):
"""Mark a step with the specified status
Args:
step: The process step to mark
status: The status to set for the step
"""
stage_name = self._stage_mapping[step]
step_name = step.value
self.progress.update_progress(stage_name, step_name, status)
self._save_progress()
def reset_progress(self):
"""Reset all progress"""
self.progress = TrainProgress()
self._save_progress()
def get_last_successful_step(self) -> Optional[ProcessStep]:
"""Get the last successfully completed step"""
ordered_steps = ProcessStep.get_ordered_steps()
for step in reversed(ordered_steps):
if self.is_step_completed(step):
return step
return None
|