Spaces:
Sleeping
Sleeping
| """Pipeline orchestration for aerial car detection training. | |
| Configures and runs RF-DETR training on a YOLO-format dataset via a Pydantic | |
| model with CLI overrides. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from pydantic import BaseModel | |
| from training.train import run_training | |
| class PipelineConfig(BaseModel): | |
| """Full configuration for the training pipeline.""" | |
| dataset_dir: Path = Path("./car_data/") | |
| epochs: int = 1 | |
| batch_size: int = 4 | |
| lr: float = 1e-4 | |
| resolution: int = 640 | |
| output_dir: str = "output" | |
| model_size: str = "base" | |
| grad_accum_steps: int = 1 | |
| def build_parser(config_cls: type[PipelineConfig]) -> argparse.ArgumentParser: | |
| """Build an argparse parser from PipelineConfig fields.""" | |
| parser = argparse.ArgumentParser( | |
| description="Run the full car-detection training pipeline (convert + train).", | |
| ) | |
| for name, field_info in config_cls.model_fields.items(): | |
| flag = f"--{name.replace('_', '-')}" | |
| field_type = field_info.annotation | |
| # Unwrap Optional / Union types to the core type | |
| origin = getattr(field_type, "__origin__", None) | |
| if origin is list: | |
| inner = field_type.__args__[0] | |
| parser.add_argument( | |
| flag, | |
| type=inner, | |
| nargs="+", | |
| default=field_info.default, | |
| help=f"(default: {field_info.default})", | |
| ) | |
| else: | |
| # Path fields are accepted as strings and converted by Pydantic | |
| arg_type = str if field_type is Path else field_type | |
| parser.add_argument( | |
| flag, | |
| type=arg_type, | |
| default=field_info.default, | |
| help=f"(default: {field_info.default})", | |
| ) | |
| return parser | |
| def run_pipeline(cfg: PipelineConfig) -> None: | |
| """Execute the training pipeline.""" | |
| # Resolve dataset path relative to this script's directory | |
| script_dir = Path(__file__).resolve().parent | |
| dataset_dir = (script_dir / cfg.dataset_dir).resolve() | |
| print("=" * 60) | |
| print("Training RF-DETR model") | |
| print(f" dataset: {dataset_dir}") | |
| print(f" model_size: {cfg.model_size}") | |
| print(f" epochs: {cfg.epochs}") | |
| print(f" batch_size: {cfg.batch_size}") | |
| print(f" lr: {cfg.lr}") | |
| print(f" resolution: {cfg.resolution}") | |
| print("=" * 60) | |
| run_training( | |
| dataset_dir=dataset_dir, | |
| epochs=cfg.epochs, | |
| batch_size=cfg.batch_size, | |
| lr=cfg.lr, | |
| resolution=cfg.resolution, | |
| output_dir=cfg.output_dir, | |
| model_size=cfg.model_size, | |
| grad_accum_steps=cfg.grad_accum_steps, | |
| ) | |
| print() | |
| print("=" * 60) | |
| print("Training complete.") | |
| print("=" * 60) | |
| def main() -> None: | |
| parser = build_parser(PipelineConfig) | |
| args = parser.parse_args() | |
| cfg = PipelineConfig(**vars(args)) | |
| run_pipeline(cfg) | |
| if __name__ == "__main__": | |
| main() | |