File size: 1,460 Bytes
ef296aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from .dense_optical_tracking import DenseOpticalTracker
from .optical_flow import OpticalFlow
from .point_tracking import PointTracker
def create_model(args):
if args.model == "dot":
model = DenseOpticalTracker(
height=args.height,
width=args.width,
tracker_config=args.tracker_config,
tracker_path=args.tracker_path,
estimator_config=args.estimator_config,
estimator_path=args.estimator_path,
refiner_config=args.refiner_config,
refiner_path=args.refiner_path,
)
elif args.model == "pt":
model = PointTracker(
height=args.height,
width=args.width,
tracker_config=args.tracker_config,
tracker_path=args.tracker_path,
estimator_config=args.estimator_config,
estimator_path=args.estimator_path,
)
elif args.model == "ofe":
model = OpticalFlow(
height=args.height,
width=args.width,
config=args.estimator_config,
load_path=args.estimator_path,
)
elif args.model == "ofr":
model = OpticalFlow(
height=args.height,
width=args.width,
config=args.refiner_config,
load_path=args.refiner_path,
)
else:
raise ValueError(f"Unknown model name {args.model}")
return model |