nas / BioReason /bioreason /dna_modules /protein_module.py
yuccaaa's picture
Add files using upload-large-folder tool
ffcfc75 verified
from abc import ABC, abstractmethod
from typing import Dict, Any, Union, List, Optional, Type
import torch
class ProteinBaseModule(ABC):
"""
Abstract base class for protein-language model modules.
This class defines the interface that all protein-LLM implementations
must follow, providing standardized methods for model loading,
processing, and training integration.
"""
def __init__(self):
"""Initialize the protein module."""
super().__init__()
@abstractmethod
def get_protein_llm_key(self) -> str:
"""
Get the unique identifier for this protein-LLM implementation.
Returns:
String identifier for this module type
"""
pass
@abstractmethod
def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
"""
Return the appropriate model class based on model ID.
Args:
model_id: Identifier for the model
model_init_kwargs: Initialization arguments for the model
Returns:
The model class to instantiate
"""
pass
def post_model_init(self, model: Any, processing_class: Any) -> None:
"""
Perform any post-initialization setup on the model.
Args:
model: The initialized model
processing_class: The processor for the model
"""
# Default implementation does nothing
pass
def is_embeds_input(self) -> bool:
"""
Whether the model uses embeddings as input (instead of token IDs).
Returns:
Boolean indicating if the model takes embedding inputs
"""
# Default for protein-LLM models is True due to Q-Former integration
return True
@abstractmethod
def get_processing_class(self) -> Type:
"""
Get the processing class to use with this protein-LLM model.
Returns:
The processing class
"""
pass
@abstractmethod
def get_protein_llm_modules_keywords(self) -> List[str]:
"""
Get keywords to identify protein-specific modules in the model.
Used to exclude protein modules from LoRA adaptation during training
or to identify components for specific training strategies.
Returns:
List of keywords that identify protein modules
"""
pass
@abstractmethod
def get_custom_multimodal_keywords(self) -> List[str]:
"""
Get keywords for multimodal inputs that should be passed to the model.
Returns:
List of input keywords for multimodal processing
"""
pass
@abstractmethod
def get_non_generate_params(self) -> List[str]:
"""
Get parameter names that should be excluded from generation.
Returns:
List of parameter names to exclude from generation calls
"""
pass
@abstractmethod
def get_custom_processing_keywords(self) -> List[tuple]:
"""
Get custom processing keywords for the processor.
Returns:
List of (component, parameter) tuples for custom processing
"""
pass
@abstractmethod
def prepare_prompt(
self,
processing_class: Any,
inputs: List[Dict[str, Union[torch.Tensor, Any]]]
) -> List[str]:
"""
Prepare prompts from input examples.
Args:
processing_class: The processor to use
inputs: List of input examples
Returns:
List of prepared prompts
"""
pass
@abstractmethod
def prepare_model_inputs(
self,
processing_class: Any,
model: Any,
prompts_text: List[str],
batch_protein_sequences: List[List[str]],
return_tensors: str = "pt",
padding: bool = True,
padding_side: str = "left",
add_special_tokens: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Prepare inputs for the model.
Args:
processing_class: The processor to use
model: The model to prepare inputs for
prompts_text: List of text prompts
batch_protein_sequences: List of lists of protein sequences
return_tensors: Return format for tensors
padding: Whether to pad inputs
padding_side: Side to pad on
add_special_tokens: Whether to add special tokens
**kwargs: Additional arguments
Returns:
Processed inputs for the model
"""
pass
def get_reward_functions(self) -> Dict[str, callable]:
"""
Get available reward functions for this module.
Returns:
Dictionary mapping function names to callables
"""
return {}
def validate_model_config(self, config: Dict[str, Any]) -> bool:
"""
Validate model configuration parameters.
Args:
config: Configuration dictionary
Returns:
True if valid, False otherwise
"""
return True
def get_default_generation_config(self) -> Dict[str, Any]:
"""
Get default generation configuration for this model type.
Returns:
Dictionary of default generation parameters
"""
return {
"max_new_tokens": 512,
"temperature": 0.7,
"do_sample": True,
"top_p": 0.9,
"pad_token_id": None, # Will be set by processor
}