| import clip |
| import torch |
| import joblib |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| from config import Config |
|
|
| class ModelLoader: |
| """ |
| A class to load and hold the machine learning models. |
| This ensures that models are loaded only once. |
| """ |
| def __init__(self, clip_model_name: str, svm_repo_id: str, svm_filename: str): |
| """ |
| Initializes the ModelLoader and loads the models. |
| |
| Args: |
| clip_model_name (str): The name of the CLIP model to load (e.g., 'ViT-L/14'). |
| svm_repo_id (str): The repository ID on Hugging Face (e.g., 'rhnsa/ai_human_image_detector'). |
| svm_filename (str): The name of the model file in the repository (e.g., 'model.joblib'). |
| """ |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {self.device}") |
|
|
| self.clip_model, self.clip_preprocess = self._load_clip_model(clip_model_name) |
| self.svm_model = self._load_svm_model(repo_id=svm_repo_id, filename=svm_filename) |
| print("Models loaded successfully.") |
|
|
| def _load_clip_model(self, model_name: str): |
| """ |
| Loads the specified CLIP model and its preprocessor. |
| |
| Args: |
| model_name (str): The name of the CLIP model. |
| |
| Returns: |
| A tuple containing the loaded CLIP model and its preprocess function. |
| """ |
| try: |
| model, preprocess = clip.load(model_name, device=self.device) |
| return model, preprocess |
| except Exception as e: |
| print(f"Error loading CLIP model: {e}") |
| raise |
|
|
| def _load_svm_model(self, repo_id: str, filename: str): |
| """ |
| Downloads and loads the SVM model from a Hugging Face Hub repository. |
| |
| Args: |
| repo_id (str): The repository ID on Hugging Face. |
| filename (str): The name of the model file in the repository. |
| |
| Returns: |
| The loaded SVM model object. |
| """ |
| print(f"Downloading SVM model from Hugging Face repo: {repo_id}") |
| try: |
| |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=Config.HF_TOKEN) |
| print(f"SVM model downloaded to: {model_path}") |
| |
| |
| svm_model = joblib.load(model_path) |
| return svm_model |
| except Exception as e: |
| print(f"Error downloading or loading SVM model from Hugging Face: {e}") |
| raise |
|
|
| |
| |
| CLIP_MODEL_NAME = Config.AI_HUMAN_CLIP_MODEL_NAME |
| SVM_REPO_ID = Config.AI_HUMAN_SVM_REPO_ID |
| SVM_FILENAME = Config.AI_HUMAN_SVM_FILENAME |
|
|
| |
| models = ModelLoader( |
| clip_model_name=CLIP_MODEL_NAME, |
| svm_repo_id=SVM_REPO_ID, |
| svm_filename=SVM_FILENAME |
| ) |
|
|