File size: 688 Bytes
554dccb 6bfc4b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | from inference_rfdetr import RFDETRInference
def get_model(model_name, version, pretrain_weights):
"""
Factory method to return the correct model inference class based on name and version.
Args:
model_name (str): Name of the model (e.g., 'rfdetr').
version (str): Version string (e.g., 'small', 'nano').
pretrain_weights (str): Path to model weights.
Returns:
BaseInference: A model inference object.
Raises:
ValueError: If model_name is unsupported.
"""
if model_name == 'rfdetr':
return RFDETRInference(version, pretrain_weights)
else:
raise ValueError(f"Unsupported model: {model_name}")
|