|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
from rf100vl import get_rf100vl_projects
|
|
|
import roboflow
|
|
|
from rfdetr import RFDETRBase
|
|
|
import torch
|
|
|
import os
|
|
|
|
|
|
def download_dataset(rf_project: roboflow.Project, dataset_version: int):
|
|
|
versions = rf_project.versions()
|
|
|
if dataset_version is not None:
|
|
|
versions = [v for v in versions if v.version == str(dataset_version)]
|
|
|
if len(versions) == 0:
|
|
|
raise ValueError(f"Dataset version {dataset_version} not found")
|
|
|
version = versions[0]
|
|
|
else:
|
|
|
version = max(versions, key=lambda v: v.id)
|
|
|
location = os.path.join("datasets/", rf_project.name + "_v" + version.version)
|
|
|
if not os.path.exists(location):
|
|
|
location = version.download(
|
|
|
model_format="coco", location=location, overwrite=False
|
|
|
).location
|
|
|
|
|
|
return location
|
|
|
|
|
|
|
|
|
def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int):
|
|
|
location = download_dataset(rf_project, dataset_version)
|
|
|
print(location)
|
|
|
rf_detr = RFDETRBase()
|
|
|
device_supports_cuda = torch.cuda.is_available()
|
|
|
rf_detr.train(
|
|
|
dataset_dir=location,
|
|
|
epochs=1,
|
|
|
device="cuda" if device_supports_cuda else "cpu",
|
|
|
)
|
|
|
|
|
|
|
|
|
def train_from_coco_dir(coco_dir: str):
|
|
|
rf_detr = RFDETRBase()
|
|
|
rf_detr.train(
|
|
|
dataset_dir=coco_dir,
|
|
|
epochs=1,
|
|
|
device="cuda" if device_supports_cuda else "cpu",
|
|
|
)
|
|
|
|
|
|
|
|
|
def trainer():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--coco_dir", type=str, required=False)
|
|
|
parser.add_argument("--api_key", type=str, required=False)
|
|
|
parser.add_argument("--workspace", type=str, required=False, default=None)
|
|
|
parser.add_argument("--project_name", type=str, required=False, default=None)
|
|
|
parser.add_argument("--dataset_version", type=int, required=False, default=None)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.coco_dir is not None:
|
|
|
train_from_coco_dir(args.coco_dir)
|
|
|
return
|
|
|
|
|
|
if (args.workspace is None and args.project_name is not None) or (
|
|
|
args.workspace is not None and args.project_name is None
|
|
|
):
|
|
|
raise ValueError(
|
|
|
"Either both workspace and project_name must be provided or none of them"
|
|
|
)
|
|
|
|
|
|
if args.workspace is not None:
|
|
|
rf = roboflow.Roboflow(api_key=args.api_key)
|
|
|
project = rf.workspace(args.workspace).project(args.project_name)
|
|
|
else:
|
|
|
projects = get_rf100vl_projects(api_key=args.api_key)
|
|
|
project = projects[0].rf_project
|
|
|
|
|
|
train_from_rf_project(project, args.dataset_version)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
trainer()
|
|
|
|