Spaces:
Configuration error
Configuration error
| from inference.core.env import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, LAMBDA | |
| from inference.core.models.classification_base import ( | |
| ClassificationBaseOnnxRoboflowInferenceModel, | |
| ) | |
| class VitClassification(ClassificationBaseOnnxRoboflowInferenceModel): | |
| """VitClassification handles classification inference | |
| for Vision Transformer (ViT) models using ONNX. | |
| Inherits: | |
| ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference. | |
| ClassificationMixin: Mixin class providing classification-specific methods. | |
| Attributes: | |
| multiclass (bool): A flag that specifies if the model should handle multiclass classification. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| """Initializes the VitClassification instance. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| self.multiclass = self.environment.get("MULTICLASS", False) | |
| def weights_file(self) -> str: | |
| """Determines the weights file to be used based on the availability of AWS keys. | |
| If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. | |
| Otherwise, it returns the path to 'best.onnx'. | |
| Returns: | |
| str: Path to the weights file. | |
| """ | |
| if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA: | |
| return "weights.onnx" | |
| else: | |
| return "best.onnx" | |