Object_Detection_HUB / scripts /model_factory.py
Panagiota Moraiti
Correct import
554dccb
raw
history blame contribute delete
688 Bytes
from inference_rfdetr import RFDETRInference
def get_model(model_name, version, pretrain_weights):
"""
Factory method to return the correct model inference class based on name and version.
Args:
model_name (str): Name of the model (e.g., 'rfdetr').
version (str): Version string (e.g., 'small', 'nano').
pretrain_weights (str): Path to model weights.
Returns:
BaseInference: A model inference object.
Raises:
ValueError: If model_name is unsupported.
"""
if model_name == 'rfdetr':
return RFDETRInference(version, pretrain_weights)
else:
raise ValueError(f"Unsupported model: {model_name}")