| import os |
| import torch |
| from transformers import AutoModel, AutoTokenizer |
|
|
| class ModelLoader: |
| def __init__(self): |
| self.loaded_models = {} |
| self.available_models = { |
| 'vit-base': 'google/vit-base-patch16-224', |
| 'resnet-50': 'microsoft/resnet-50', |
| 'dinov2-base': 'facebook/dinov2-base', |
| 'flan-t5-large': 'google/flan-t5-large', |
| 'bert-base': 'bert-base-uncased' |
| } |
| |
| def load_model(self, model_name: str): |
| if model_name in self.loaded_models: |
| return self.loaded_models[model_name] |
| |
| model_path = self.available_models.get(model_name) |
| if not model_path: |
| return None |
| |
| model = AutoModel.from_pretrained(model_path) |
| self.loaded_models[model_name] = model |
| return model |
| |
| def list_models(self) -> list: |
| return list(self.available_models.keys()) |
| |
| def get_loaded_models(self) -> list: |
| return list(self.loaded_models.keys()) |
| |
| def unload_model(self, model_name: str): |
| if model_name in self.loaded_models: |
| del self.loaded_models[model_name] |
| return True |
| return False |