DRAPEa / models /model_loader.py
andevs's picture
Upload 18 files
9c1053c verified
Raw
History Blame Contribute Delete
1.21 kB
import os
import torch
from transformers import AutoModel, AutoTokenizer
class ModelLoader:
def __init__(self):
self.loaded_models = {}
self.available_models = {
'vit-base': 'google/vit-base-patch16-224',
'resnet-50': 'microsoft/resnet-50',
'dinov2-base': 'facebook/dinov2-base',
'flan-t5-large': 'google/flan-t5-large',
'bert-base': 'bert-base-uncased'
}
def load_model(self, model_name: str):
if model_name in self.loaded_models:
return self.loaded_models[model_name]
model_path = self.available_models.get(model_name)
if not model_path:
return None
model = AutoModel.from_pretrained(model_path)
self.loaded_models[model_name] = model
return model
def list_models(self) -> list:
return list(self.available_models.keys())
def get_loaded_models(self) -> list:
return list(self.loaded_models.keys())
def unload_model(self, model_name: str):
if model_name in self.loaded_models:
del self.loaded_models[model_name]
return True
return False