pvs_backend / src /components /phone_trainer.py
adnankhan-11's picture
PVD System - Initial deployment
d2885a7
"""
This file trains or fine-tunes your YOLO phone detector using the cleaned MLOps structure.
It follows your current project idea of using Ultralytics YOLO with a dataset yaml file.
It also saves training metadata for MLflow, DAGsHub, and DVC-friendly artifact tracking.
This keeps phone-model training separate from posture-model training.
"""
from pathlib import Path
import os
import mlflow
from src.entity.config_entity import PathsConfig, PhoneDetectorConfig, TrainingConfig
from src.utils.common import save_json, save_yaml
from src.utils.logger import get_logger
class PhoneTrainer:
"""
Train or fine-tune the phone detector.
"""
def __init__(
self,
paths_config: PathsConfig,
phone_detector_config: PhoneDetectorConfig,
training_config: TrainingConfig,
log_dir: Path | None = None,
log_level: str = "INFO",
) -> None:
self.paths_config = paths_config
self.phone_detector_config = phone_detector_config
self.training_config = training_config
self.logger = get_logger(
self.__class__.__name__, log_dir=log_dir, level=log_level
)
def _setup_mlflow(self) -> None:
"""
Configure MLflow tracking for phone detector training.
"""
tracking_uri = os.getenv(
"MLFLOW_TRACKING_URI",
self.training_config.tracking.tracking_uri,
)
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(self.training_config.tracking.experiment_name_phone)
def train(self, data_yaml_path: Path) -> dict:
"""
Train the YOLO phone detector using a dataset yaml file.
"""
try:
from ultralytics import YOLO
except ImportError as exc:
raise ImportError(
"Ultralytics is not installed. Please install it before running phone training."
) from exc
if not data_yaml_path.exists():
raise FileNotFoundError(f"YOLO dataset yaml not found: {data_yaml_path}")
self._setup_mlflow()
if self.training_config.phone_training.use_existing_weights:
model_path = self.training_config.phone_training.existing_weights_path
else:
model_path = self.phone_detector_config.weights.fallback_weight_file
if not model_path.exists():
raise FileNotFoundError(
f"Phone detector start weights not found: {model_path}"
)
self.logger.info("Starting phone detector training from: %s", model_path)
model = YOLO(str(model_path))
with mlflow.start_run(run_name="phone_training_run"):
mlflow.log_param("start_weights", str(model_path))
mlflow.log_param(
"epochs", self.training_config.phone_training.hyperparameters.epochs
)
mlflow.log_param(
"batch_size",
self.training_config.phone_training.hyperparameters.batch_size,
)
mlflow.log_param(
"image_size",
self.training_config.phone_training.hyperparameters.image_size,
)
results = model.train(
data=str(data_yaml_path),
epochs=self.training_config.phone_training.hyperparameters.epochs,
imgsz=self.training_config.phone_training.hyperparameters.image_size,
batch=self.training_config.phone_training.hyperparameters.batch_size,
device="cpu", # self.training_config.general.device_preference,
)
summary = {
"data_yaml_path": str(data_yaml_path),
"start_weights": str(model_path),
"epochs": self.training_config.phone_training.hyperparameters.epochs,
"batch_size": self.training_config.phone_training.hyperparameters.batch_size,
"image_size": self.training_config.phone_training.hyperparameters.image_size,
"results_repr": str(results),
}
save_json(
self.paths_config.metrics_dir
/ self.training_config.phone_training.outputs.metrics_file_name,
summary,
)
dvc_meta = {
"artifact_type": "phone_model",
"data_yaml_path": str(data_yaml_path),
"start_weights": str(model_path),
"default_target_path": str(
self.paths_config.phone_weights_dir
/ self.training_config.phone_training.outputs.save_best_model_as
),
}
save_yaml(
self.paths_config.metrics_dir / "phone_dvc_manifest.yaml", dvc_meta
)
self.logger.info("Phone detector training completed.")
return summary