from abc import ABC, abstractmethod from typing import Dict, Any, Union, List, Optional, Type import torch class ProteinBaseModule(ABC): """ Abstract base class for protein-language model modules. This class defines the interface that all protein-LLM implementations must follow, providing standardized methods for model loading, processing, and training integration. """ def __init__(self): """Initialize the protein module.""" super().__init__() @abstractmethod def get_protein_llm_key(self) -> str: """ Get the unique identifier for this protein-LLM implementation. Returns: String identifier for this module type """ pass @abstractmethod def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type: """ Return the appropriate model class based on model ID. Args: model_id: Identifier for the model model_init_kwargs: Initialization arguments for the model Returns: The model class to instantiate """ pass def post_model_init(self, model: Any, processing_class: Any) -> None: """ Perform any post-initialization setup on the model. Args: model: The initialized model processing_class: The processor for the model """ # Default implementation does nothing pass def is_embeds_input(self) -> bool: """ Whether the model uses embeddings as input (instead of token IDs). Returns: Boolean indicating if the model takes embedding inputs """ # Default for protein-LLM models is True due to Q-Former integration return True @abstractmethod def get_processing_class(self) -> Type: """ Get the processing class to use with this protein-LLM model. Returns: The processing class """ pass @abstractmethod def get_protein_llm_modules_keywords(self) -> List[str]: """ Get keywords to identify protein-specific modules in the model. Used to exclude protein modules from LoRA adaptation during training or to identify components for specific training strategies. Returns: List of keywords that identify protein modules """ pass @abstractmethod def get_custom_multimodal_keywords(self) -> List[str]: """ Get keywords for multimodal inputs that should be passed to the model. Returns: List of input keywords for multimodal processing """ pass @abstractmethod def get_non_generate_params(self) -> List[str]: """ Get parameter names that should be excluded from generation. Returns: List of parameter names to exclude from generation calls """ pass @abstractmethod def get_custom_processing_keywords(self) -> List[tuple]: """ Get custom processing keywords for the processor. Returns: List of (component, parameter) tuples for custom processing """ pass @abstractmethod def prepare_prompt( self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]] ) -> List[str]: """ Prepare prompts from input examples. Args: processing_class: The processor to use inputs: List of input examples Returns: List of prepared prompts """ pass @abstractmethod def prepare_model_inputs( self, processing_class: Any, model: Any, prompts_text: List[str], batch_protein_sequences: List[List[str]], return_tensors: str = "pt", padding: bool = True, padding_side: str = "left", add_special_tokens: bool = False, **kwargs ) -> Dict[str, Any]: """ Prepare inputs for the model. Args: processing_class: The processor to use model: The model to prepare inputs for prompts_text: List of text prompts batch_protein_sequences: List of lists of protein sequences return_tensors: Return format for tensors padding: Whether to pad inputs padding_side: Side to pad on add_special_tokens: Whether to add special tokens **kwargs: Additional arguments Returns: Processed inputs for the model """ pass def get_reward_functions(self) -> Dict[str, callable]: """ Get available reward functions for this module. Returns: Dictionary mapping function names to callables """ return {} def validate_model_config(self, config: Dict[str, Any]) -> bool: """ Validate model configuration parameters. Args: config: Configuration dictionary Returns: True if valid, False otherwise """ return True def get_default_generation_config(self) -> Dict[str, Any]: """ Get default generation configuration for this model type. Returns: Dictionary of default generation parameters """ return { "max_new_tokens": 512, "temperature": 0.7, "do_sample": True, "top_p": 0.9, "pad_token_id": None, # Will be set by processor }