| from dataclasses import dataclass |
| from typing import List, Optional, Dict, Any |
|
|
|
|
| @dataclass |
| class ProteinInput: |
| """ |
| Data class for protein-text input pairs. |
| |
| This class encapsulates a text input along with associated protein sequences |
| for processing by the ESM2-LLM model. |
| """ |
| text: str |
| protein_sequences: List[str] |
| metadata: Optional[Dict[str, Any]] = None |
| |
| def __post_init__(self): |
| """Validate inputs after initialization.""" |
| if not isinstance(self.text, str): |
| raise TypeError("text must be a string") |
| |
| if not isinstance(self.protein_sequences, list): |
| raise TypeError("protein_sequences must be a list") |
| |
| for i, seq in enumerate(self.protein_sequences): |
| if not isinstance(seq, str): |
| raise TypeError(f"protein_sequences[{i}] must be a string") |
| |
| |
| valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') |
| if not all(aa.upper() in valid_aa for aa in seq): |
| raise ValueError(f"protein_sequences[{i}] contains invalid amino acid characters") |
| |
| def __len__(self): |
| """Return the number of protein sequences.""" |
| return len(self.protein_sequences) |
| |
| def __getitem__(self, index): |
| """Get a specific protein sequence by index.""" |
| return self.protein_sequences[index] |
| |
| def add_protein_sequence(self, sequence: str): |
| """Add a protein sequence to the input.""" |
| if not isinstance(sequence, str): |
| raise TypeError("sequence must be a string") |
| |
| |
| valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') |
| if not all(aa.upper() in valid_aa for aa in sequence): |
| raise ValueError("sequence contains invalid amino acid characters") |
| |
| self.protein_sequences.append(sequence) |
| |
| def get_total_protein_length(self) -> int: |
| """Get the total length of all protein sequences.""" |
| return sum(len(seq) for seq in self.protein_sequences) |
| |
| def get_protein_lengths(self) -> List[int]: |
| """Get the length of each protein sequence.""" |
| return [len(seq) for seq in self.protein_sequences] |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to dictionary representation.""" |
| return { |
| "text": self.text, |
| "protein_sequences": self.protein_sequences, |
| "metadata": self.metadata, |
| } |
| |
| @classmethod |
| def from_dict(cls, data: Dict[str, Any]) -> "ProteinInput": |
| """Create ProteinInput from dictionary.""" |
| return cls( |
| text=data["text"], |
| protein_sequences=data["protein_sequences"], |
| metadata=data.get("metadata"), |
| ) |
|
|
|
|
| def validate_protein_sequence(sequence: str) -> bool: |
| """ |
| Validate if a string is a valid protein sequence. |
| |
| Args: |
| sequence: String to validate |
| |
| Returns: |
| True if valid, False otherwise |
| """ |
| if not isinstance(sequence, str): |
| return False |
| |
| valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') |
| return all(aa.upper() in valid_aa for aa in sequence) |
|
|
|
|
| def clean_protein_sequence(sequence: str) -> str: |
| """ |
| Clean a protein sequence by removing invalid characters and whitespace. |
| |
| Args: |
| sequence: Raw protein sequence |
| |
| Returns: |
| Cleaned protein sequence |
| """ |
| if not isinstance(sequence, str): |
| raise TypeError("sequence must be a string") |
| |
| |
| cleaned = ''.join(sequence.split()).upper() |
| |
| |
| valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') |
| cleaned = ''.join(aa for aa in cleaned if aa in valid_aa) |
| |
| return cleaned |
|
|
|
|
| def load_protein_sequences_from_fasta(filepath: str) -> List[Dict[str, str]]: |
| """ |
| Load protein sequences from a FASTA file. |
| |
| Args: |
| filepath: Path to FASTA file |
| |
| Returns: |
| List of dictionaries with 'id' and 'sequence' keys |
| """ |
| sequences = [] |
| current_id = None |
| current_seq = [] |
| |
| try: |
| with open(filepath, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if line.startswith('>'): |
| |
| if current_id is not None and current_seq: |
| sequences.append({ |
| 'id': current_id, |
| 'sequence': clean_protein_sequence(''.join(current_seq)) |
| }) |
| |
| |
| current_id = line[1:] |
| current_seq = [] |
| elif line: |
| current_seq.append(line) |
| |
| |
| if current_id is not None and current_seq: |
| sequences.append({ |
| 'id': current_id, |
| 'sequence': clean_protein_sequence(''.join(current_seq)) |
| }) |
| |
| except FileNotFoundError: |
| raise FileNotFoundError(f"FASTA file not found: {filepath}") |
| except Exception as e: |
| raise Exception(f"Error reading FASTA file: {e}") |
| |
| return sequences |
|
|
|
|
| def create_protein_inputs_from_fasta( |
| fasta_filepath: str, |
| text_template: str = "Analyze this protein: {protein_id}", |
| include_sequence_in_text: bool = False, |
| ) -> List[ProteinInput]: |
| """ |
| Create ProteinInput objects from a FASTA file. |
| |
| Args: |
| fasta_filepath: Path to FASTA file |
| text_template: Template for creating text (can use {protein_id}) |
| include_sequence_in_text: Whether to include sequence in text |
| |
| Returns: |
| List of ProteinInput objects |
| """ |
| sequences = load_protein_sequences_from_fasta(fasta_filepath) |
| |
| protein_inputs = [] |
| for seq_data in sequences: |
| |
| text = text_template.format(protein_id=seq_data['id']) |
| |
| if include_sequence_in_text: |
| text += f" Sequence: {seq_data['sequence']}" |
| |
| |
| protein_input = ProteinInput( |
| text=text, |
| protein_sequences=[seq_data['sequence']], |
| metadata={'protein_id': seq_data['id']} |
| ) |
| |
| protein_inputs.append(protein_input) |
| |
| return protein_inputs |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| protein_input = ProteinInput( |
| text="What is the function of this protein?", |
| protein_sequences=[ |
| "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", |
| "ACDEFGHIKLMNPQRSTVWY" |
| ], |
| metadata={"source": "example"} |
| ) |
| |
| print(f"Number of protein sequences: {len(protein_input)}") |
| print(f"Total protein length: {protein_input.get_total_protein_length()}") |
| print(f"Individual lengths: {protein_input.get_protein_lengths()}") |
| |
| |
| try: |
| invalid_input = ProteinInput( |
| text="Test", |
| protein_sequences=["INVALID123"] |
| ) |
| except ValueError as e: |
| print(f"Validation caught invalid sequence: {e}") |
| |
| |
| dirty_sequence = " AC DE FG HI 123 KL " |
| clean_sequence = clean_protein_sequence(dirty_sequence) |
| print(f"Cleaned sequence: '{clean_sequence}'") |