Spaces:
Running
on
Zero
Running
on
Zero
| # Necessary imports | |
| import os | |
| import sys | |
| from dotenv import load_dotenv | |
| from typing import Any | |
| import torch | |
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| # Local imports | |
| from src.logger import logging | |
| from src.exception import CustomExceptionHandling | |
| # Load the Environment Variables from .env file | |
| load_dotenv() | |
| # Access token for using the model | |
| access_token = os.environ.get("ACCESS_TOKEN") | |
| def load_model_and_processor(model_name: str, device: str) -> Any: | |
| """ | |
| Load the model and processor. | |
| Args: | |
| - model_name (str): The name of the model to load. | |
| - device (str): The device to load the model onto. | |
| Returns: | |
| - model: The loaded model. | |
| - processor: The loaded processor. | |
| """ | |
| try: | |
| # Load the model and processor | |
| model = ( | |
| PaliGemmaForConditionalGeneration.from_pretrained( | |
| model_name, dtype=torch.bfloat16, token=access_token | |
| ) | |
| .eval() | |
| .to(device) | |
| ) | |
| processor = PaliGemmaProcessor.from_pretrained(model_name, use_fast=True, token=access_token) | |
| # Log the successful loading of the model and processor | |
| logging.info("Model and processor loaded successfully.") | |
| # Return the model and processor | |
| return model, processor | |
| # Handle exceptions that may occur during model and processor loading | |
| except Exception as e: | |
| # Custom exception handling | |
| raise CustomExceptionHandling(e, sys) from e | |