Spaces:
Sleeping
Sleeping
| """ | |
| This file trains or fine-tunes your YOLO phone detector using the cleaned MLOps structure. | |
| It follows your current project idea of using Ultralytics YOLO with a dataset yaml file. | |
| It also saves training metadata for MLflow, DAGsHub, and DVC-friendly artifact tracking. | |
| This keeps phone-model training separate from posture-model training. | |
| """ | |
| from pathlib import Path | |
| import os | |
| import mlflow | |
| from src.entity.config_entity import PathsConfig, PhoneDetectorConfig, TrainingConfig | |
| from src.utils.common import save_json, save_yaml | |
| from src.utils.logger import get_logger | |
| class PhoneTrainer: | |
| """ | |
| Train or fine-tune the phone detector. | |
| """ | |
| def __init__( | |
| self, | |
| paths_config: PathsConfig, | |
| phone_detector_config: PhoneDetectorConfig, | |
| training_config: TrainingConfig, | |
| log_dir: Path | None = None, | |
| log_level: str = "INFO", | |
| ) -> None: | |
| self.paths_config = paths_config | |
| self.phone_detector_config = phone_detector_config | |
| self.training_config = training_config | |
| self.logger = get_logger( | |
| self.__class__.__name__, log_dir=log_dir, level=log_level | |
| ) | |
| def _setup_mlflow(self) -> None: | |
| """ | |
| Configure MLflow tracking for phone detector training. | |
| """ | |
| tracking_uri = os.getenv( | |
| "MLFLOW_TRACKING_URI", | |
| self.training_config.tracking.tracking_uri, | |
| ) | |
| mlflow.set_tracking_uri(tracking_uri) | |
| mlflow.set_experiment(self.training_config.tracking.experiment_name_phone) | |
| def train(self, data_yaml_path: Path) -> dict: | |
| """ | |
| Train the YOLO phone detector using a dataset yaml file. | |
| """ | |
| try: | |
| from ultralytics import YOLO | |
| except ImportError as exc: | |
| raise ImportError( | |
| "Ultralytics is not installed. Please install it before running phone training." | |
| ) from exc | |
| if not data_yaml_path.exists(): | |
| raise FileNotFoundError(f"YOLO dataset yaml not found: {data_yaml_path}") | |
| self._setup_mlflow() | |
| if self.training_config.phone_training.use_existing_weights: | |
| model_path = self.training_config.phone_training.existing_weights_path | |
| else: | |
| model_path = self.phone_detector_config.weights.fallback_weight_file | |
| if not model_path.exists(): | |
| raise FileNotFoundError( | |
| f"Phone detector start weights not found: {model_path}" | |
| ) | |
| self.logger.info("Starting phone detector training from: %s", model_path) | |
| model = YOLO(str(model_path)) | |
| with mlflow.start_run(run_name="phone_training_run"): | |
| mlflow.log_param("start_weights", str(model_path)) | |
| mlflow.log_param( | |
| "epochs", self.training_config.phone_training.hyperparameters.epochs | |
| ) | |
| mlflow.log_param( | |
| "batch_size", | |
| self.training_config.phone_training.hyperparameters.batch_size, | |
| ) | |
| mlflow.log_param( | |
| "image_size", | |
| self.training_config.phone_training.hyperparameters.image_size, | |
| ) | |
| results = model.train( | |
| data=str(data_yaml_path), | |
| epochs=self.training_config.phone_training.hyperparameters.epochs, | |
| imgsz=self.training_config.phone_training.hyperparameters.image_size, | |
| batch=self.training_config.phone_training.hyperparameters.batch_size, | |
| device="cpu", # self.training_config.general.device_preference, | |
| ) | |
| summary = { | |
| "data_yaml_path": str(data_yaml_path), | |
| "start_weights": str(model_path), | |
| "epochs": self.training_config.phone_training.hyperparameters.epochs, | |
| "batch_size": self.training_config.phone_training.hyperparameters.batch_size, | |
| "image_size": self.training_config.phone_training.hyperparameters.image_size, | |
| "results_repr": str(results), | |
| } | |
| save_json( | |
| self.paths_config.metrics_dir | |
| / self.training_config.phone_training.outputs.metrics_file_name, | |
| summary, | |
| ) | |
| dvc_meta = { | |
| "artifact_type": "phone_model", | |
| "data_yaml_path": str(data_yaml_path), | |
| "start_weights": str(model_path), | |
| "default_target_path": str( | |
| self.paths_config.phone_weights_dir | |
| / self.training_config.phone_training.outputs.save_best_model_as | |
| ), | |
| } | |
| save_yaml( | |
| self.paths_config.metrics_dir / "phone_dvc_manifest.yaml", dvc_meta | |
| ) | |
| self.logger.info("Phone detector training completed.") | |
| return summary | |