peptide-vae-amp / peptide_encoder.py
geestaltt's picture
Upload folder using huggingface_hub
79eb85c verified
# Copyright 2024-2025 AI Whisperers (https://github.com/Ai-Whisperers)
#
# Licensed under the PolyForm Noncommercial License 1.0.0
# See LICENSE file in the repository root for full license text.
#
# For commercial licensing inquiries: support@aiwhisperers.com
"""PeptideEncoder: Biologically-Grounded AMP Activity Predictor.
This module implements a learned peptide encoder for antimicrobial peptide (AMP)
activity prediction. Following the successful TrainableCodonEncoder pattern
(Spearman 0.60 for DDG), it uses multi-component embeddings and hyperbolic
projections to learn biologically meaningful representations.
Architecture:
Input: Peptide Sequence (10-50 AA)
→ PeptideInputProcessor (tokenize, pad, position encode)
→ MultiComponentEmbedding (AA + 5-adic group + properties = 56D)
→ Transformer Encoder (2 layers, 4 heads)
→ Dual Pooling (mean + attention = 112D)
→ HyperbolicProjection (16D Poincaré ball)
→ MIC Prediction Head (16D → 1)
Decoder Path:
→ Hyperbolic → Euclidean (inverse projection)
→ Transformer Decoder (2 layers, 4 heads, causal mask)
→ Sequence Output (vocab size 22)
Loss Components (6):
1. Reconstruction (sequence cross-entropy)
2. MIC Prediction (Smooth L1)
3. Property Alignment (embed dist ~ property dist)
4. Radial Hierarchy (low MIC → center)
5. Cohesion (same pathogen clusters)
6. Separation (different pathogens separate)
Usage:
from src.encoders.peptide_encoder import PeptideVAE
model = PeptideVAE(latent_dim=16)
z_hyp = model.encode(sequences) # (batch, 16) on Poincaré ball
mic_pred = model.predict_mic(z_hyp) # (batch, 1)
decoded = model.decode(z_hyp) # (batch, seq_len, vocab_size)
"""
from __future__ import annotations
import math
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from src.encoders.padic_amino_acid_encoder import (
AA_TO_GROUP,
AA_TO_INDEX,
AA_PROPERTIES,
INDEX_TO_AA,
AminoAcidGroup,
)
from src.geometry import (
exp_map_zero,
log_map_zero,
poincare_distance,
project_to_poincare,
)
from src.models.hyperbolic_projection import HyperbolicProjection
# =============================================================================
# Constants
# =============================================================================
MAX_SEQ_LEN = 50 # Maximum peptide length (padded)
VOCAB_SIZE = 22 # 20 AA + stop + unknown/pad
PAD_IDX = 21 # Index for padding token (X)
# =============================================================================
# Input Processing
# =============================================================================
class PeptideInputProcessor(nn.Module):
"""Process peptide sequences into model inputs.
Handles:
- Tokenization (AA → index 0-21)
- Padding to MAX_SEQ_LEN
- Positional encoding (sinusoidal)
- N/C-terminal distance features
"""
def __init__(
self,
max_seq_len: int = MAX_SEQ_LEN,
embedding_dim: int = 56,
):
"""Initialize processor.
Args:
max_seq_len: Maximum sequence length
embedding_dim: Position embedding dimension
"""
super().__init__()
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
# Precompute sinusoidal positional encoding
pe = torch.zeros(max_seq_len, embedding_dim)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
if embedding_dim % 2 == 1:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
else:
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('positional_encoding', pe)
def tokenize(self, sequence: str) -> Tensor:
"""Convert sequence string to token indices.
Args:
sequence: Amino acid sequence (uppercase)
Returns:
Token indices tensor (seq_len,)
"""
indices = []
for aa in sequence.upper():
idx = AA_TO_INDEX.get(aa, PAD_IDX)
indices.append(idx)
return torch.tensor(indices, dtype=torch.long)
def pad_sequence(self, tokens: Tensor) -> Tuple[Tensor, Tensor]:
"""Pad sequence to max_seq_len.
Args:
tokens: Token indices (seq_len,)
Returns:
Tuple of (padded_tokens, attention_mask)
"""
seq_len = tokens.shape[0]
if seq_len > self.max_seq_len:
# Truncate
padded = tokens[:self.max_seq_len]
mask = torch.ones(self.max_seq_len, dtype=torch.bool)
else:
# Pad
padded = F.pad(tokens, (0, self.max_seq_len - seq_len), value=PAD_IDX)
mask = torch.zeros(self.max_seq_len, dtype=torch.bool)
mask[:seq_len] = True
return padded, mask
def get_position_embeddings(self, seq_len: int, device: torch.device) -> Tensor:
"""Get positional embeddings for sequence.
Args:
seq_len: Actual sequence length
device: Target device
Returns:
Position embeddings (max_seq_len, embedding_dim)
"""
return self.positional_encoding[:self.max_seq_len].to(device)
def get_terminal_features(self, seq_len: int, device: torch.device) -> Tensor:
"""Get N/C-terminal distance features.
Args:
seq_len: Actual sequence length
device: Target device
Returns:
Terminal features (max_seq_len, 2) - [n_term_dist, c_term_dist]
"""
features = torch.zeros(self.max_seq_len, 2, device=device)
if seq_len > 0:
positions = torch.arange(self.max_seq_len, device=device).float()
# N-terminal distance (0 at start)
features[:, 0] = positions / max(seq_len - 1, 1)
# C-terminal distance (0 at end)
features[:, 1] = (seq_len - 1 - positions).clamp(min=0) / max(seq_len - 1, 1)
return features
def forward(
self,
sequences: List[str],
) -> Dict[str, Tensor]:
"""Process batch of sequences.
Args:
sequences: List of AA sequences
Returns:
Dictionary with tokens, mask, positions, terminal_features
"""
batch_size = len(sequences)
device = self.positional_encoding.device
all_tokens = []
all_masks = []
all_lengths = []
for seq in sequences:
tokens = self.tokenize(seq)
padded, mask = self.pad_sequence(tokens)
all_tokens.append(padded)
all_masks.append(mask)
all_lengths.append(len(seq))
tokens_batch = torch.stack(all_tokens).to(device)
masks_batch = torch.stack(all_masks).to(device)
# Position embeddings (shared across batch)
positions = self.get_position_embeddings(self.max_seq_len, device)
# Terminal features per sequence
terminal_features = torch.stack([
self.get_terminal_features(length, device)
for length in all_lengths
])
return {
'tokens': tokens_batch,
'mask': masks_batch,
'positions': positions,
'terminal_features': terminal_features,
'lengths': torch.tensor(all_lengths, device=device),
}
# =============================================================================
# Multi-Component Embedding
# =============================================================================
class PropertyEncoder(nn.Module):
"""Encode amino acid physicochemical properties to learned embeddings."""
def __init__(
self,
output_dim: int = 8,
n_properties: int = 4,
):
"""Initialize property encoder.
Args:
output_dim: Output embedding dimension
n_properties: Number of input properties (hydrophobicity, MW, pI, flexibility)
"""
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(n_properties, output_dim * 2),
nn.LayerNorm(output_dim * 2),
nn.GELU(),
nn.Linear(output_dim * 2, output_dim),
)
# Register normalized AA properties as buffer
props = torch.zeros(VOCAB_SIZE, n_properties)
for aa, idx in AA_TO_INDEX.items():
if idx < VOCAB_SIZE and aa in AA_PROPERTIES:
p = AA_PROPERTIES[aa]
# Normalize to ~[0, 1]
props[idx] = torch.tensor([
(p[0] + 5) / 10, # hydrophobicity: [-4.5, 4.5] → [0, 1]
p[1] / 250, # molecular weight: [75, 204] → ~[0.3, 0.8]
p[2] / 14, # isoelectric point: [2.77, 10.76] → ~[0.2, 0.8]
p[3], # flexibility: already [0, 1]
])
self.register_buffer('aa_properties', props)
def forward(self, token_indices: Tensor) -> Tensor:
"""Encode token properties.
Args:
token_indices: Token indices (batch, seq_len)
Returns:
Property embeddings (batch, seq_len, output_dim)
"""
props = self.aa_properties[token_indices]
return self.encoder(props)
class MultiComponentEmbedding(nn.Module):
"""Multi-component embedding combining AA, group, and property information.
Total dimension: aa_dim + group_dim + property_dim = 32 + 16 + 8 = 56
"""
def __init__(
self,
aa_dim: int = 32,
group_dim: int = 16,
property_dim: int = 8,
dropout: float = 0.1,
):
"""Initialize multi-component embedding.
Args:
aa_dim: AA embedding dimension
group_dim: 5-adic group embedding dimension
property_dim: Property encoding dimension
dropout: Dropout rate
"""
super().__init__()
self.total_dim = aa_dim + group_dim + property_dim
# AA embedding (22 tokens)
self.aa_embedding = nn.Embedding(VOCAB_SIZE, aa_dim, padding_idx=PAD_IDX)
# 5-adic group embedding (5 groups)
self.group_embedding = nn.Embedding(5, group_dim)
# Property encoder
self.property_encoder = PropertyEncoder(output_dim=property_dim)
# Normalization and dropout
self.norm = nn.LayerNorm(self.total_dim)
self.dropout = nn.Dropout(dropout)
# Register AA to group mapping
groups = torch.zeros(VOCAB_SIZE, dtype=torch.long)
for aa, idx in AA_TO_INDEX.items():
if idx < VOCAB_SIZE:
groups[idx] = AA_TO_GROUP.get(aa, AminoAcidGroup.SPECIAL)
self.register_buffer('aa_to_group', groups)
def forward(self, token_indices: Tensor) -> Tensor:
"""Get multi-component embeddings.
Args:
token_indices: Token indices (batch, seq_len)
Returns:
Combined embeddings (batch, seq_len, total_dim)
"""
# AA embeddings
aa_emb = self.aa_embedding(token_indices)
# Group embeddings
group_indices = self.aa_to_group[token_indices]
group_emb = self.group_embedding(group_indices)
# Property embeddings
prop_emb = self.property_encoder(token_indices)
# Concatenate
combined = torch.cat([aa_emb, group_emb, prop_emb], dim=-1)
combined = self.norm(combined)
combined = self.dropout(combined)
return combined
# =============================================================================
# Attention Pooling
# =============================================================================
class AttentionPooling(nn.Module):
"""Learned attention pooling for sequence aggregation."""
def __init__(
self,
input_dim: int,
n_heads: int = 4,
):
"""Initialize attention pooling.
Args:
input_dim: Input feature dimension
n_heads: Number of attention heads
"""
super().__init__()
# Learned query for attention pooling
self.query = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
# Multi-head attention
self.attention = nn.MultiheadAttention(
embed_dim=input_dim,
num_heads=n_heads,
batch_first=True,
)
def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
"""Apply attention pooling.
Args:
x: Sequence features (batch, seq_len, dim)
mask: Attention mask (batch, seq_len), True for valid positions
Returns:
Pooled features (batch, dim)
"""
batch_size = x.shape[0]
# Expand query to batch
query = self.query.expand(batch_size, -1, -1)
# Create key padding mask (True = ignore)
if mask is not None:
key_padding_mask = ~mask # Invert: True means padding
else:
key_padding_mask = None
# Attention pooling
pooled, _ = self.attention(
query, x, x,
key_padding_mask=key_padding_mask,
)
return pooled.squeeze(1)
# =============================================================================
# Peptide Encoder
# =============================================================================
class PeptideEncoderTransformer(nn.Module):
"""Transformer-based peptide encoder to hyperbolic space."""
def __init__(
self,
embedding_dim: int = 56,
hidden_dim: int = 128,
latent_dim: int = 16,
n_layers: int = 2,
n_heads: int = 4,
dropout: float = 0.1,
max_radius: float = 0.95,
curvature: float = 1.0,
):
"""Initialize peptide encoder.
Args:
embedding_dim: Input embedding dimension (from MultiComponentEmbedding)
hidden_dim: Transformer hidden dimension
latent_dim: Output latent dimension (Poincaré ball)
n_layers: Number of transformer layers
n_heads: Number of attention heads
dropout: Dropout rate
max_radius: Maximum radius in Poincaré ball
curvature: Hyperbolic curvature
"""
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.curvature = curvature
self.max_radius = max_radius
# Project embedding to hidden dim
self.input_proj = nn.Linear(embedding_dim, hidden_dim)
# Positional encoding (in hidden_dim space)
pe = torch.zeros(MAX_SEQ_LEN, hidden_dim)
position = torch.arange(0, MAX_SEQ_LEN, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
if hidden_dim % 2 == 1:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
else:
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('positional_encoding', pe)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=n_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# Dual pooling
self.mean_pool_proj = nn.Linear(hidden_dim, hidden_dim)
self.attention_pool = AttentionPooling(hidden_dim, n_heads=n_heads)
# Fusion layer (mean + attention = 2 * hidden_dim → hidden_dim)
self.fusion = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
)
# Hyperbolic projection
self.hyperbolic_proj = HyperbolicProjection(
latent_dim=latent_dim,
hidden_dim=hidden_dim,
max_radius=max_radius,
curvature=curvature,
n_layers=1,
dropout=dropout,
)
# Pre-projection from fusion to latent
self.pre_projection = nn.Linear(hidden_dim, latent_dim)
def forward(
self,
embeddings: Tensor,
mask: Optional[Tensor] = None,
positions: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
"""Encode peptide embeddings to hyperbolic space.
Args:
embeddings: Multi-component embeddings (batch, seq_len, embedding_dim)
mask: Attention mask (batch, seq_len), True for valid
positions: Position embeddings (seq_len, embedding_dim)
Returns:
Dictionary with z_hyp, z_euclidean, direction, radius
"""
batch_size = embeddings.shape[0]
# Project to hidden dim
x = self.input_proj(embeddings)
# Add positional encoding (use internal PE in hidden_dim space)
x = x + self.positional_encoding[:x.shape[1]].unsqueeze(0)
# Create transformer mask (True = ignore)
if mask is not None:
src_key_padding_mask = ~mask
else:
src_key_padding_mask = None
# Transformer encoding
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
# Dual pooling
# Mean pooling (masked)
if mask is not None:
mask_expanded = mask.unsqueeze(-1).float()
mean_pooled = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)
else:
mean_pooled = x.mean(dim=1)
mean_pooled = self.mean_pool_proj(mean_pooled)
# Attention pooling
attn_pooled = self.attention_pool(x, mask)
# Fuse pooled representations
fused = self.fusion(torch.cat([mean_pooled, attn_pooled], dim=-1))
# Project to latent dimension
z_euclidean = self.pre_projection(fused)
# Project to Poincaré ball with components
z_hyp, direction, radius = self.hyperbolic_proj.forward_with_components(z_euclidean)
return {
'z_hyp': z_hyp,
'z_euclidean': z_euclidean,
'direction': direction,
'radius': radius,
'transformer_output': x,
}
# =============================================================================
# Peptide Decoder
# =============================================================================
class PeptideDecoder(nn.Module):
"""Transformer-based decoder for sequence reconstruction."""
def __init__(
self,
latent_dim: int = 16,
hidden_dim: int = 128,
embedding_dim: int = 56,
n_layers: int = 2,
n_heads: int = 4,
dropout: float = 0.1,
max_seq_len: int = MAX_SEQ_LEN,
curvature: float = 1.0,
):
"""Initialize peptide decoder.
Args:
latent_dim: Input latent dimension
hidden_dim: Transformer hidden dimension
embedding_dim: Target embedding dimension
n_layers: Number of transformer layers
n_heads: Number of attention heads
dropout: Dropout rate
max_seq_len: Maximum sequence length
curvature: Hyperbolic curvature (for inverse projection)
"""
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.max_seq_len = max_seq_len
self.curvature = curvature
# Inverse hyperbolic projection: Poincaré → Euclidean
self.inverse_proj = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
)
# Initial sequence embedding (for autoregressive decoding start)
self.start_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
# Target embedding (for teacher forcing)
self.target_embedding = nn.Embedding(VOCAB_SIZE, hidden_dim, padding_idx=PAD_IDX)
# Positional encoding
pe = torch.zeros(max_seq_len, hidden_dim)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
if hidden_dim % 2 == 1:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
else:
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('positional_encoding', pe)
# Transformer decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=n_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
# Output projection to vocabulary
self.output_proj = nn.Linear(hidden_dim, VOCAB_SIZE)
# Register causal mask
causal_mask = torch.triu(
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool),
diagonal=1,
)
self.register_buffer('causal_mask', causal_mask)
def forward(
self,
z_hyp: Tensor,
target_tokens: Optional[Tensor] = None,
target_mask: Optional[Tensor] = None,
) -> Tensor:
"""Decode from hyperbolic latent to sequence logits.
Args:
z_hyp: Hyperbolic latent (batch, latent_dim)
target_tokens: Target tokens for teacher forcing (batch, seq_len)
target_mask: Target mask (batch, seq_len)
Returns:
Logits (batch, seq_len, vocab_size)
"""
batch_size = z_hyp.shape[0]
device = z_hyp.device
# Apply log map to get tangent space representation
z_tangent = log_map_zero(z_hyp, c=self.curvature)
# Inverse projection
memory = self.inverse_proj(z_tangent)
memory = memory.unsqueeze(1) # (batch, 1, hidden_dim)
if target_tokens is not None:
# Teacher forcing mode
seq_len = target_tokens.shape[1]
# Embed targets
tgt = self.target_embedding(target_tokens)
tgt = tgt + self.positional_encoding[:seq_len].unsqueeze(0)
# Create masks
tgt_mask = self.causal_mask[:seq_len, :seq_len].to(device)
tgt_key_padding_mask = ~target_mask if target_mask is not None else None
# Decode
output = self.transformer(
tgt, memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
)
else:
# Autoregressive mode (for inference)
# Start with start token
tgt = self.start_token.expand(batch_size, -1, -1)
outputs = []
for i in range(self.max_seq_len):
# Add positional encoding
tgt_pos = tgt + self.positional_encoding[:tgt.shape[1]].unsqueeze(0)
# Create causal mask
tgt_mask = self.causal_mask[:tgt.shape[1], :tgt.shape[1]].to(device)
# Decode one step
output = self.transformer(tgt_pos, memory, tgt_mask=tgt_mask)
# Get last token prediction
last_output = output[:, -1:, :]
outputs.append(last_output)
# Predict next token
logits = self.output_proj(last_output)
next_token = logits.argmax(dim=-1)
# Embed and append
next_emb = self.target_embedding(next_token)
tgt = torch.cat([tgt, next_emb], dim=1)
output = torch.cat(outputs, dim=1)
# Project to vocabulary
logits = self.output_proj(output)
return logits
# =============================================================================
# MIC Prediction Head
# =============================================================================
class MICPredictionHead(nn.Module):
"""Prediction head for MIC (Minimum Inhibitory Concentration)."""
def __init__(
self,
latent_dim: int = 16,
hidden_dim: int = 32,
dropout: float = 0.1,
):
"""Initialize MIC prediction head.
Args:
latent_dim: Input dimension (from hyperbolic space)
hidden_dim: Hidden layer dimension
dropout: Dropout rate
"""
super().__init__()
self.predictor = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
def forward(self, z_hyp: Tensor) -> Tensor:
"""Predict log10(MIC) from hyperbolic embedding.
Args:
z_hyp: Hyperbolic embeddings (batch, latent_dim)
Returns:
Predicted log10(MIC) (batch, 1)
"""
return self.predictor(z_hyp)
# =============================================================================
# Full PeptideVAE Model
# =============================================================================
class PeptideVAE(nn.Module):
"""Full Peptide VAE with encoder, decoder, and MIC prediction.
This is the main model class integrating all components for
antimicrobial peptide activity prediction.
"""
def __init__(
self,
latent_dim: int = 16,
hidden_dim: int = 128,
embedding_dim: int = 56,
n_layers: int = 2,
n_heads: int = 4,
dropout: float = 0.1,
max_radius: float = 0.95,
curvature: float = 1.0,
max_seq_len: int = MAX_SEQ_LEN,
):
"""Initialize PeptideVAE.
Args:
latent_dim: Latent dimension in Poincaré ball
hidden_dim: Transformer hidden dimension
embedding_dim: Multi-component embedding dimension
n_layers: Transformer layers
n_heads: Attention heads
dropout: Dropout rate
max_radius: Maximum Poincaré ball radius
curvature: Hyperbolic curvature
max_seq_len: Maximum sequence length
"""
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.curvature = curvature
self.max_radius = max_radius
self.max_seq_len = max_seq_len
# Input processing
self.input_processor = PeptideInputProcessor(
max_seq_len=max_seq_len,
embedding_dim=embedding_dim,
)
# Multi-component embedding
self.embedding = MultiComponentEmbedding(
aa_dim=32,
group_dim=16,
property_dim=8,
dropout=dropout,
)
# Encoder
self.encoder = PeptideEncoderTransformer(
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
latent_dim=latent_dim,
n_layers=n_layers,
n_heads=n_heads,
dropout=dropout,
max_radius=max_radius,
curvature=curvature,
)
# Decoder
self.decoder = PeptideDecoder(
latent_dim=latent_dim,
hidden_dim=hidden_dim,
embedding_dim=embedding_dim,
n_layers=n_layers,
n_heads=n_heads,
dropout=dropout,
max_seq_len=max_seq_len,
curvature=curvature,
)
# MIC prediction head
self.mic_head = MICPredictionHead(
latent_dim=latent_dim,
hidden_dim=hidden_dim // 4,
dropout=dropout,
)
def encode(
self,
sequences: List[str],
) -> Dict[str, Tensor]:
"""Encode peptide sequences to hyperbolic space.
Args:
sequences: List of amino acid sequences
Returns:
Dictionary with z_hyp, z_euclidean, direction, radius, etc.
"""
# Process inputs
inputs = self.input_processor(sequences)
# Get multi-component embeddings
embeddings = self.embedding(inputs['tokens'])
# Encode to hyperbolic space
encoder_output = self.encoder(
embeddings,
mask=inputs['mask'],
positions=inputs['positions'],
)
# Add input info to output
encoder_output['tokens'] = inputs['tokens']
encoder_output['mask'] = inputs['mask']
encoder_output['lengths'] = inputs['lengths']
return encoder_output
def decode(
self,
z_hyp: Tensor,
target_tokens: Optional[Tensor] = None,
target_mask: Optional[Tensor] = None,
) -> Tensor:
"""Decode from hyperbolic latent to sequence.
Args:
z_hyp: Hyperbolic latent (batch, latent_dim)
target_tokens: Target for teacher forcing (batch, seq_len)
target_mask: Target mask (batch, seq_len)
Returns:
Logits (batch, seq_len, vocab_size)
"""
return self.decoder(z_hyp, target_tokens, target_mask)
def predict_mic(self, z_hyp: Tensor) -> Tensor:
"""Predict MIC from hyperbolic embedding.
Args:
z_hyp: Hyperbolic embedding (batch, latent_dim)
Returns:
Predicted log10(MIC) (batch, 1)
"""
return self.mic_head(z_hyp)
def forward(
self,
sequences: List[str],
teacher_forcing: bool = True,
) -> Dict[str, Tensor]:
"""Full forward pass.
Args:
sequences: List of peptide sequences
teacher_forcing: Use teacher forcing for decoder
Returns:
Dictionary with all model outputs
"""
# Encode
encoder_output = self.encode(sequences)
# Decode with teacher forcing
if teacher_forcing:
logits = self.decode(
encoder_output['z_hyp'],
target_tokens=encoder_output['tokens'],
target_mask=encoder_output['mask'],
)
else:
logits = self.decode(encoder_output['z_hyp'])
# Predict MIC
mic_pred = self.predict_mic(encoder_output['z_hyp'])
return {
**encoder_output,
'logits': logits,
'mic_pred': mic_pred,
}
def get_hyperbolic_radii(self, z_hyp: Tensor) -> Tensor:
"""Get hyperbolic radii (distance from origin).
Args:
z_hyp: Hyperbolic embeddings (batch, latent_dim)
Returns:
Radii tensor (batch,)
"""
origin = torch.zeros(1, self.latent_dim, device=z_hyp.device)
radii = poincare_distance(z_hyp, origin.expand(z_hyp.shape[0], -1), c=self.curvature)
return radii
def generate(
self,
z_hyp: Tensor,
temperature: float = 1.0,
max_len: Optional[int] = None,
) -> List[str]:
"""Generate sequences from latent codes.
Args:
z_hyp: Hyperbolic latent codes (batch, latent_dim)
temperature: Sampling temperature
max_len: Maximum generation length
Returns:
List of generated sequences
"""
self.eval()
max_len = max_len or self.max_seq_len
with torch.no_grad():
logits = self.decode(z_hyp)
if temperature != 1.0:
logits = logits / temperature
# Get predicted tokens
tokens = logits.argmax(dim=-1)
# Convert to sequences
sequences = []
for token_seq in tokens:
seq = []
for idx in token_seq.cpu().numpy():
if idx == PAD_IDX:
break
aa = INDEX_TO_AA.get(idx, 'X')
if aa == '*':
break
seq.append(aa)
sequences.append(''.join(seq))
return sequences
# =============================================================================
# Exports
# =============================================================================
__all__ = [
'PeptideInputProcessor',
'PropertyEncoder',
'MultiComponentEmbedding',
'AttentionPooling',
'PeptideEncoderTransformer',
'PeptideDecoder',
'MICPredictionHead',
'PeptideVAE',
'MAX_SEQ_LEN',
'VOCAB_SIZE',
'PAD_IDX',
]