File size: 688 Bytes
554dccb
6bfc4b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from inference_rfdetr import RFDETRInference


def get_model(model_name, version, pretrain_weights):
    """
    Factory method to return the correct model inference class based on name and version.

    Args:
        model_name (str): Name of the model (e.g., 'rfdetr').
        version (str): Version string (e.g., 'small', 'nano').
        pretrain_weights (str): Path to model weights.

    Returns:
        BaseInference: A model inference object.

    Raises:
        ValueError: If model_name is unsupported.
    """
    if model_name == 'rfdetr':
        return RFDETRInference(version, pretrain_weights)
    else:
        raise ValueError(f"Unsupported model: {model_name}")