nas / BioReason /bioreason /models /protein_utils.py
yuccaaa's picture
Add files using upload-large-folder tool
ffcfc75 verified
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
@dataclass
class ProteinInput:
"""
Data class for protein-text input pairs.
This class encapsulates a text input along with associated protein sequences
for processing by the ESM2-LLM model.
"""
text: str
protein_sequences: List[str]
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
"""Validate inputs after initialization."""
if not isinstance(self.text, str):
raise TypeError("text must be a string")
if not isinstance(self.protein_sequences, list):
raise TypeError("protein_sequences must be a list")
for i, seq in enumerate(self.protein_sequences):
if not isinstance(seq, str):
raise TypeError(f"protein_sequences[{i}] must be a string")
# Basic validation for protein sequence (amino acid characters)
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') # X for unknown
if not all(aa.upper() in valid_aa for aa in seq):
raise ValueError(f"protein_sequences[{i}] contains invalid amino acid characters")
def __len__(self):
"""Return the number of protein sequences."""
return len(self.protein_sequences)
def __getitem__(self, index):
"""Get a specific protein sequence by index."""
return self.protein_sequences[index]
def add_protein_sequence(self, sequence: str):
"""Add a protein sequence to the input."""
if not isinstance(sequence, str):
raise TypeError("sequence must be a string")
# Validate amino acid sequence
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX')
if not all(aa.upper() in valid_aa for aa in sequence):
raise ValueError("sequence contains invalid amino acid characters")
self.protein_sequences.append(sequence)
def get_total_protein_length(self) -> int:
"""Get the total length of all protein sequences."""
return sum(len(seq) for seq in self.protein_sequences)
def get_protein_lengths(self) -> List[int]:
"""Get the length of each protein sequence."""
return [len(seq) for seq in self.protein_sequences]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"text": self.text,
"protein_sequences": self.protein_sequences,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ProteinInput":
"""Create ProteinInput from dictionary."""
return cls(
text=data["text"],
protein_sequences=data["protein_sequences"],
metadata=data.get("metadata"),
)
def validate_protein_sequence(sequence: str) -> bool:
"""
Validate if a string is a valid protein sequence.
Args:
sequence: String to validate
Returns:
True if valid, False otherwise
"""
if not isinstance(sequence, str):
return False
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') # Standard amino acids + X for unknown
return all(aa.upper() in valid_aa for aa in sequence)
def clean_protein_sequence(sequence: str) -> str:
"""
Clean a protein sequence by removing invalid characters and whitespace.
Args:
sequence: Raw protein sequence
Returns:
Cleaned protein sequence
"""
if not isinstance(sequence, str):
raise TypeError("sequence must be a string")
# Remove whitespace and convert to uppercase
cleaned = ''.join(sequence.split()).upper()
# Keep only valid amino acid characters
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX')
cleaned = ''.join(aa for aa in cleaned if aa in valid_aa)
return cleaned
def load_protein_sequences_from_fasta(filepath: str) -> List[Dict[str, str]]:
"""
Load protein sequences from a FASTA file.
Args:
filepath: Path to FASTA file
Returns:
List of dictionaries with 'id' and 'sequence' keys
"""
sequences = []
current_id = None
current_seq = []
try:
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if line.startswith('>'):
# Save previous sequence if exists
if current_id is not None and current_seq:
sequences.append({
'id': current_id,
'sequence': clean_protein_sequence(''.join(current_seq))
})
# Start new sequence
current_id = line[1:] # Remove '>'
current_seq = []
elif line:
current_seq.append(line)
# Save last sequence
if current_id is not None and current_seq:
sequences.append({
'id': current_id,
'sequence': clean_protein_sequence(''.join(current_seq))
})
except FileNotFoundError:
raise FileNotFoundError(f"FASTA file not found: {filepath}")
except Exception as e:
raise Exception(f"Error reading FASTA file: {e}")
return sequences
def create_protein_inputs_from_fasta(
fasta_filepath: str,
text_template: str = "Analyze this protein: {protein_id}",
include_sequence_in_text: bool = False,
) -> List[ProteinInput]:
"""
Create ProteinInput objects from a FASTA file.
Args:
fasta_filepath: Path to FASTA file
text_template: Template for creating text (can use {protein_id})
include_sequence_in_text: Whether to include sequence in text
Returns:
List of ProteinInput objects
"""
sequences = load_protein_sequences_from_fasta(fasta_filepath)
protein_inputs = []
for seq_data in sequences:
# Create text from template
text = text_template.format(protein_id=seq_data['id'])
if include_sequence_in_text:
text += f" Sequence: {seq_data['sequence']}"
# Create ProteinInput
protein_input = ProteinInput(
text=text,
protein_sequences=[seq_data['sequence']],
metadata={'protein_id': seq_data['id']}
)
protein_inputs.append(protein_input)
return protein_inputs
# Example usage and utility functions
if __name__ == "__main__":
# Example usage
protein_input = ProteinInput(
text="What is the function of this protein?",
protein_sequences=[
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
"ACDEFGHIKLMNPQRSTVWY"
],
metadata={"source": "example"}
)
print(f"Number of protein sequences: {len(protein_input)}")
print(f"Total protein length: {protein_input.get_total_protein_length()}")
print(f"Individual lengths: {protein_input.get_protein_lengths()}")
# Test validation
try:
invalid_input = ProteinInput(
text="Test",
protein_sequences=["INVALID123"] # Contains numbers
)
except ValueError as e:
print(f"Validation caught invalid sequence: {e}")
# Test cleaning
dirty_sequence = " AC DE FG HI 123 KL "
clean_sequence = clean_protein_sequence(dirty_sequence)
print(f"Cleaned sequence: '{clean_sequence}'")