Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------ | |
| # RF-DETR | |
| # Copyright (c) 2025 Roboflow. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) | |
| # Copyright (c) 2024 Baidu. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| import argparse | |
| from rf100vl import get_rf100vl_projects | |
| import roboflow | |
| from rfdetr import RFDETRBase | |
| import torch | |
| import os | |
| def download_dataset(rf_project: roboflow.Project, dataset_version: int): | |
| versions = rf_project.versions() | |
| if dataset_version is not None: | |
| versions = [v for v in versions if v.version == str(dataset_version)] | |
| if len(versions) == 0: | |
| raise ValueError(f"Dataset version {dataset_version} not found") | |
| version = versions[0] | |
| else: | |
| version = max(versions, key=lambda v: v.id) | |
| location = os.path.join("datasets/", rf_project.name + "_v" + version.version) | |
| if not os.path.exists(location): | |
| location = version.download( | |
| model_format="coco", location=location, overwrite=False | |
| ).location | |
| return location | |
| def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int): | |
| location = download_dataset(rf_project, dataset_version) | |
| print(location) | |
| rf_detr = RFDETRBase() | |
| device_supports_cuda = torch.cuda.is_available() | |
| rf_detr.train( | |
| dataset_dir=location, | |
| epochs=1, | |
| device="cuda" if device_supports_cuda else "cpu", | |
| ) | |
| def train_from_coco_dir(coco_dir: str): | |
| rf_detr = RFDETRBase() | |
| rf_detr.train( | |
| dataset_dir=coco_dir, | |
| epochs=1, | |
| device="cuda" if device_supports_cuda else "cpu", | |
| ) | |
| def trainer(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--coco_dir", type=str, required=False) | |
| parser.add_argument("--api_key", type=str, required=False) | |
| parser.add_argument("--workspace", type=str, required=False, default=None) | |
| parser.add_argument("--project_name", type=str, required=False, default=None) | |
| parser.add_argument("--dataset_version", type=int, required=False, default=None) | |
| args = parser.parse_args() | |
| if args.coco_dir is not None: | |
| train_from_coco_dir(args.coco_dir) | |
| return | |
| if (args.workspace is None and args.project_name is not None) or ( | |
| args.workspace is not None and args.project_name is None | |
| ): | |
| raise ValueError( | |
| "Either both workspace and project_name must be provided or none of them" | |
| ) | |
| if args.workspace is not None: | |
| rf = roboflow.Roboflow(api_key=args.api_key) | |
| project = rf.workspace(args.workspace).project(args.project_name) | |
| else: | |
| projects = get_rf100vl_projects(api_key=args.api_key) | |
| project = projects[0].rf_project | |
| train_from_rf_project(project, args.dataset_version) | |
| if __name__ == "__main__": | |
| trainer() | |