""" MolecularModule: Domain knowledge for chemistry and biology. Element embeddings, SMILES understanding, bond types, amino acids. """ import re import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List class MolecularModule(nn.Module): """ Domain knowledge for chemistry and biology. - All 118 elements as learned embeddings with properties (atomic number, mass, electronegativity, valence electrons) - SMILES string understanding for molecular structures - Bond type awareness (covalent, ionic, hydrogen, van der Waals) - Amino acid sequence understanding for biology/zoology - Molecular formula → property reasoning """ def __init__(self, d_model: int, num_elements: int = 118): """ Initialize MolecularModule. Args: d_model: Model dimension num_elements: Number of chemical elements (default 118) """ super().__init__() self.d_model = d_model self.num_elements = num_elements # Element embeddings — 118 elements self.element_embed = nn.Embedding(num_elements + 1, d_model) # +1 for unknown # Element property encoder (12 properties) # [atomic_number, mass, electronegativity, valence_e, period, group, # atomic_radius, ionization_energy, electron_affinity, density, # melting_point, boiling_point] self.property_proj = nn.Linear(12, d_model) # Bond type embeddings (8 types) # 0: none, 1: single, 2: double, 3: triple, 4: aromatic, # 5: ionic, 6: hydrogen, 7: van der waals self.bond_embed = nn.Embedding(8, d_model) # Amino acid embeddings (20 standard + special) self.amino_acid_vocab = 25 # 20 standard + stop + start + unknown + special self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model) # Molecular graph attention (treats molecules as graphs) self.mol_attention = nn.MultiheadAttention( d_model, num_heads=8, batch_first=True, dropout=0.1, ) # Property prediction head (for auxiliary tasks) self.property_head = nn.Linear(d_model, 12) # Initialize weights self._initialize_weights() # Pre-compute element properties (simplified) self._init_element_properties() def _initialize_weights(self): """Initialize weights.""" for module in [self.element_embed, self.property_proj, self.bond_embed, self.amino_embed, self.mol_attention, self.property_head]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) def _init_element_properties(self): """Initialize element property table with approximate values.""" # This is a simplified version - in practice would load from database # Properties: [atomic_number, mass, electronegativity, valence_e, period, group, # atomic_radius, ionization_energy, electron_affinity, density, # melting_point, boiling_point] properties = torch.zeros(self.num_elements + 1, 12) # Fill in known elements (simplified data for first 20 + some common ones) # Real implementation would use a comprehensive chemistry database element_data = { 1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20], 6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027], 7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77], 8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90], # ... would fill all 118 elements } for z, props in element_data.items(): properties[z] = torch.tensor(props) self.register_buffer("element_properties", properties) def detect_molecular_spans( self, text: str, ) -> List[Tuple[int, int, str]]: """ Detect molecular/chemical spans in text. Args: text: Input text string Returns: List of (start_char, end_char, span_type) span_type: "formula", "smiles", "amino_acid" """ spans = [] # Chemical formulas: H2O, CO2, C6H12O6, NaCl, HCl formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b' for match in re.finditer(formula_pattern, text): # Filter out single letters that are not formulas span = match.group() if len(span) > 1 or span.isupper(): spans.append((match.start(), match.end(), "formula")) # SMILES patterns (simplified detection) # Contains: =, #, @, [], (), numbers in sequence smiles_hints = ['=', '#', '@', '[', ']', '(', ')'] words = re.findall(r'\S+', text) for word in words: if any(hint in word for hint in smiles_hints) and len(word) > 3: # Find position in text pos = text.find(word) if pos >= 0: spans.append((pos, pos + len(word), "smiles")) # Amino acid sequences (single letters, length > 5) aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b' for match in re.finditer(aa_pattern, text.upper()): spans.append((match.start(), match.end(), "amino_acid")) return spans def encode_molecule( self, formula: str, ) -> torch.Tensor: """ Encode a molecular formula into embedding. Args: formula: Chemical formula string (e.g., "C6H12O6") Returns: Molecule embedding (d_model,) """ # Parse formula into elements and counts # Simplified parser - real would handle nested parentheses pattern = r'([A-Z][a-z]?)(\d*)' matches = re.findall(pattern, formula) device = self.element_embed.weight.device embeddings = [] weights = [] for element, count_str in matches: # Get element atomic number (simplified mapping) element_map = { 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20, # ... extend as needed } z = element_map.get(element, 0) # 0 = unknown count = int(count_str) if count_str else 1 # Get element embedding elem_emb = self.element_embed(torch.tensor(z, device=device)) # Get properties and project props = self.element_properties[z].unsqueeze(0) # (1, 12) props_emb = self.property_proj(props).squeeze(0) # Combine combined = elem_emb + props_emb embeddings.append(combined) weights.append(count) if not embeddings: # Return zero embedding return torch.zeros(self.d_model, device=device) # Weighted average embeddings = torch.stack(embeddings) weights = torch.tensor(weights, dtype=torch.float32, device=device) weights = weights / weights.sum() return (embeddings * weights.unsqueeze(-1)).sum(dim=0) def forward( self, x: torch.Tensor, text: Optional[List[str]] = None, molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None, ) -> torch.Tensor: """ Forward pass through molecular module. Args: x: Input tensor (batch, seq_len, d_model) text: Optional original text strings molecular_spans: Optional pre-computed molecular spans per batch Returns: Molecular-enhanced representation (batch, seq_len, d_model) """ batch, seq_len, d_model = x.shape device = x.device # Detect molecular spans if molecular_spans is None and text is not None: molecular_spans = [] for b in range(batch): spans = self.detect_molecular_spans(text[b]) # Convert char spans to token spans token_spans = [] for start_char, end_char, span_type in spans: start_tok = max(0, start_char // 4) end_tok = min(seq_len, end_char // 4 + 1) token_spans.append((start_tok, end_tok, span_type)) molecular_spans.append(token_spans) # Enhance molecular spans output = x.clone() if molecular_spans: for b in range(batch): spans_b = molecular_spans[b] if b < len(molecular_spans) else [] for start_tok, end_tok, span_type in spans_b: if end_tok <= start_tok: continue span_slice = x[b, start_tok:end_tok, :] if span_type == "formula": # Extract formula from text if available if text: formula = text[b][start_tok*4:end_tok*4] # rough extraction mol_emb = self.encode_molecule(formula) else: mol_emb = torch.randn(d_model, device=device) # Add molecular embedding to first token output[b, start_tok, :] += mol_emb elif span_type == "amino_acid": # Encode as amino acid sequence # Simplified: treat each letter as amino acid seq_len_span = end_tok - start_tok aa_ids = torch.randint(0, 20, (seq_len_span,), device=device) aa_emb = self.amino_embed(aa_ids) # (seq_len_span, d_model) output[b, start_tok:end_tok, :] += aa_emb elif span_type == "smiles": # For SMILES, apply graph attention (simplified) # Treat each character as a node seq_len_span = end_tok - start_tok if seq_len_span > 1: # Self-attention over the span attn_out, _ = self.mol_attention( span_slice.unsqueeze(0), span_slice.unsqueeze(0), span_slice.unsqueeze(0), ) output[b, start_tok:end_tok, :] += attn_out.squeeze(0) return output def compute_property_loss( self, x: torch.Tensor, element_ids: torch.Tensor, target_properties: torch.Tensor, ) -> torch.Tensor: """ Compute auxiliary loss for property prediction. Args: x: Input tensor (batch, seq_len, d_model) element_ids: Element IDs (batch, seq_len) target_properties: Target property values (batch, seq_len, 12) Returns: MSE loss for property prediction """ # Get element embeddings elem_emb = self.element_embed(element_ids) # Predict properties pred_props = self.property_head(elem_emb) # Compute loss loss = F.mse_loss(pred_props, target_properties) return loss def test_molecular_module(): """Test MolecularModule.""" d_model = 512 batch_size = 2 seq_len = 128 module = MolecularModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = [ "Water is H2O. The DNA sequence is ACGTACGTACGT.", "Proteins are made of amino acids like ACDEFGH. Benzene is C6H6." ] output = module(x, text=text) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape print("MolecularModule test passed!") if __name__ == "__main__": test_molecular_module()