| | from abc import ABC, abstractmethod |
| | from typing import Dict, Any, Union, List, Optional, Type |
| | import torch |
| |
|
| | class ProteinBaseModule(ABC): |
| | """ |
| | Abstract base class for protein-language model modules. |
| | |
| | This class defines the interface that all protein-LLM implementations |
| | must follow, providing standardized methods for model loading, |
| | processing, and training integration. |
| | """ |
| | |
| | def __init__(self): |
| | """Initialize the protein module.""" |
| | super().__init__() |
| | |
| | @abstractmethod |
| | def get_protein_llm_key(self) -> str: |
| | """ |
| | Get the unique identifier for this protein-LLM implementation. |
| | |
| | Returns: |
| | String identifier for this module type |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type: |
| | """ |
| | Return the appropriate model class based on model ID. |
| | |
| | Args: |
| | model_id: Identifier for the model |
| | model_init_kwargs: Initialization arguments for the model |
| | |
| | Returns: |
| | The model class to instantiate |
| | """ |
| | pass |
| |
|
| | def post_model_init(self, model: Any, processing_class: Any) -> None: |
| | """ |
| | Perform any post-initialization setup on the model. |
| | |
| | Args: |
| | model: The initialized model |
| | processing_class: The processor for the model |
| | """ |
| | |
| | pass |
| |
|
| | def is_embeds_input(self) -> bool: |
| | """ |
| | Whether the model uses embeddings as input (instead of token IDs). |
| | |
| | Returns: |
| | Boolean indicating if the model takes embedding inputs |
| | """ |
| | |
| | return True |
| | |
| | @abstractmethod |
| | def get_processing_class(self) -> Type: |
| | """ |
| | Get the processing class to use with this protein-LLM model. |
| | |
| | Returns: |
| | The processing class |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def get_protein_llm_modules_keywords(self) -> List[str]: |
| | """ |
| | Get keywords to identify protein-specific modules in the model. |
| | |
| | Used to exclude protein modules from LoRA adaptation during training |
| | or to identify components for specific training strategies. |
| | |
| | Returns: |
| | List of keywords that identify protein modules |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def get_custom_multimodal_keywords(self) -> List[str]: |
| | """ |
| | Get keywords for multimodal inputs that should be passed to the model. |
| | |
| | Returns: |
| | List of input keywords for multimodal processing |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def get_non_generate_params(self) -> List[str]: |
| | """ |
| | Get parameter names that should be excluded from generation. |
| | |
| | Returns: |
| | List of parameter names to exclude from generation calls |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def get_custom_processing_keywords(self) -> List[tuple]: |
| | """ |
| | Get custom processing keywords for the processor. |
| | |
| | Returns: |
| | List of (component, parameter) tuples for custom processing |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def prepare_prompt( |
| | self, |
| | processing_class: Any, |
| | inputs: List[Dict[str, Union[torch.Tensor, Any]]] |
| | ) -> List[str]: |
| | """ |
| | Prepare prompts from input examples. |
| | |
| | Args: |
| | processing_class: The processor to use |
| | inputs: List of input examples |
| | |
| | Returns: |
| | List of prepared prompts |
| | """ |
| | pass |
| | |
| | @abstractmethod |
| | def prepare_model_inputs( |
| | self, |
| | processing_class: Any, |
| | model: Any, |
| | prompts_text: List[str], |
| | batch_protein_sequences: List[List[str]], |
| | return_tensors: str = "pt", |
| | padding: bool = True, |
| | padding_side: str = "left", |
| | add_special_tokens: bool = False, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Prepare inputs for the model. |
| | |
| | Args: |
| | processing_class: The processor to use |
| | model: The model to prepare inputs for |
| | prompts_text: List of text prompts |
| | batch_protein_sequences: List of lists of protein sequences |
| | return_tensors: Return format for tensors |
| | padding: Whether to pad inputs |
| | padding_side: Side to pad on |
| | add_special_tokens: Whether to add special tokens |
| | **kwargs: Additional arguments |
| | |
| | Returns: |
| | Processed inputs for the model |
| | """ |
| | pass |
| |
|
| | def get_reward_functions(self) -> Dict[str, callable]: |
| | """ |
| | Get available reward functions for this module. |
| | |
| | Returns: |
| | Dictionary mapping function names to callables |
| | """ |
| | return {} |
| |
|
| | def validate_model_config(self, config: Dict[str, Any]) -> bool: |
| | """ |
| | Validate model configuration parameters. |
| | |
| | Args: |
| | config: Configuration dictionary |
| | |
| | Returns: |
| | True if valid, False otherwise |
| | """ |
| | return True |
| |
|
| | def get_default_generation_config(self) -> Dict[str, Any]: |
| | """ |
| | Get default generation configuration for this model type. |
| | |
| | Returns: |
| | Dictionary of default generation parameters |
| | """ |
| | return { |
| | "max_new_tokens": 512, |
| | "temperature": 0.7, |
| | "do_sample": True, |
| | "top_p": 0.9, |
| | "pad_token_id": None, |
| | } |