car-detection / training /pipeline.py
socks22's picture
first
f3f6f5d
"""Pipeline orchestration for aerial car detection training.
Configures and runs RF-DETR training on a YOLO-format dataset via a Pydantic
model with CLI overrides.
"""
import argparse
from pathlib import Path
from pydantic import BaseModel
from training.train import run_training
class PipelineConfig(BaseModel):
"""Full configuration for the training pipeline."""
dataset_dir: Path = Path("./car_data/")
epochs: int = 1
batch_size: int = 4
lr: float = 1e-4
resolution: int = 640
output_dir: str = "output"
model_size: str = "base"
grad_accum_steps: int = 1
def build_parser(config_cls: type[PipelineConfig]) -> argparse.ArgumentParser:
"""Build an argparse parser from PipelineConfig fields."""
parser = argparse.ArgumentParser(
description="Run the full car-detection training pipeline (convert + train).",
)
for name, field_info in config_cls.model_fields.items():
flag = f"--{name.replace('_', '-')}"
field_type = field_info.annotation
# Unwrap Optional / Union types to the core type
origin = getattr(field_type, "__origin__", None)
if origin is list:
inner = field_type.__args__[0]
parser.add_argument(
flag,
type=inner,
nargs="+",
default=field_info.default,
help=f"(default: {field_info.default})",
)
else:
# Path fields are accepted as strings and converted by Pydantic
arg_type = str if field_type is Path else field_type
parser.add_argument(
flag,
type=arg_type,
default=field_info.default,
help=f"(default: {field_info.default})",
)
return parser
def run_pipeline(cfg: PipelineConfig) -> None:
"""Execute the training pipeline."""
# Resolve dataset path relative to this script's directory
script_dir = Path(__file__).resolve().parent
dataset_dir = (script_dir / cfg.dataset_dir).resolve()
print("=" * 60)
print("Training RF-DETR model")
print(f" dataset: {dataset_dir}")
print(f" model_size: {cfg.model_size}")
print(f" epochs: {cfg.epochs}")
print(f" batch_size: {cfg.batch_size}")
print(f" lr: {cfg.lr}")
print(f" resolution: {cfg.resolution}")
print("=" * 60)
run_training(
dataset_dir=dataset_dir,
epochs=cfg.epochs,
batch_size=cfg.batch_size,
lr=cfg.lr,
resolution=cfg.resolution,
output_dir=cfg.output_dir,
model_size=cfg.model_size,
grad_accum_steps=cfg.grad_accum_steps,
)
print()
print("=" * 60)
print("Training complete.")
print("=" * 60)
def main() -> None:
parser = build_parser(PipelineConfig)
args = parser.parse_args()
cfg = PipelineConfig(**vars(args))
run_pipeline(cfg)
if __name__ == "__main__":
main()