"""llava.py. File for providing the Llava model implementation. """ from transformers import LlavaForConditionalGeneration from src.models.base import ModelBase from src.models.config import Config class LlavaModel(ModelBase): """Llava model implementation.""" def __init__(self, config: Config) -> None: """Initialization of the llava model. Args: config (Config): Parsed config """ # initialize the parent class super().__init__(config) def _load_specific_model(self) -> None: """Overridden function to populate self.model.""" self.model = LlavaForConditionalGeneration.from_pretrained( self.model_path, **self.config.model ) if hasattr(self.config, 'model') else ( LlavaForConditionalGeneration.from_pretrained( self.model_path ) )