| from inference_rfdetr import RFDETRInference | |
| def get_model(model_name, version, pretrain_weights): | |
| """ | |
| Factory method to return the correct model inference class based on name and version. | |
| Args: | |
| model_name (str): Name of the model (e.g., 'rfdetr'). | |
| version (str): Version string (e.g., 'small', 'nano'). | |
| pretrain_weights (str): Path to model weights. | |
| Returns: | |
| BaseInference: A model inference object. | |
| Raises: | |
| ValueError: If model_name is unsupported. | |
| """ | |
| if model_name == 'rfdetr': | |
| return RFDETRInference(version, pretrain_weights) | |
| else: | |
| raise ValueError(f"Unsupported model: {model_name}") | |