from .dense_optical_tracking import DenseOpticalTracker from .optical_flow import OpticalFlow from .point_tracking import PointTracker def create_model(args): if args.model == "dot": model = DenseOpticalTracker( height=args.height, width=args.width, tracker_config=args.tracker_config, tracker_path=args.tracker_path, estimator_config=args.estimator_config, estimator_path=args.estimator_path, refiner_config=args.refiner_config, refiner_path=args.refiner_path, ) elif args.model == "pt": model = PointTracker( height=args.height, width=args.width, tracker_config=args.tracker_config, tracker_path=args.tracker_path, estimator_config=args.estimator_config, estimator_path=args.estimator_path, ) elif args.model == "ofe": model = OpticalFlow( height=args.height, width=args.width, config=args.estimator_config, load_path=args.estimator_path, ) elif args.model == "ofr": model = OpticalFlow( height=args.height, width=args.width, config=args.refiner_config, load_path=args.refiner_path, ) else: raise ValueError(f"Unknown model name {args.model}") return model