| | """
|
| | MolecularModule: Domain knowledge for chemistry and biology.
|
| | Element embeddings, SMILES understanding, bond types, amino acids.
|
| | """
|
| |
|
| | import re
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple, List
|
| |
|
| |
|
| | class MolecularModule(nn.Module):
|
| | """
|
| | Domain knowledge for chemistry and biology.
|
| | - All 118 elements as learned embeddings with properties
|
| | (atomic number, mass, electronegativity, valence electrons)
|
| | - SMILES string understanding for molecular structures
|
| | - Bond type awareness (covalent, ionic, hydrogen, van der Waals)
|
| | - Amino acid sequence understanding for biology/zoology
|
| | - Molecular formula → property reasoning
|
| | """
|
| |
|
| | def __init__(self, d_model: int, num_elements: int = 118):
|
| | """
|
| | Initialize MolecularModule.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | num_elements: Number of chemical elements (default 118)
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.num_elements = num_elements
|
| |
|
| |
|
| | self.element_embed = nn.Embedding(num_elements + 1, d_model)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.property_proj = nn.Linear(12, d_model)
|
| |
|
| |
|
| |
|
| |
|
| | self.bond_embed = nn.Embedding(8, d_model)
|
| |
|
| |
|
| | self.amino_acid_vocab = 25
|
| | self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model)
|
| |
|
| |
|
| | self.mol_attention = nn.MultiheadAttention(
|
| | d_model,
|
| | num_heads=8,
|
| | batch_first=True,
|
| | dropout=0.1,
|
| | )
|
| |
|
| |
|
| | self.property_head = nn.Linear(d_model, 12)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| |
|
| | self._init_element_properties()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | for module in [self.element_embed, self.property_proj, self.bond_embed,
|
| | self.amino_embed, self.mol_attention, self.property_head]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if hasattr(module, 'bias') and module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| |
|
| | def _init_element_properties(self):
|
| | """Initialize element property table with approximate values."""
|
| |
|
| |
|
| |
|
| |
|
| | properties = torch.zeros(self.num_elements + 1, 12)
|
| |
|
| |
|
| |
|
| | element_data = {
|
| | 1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20],
|
| | 6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027],
|
| | 7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77],
|
| | 8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90],
|
| |
|
| | }
|
| |
|
| | for z, props in element_data.items():
|
| | properties[z] = torch.tensor(props)
|
| |
|
| | self.register_buffer("element_properties", properties)
|
| |
|
| | def detect_molecular_spans(
|
| | self,
|
| | text: str,
|
| | ) -> List[Tuple[int, int, str]]:
|
| | """
|
| | Detect molecular/chemical spans in text.
|
| |
|
| | Args:
|
| | text: Input text string
|
| |
|
| | Returns:
|
| | List of (start_char, end_char, span_type)
|
| | span_type: "formula", "smiles", "amino_acid"
|
| | """
|
| | spans = []
|
| |
|
| |
|
| | formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b'
|
| | for match in re.finditer(formula_pattern, text):
|
| |
|
| | span = match.group()
|
| | if len(span) > 1 or span.isupper():
|
| | spans.append((match.start(), match.end(), "formula"))
|
| |
|
| |
|
| |
|
| | smiles_hints = ['=', '#', '@', '[', ']', '(', ')']
|
| | words = re.findall(r'\S+', text)
|
| | for word in words:
|
| | if any(hint in word for hint in smiles_hints) and len(word) > 3:
|
| |
|
| | pos = text.find(word)
|
| | if pos >= 0:
|
| | spans.append((pos, pos + len(word), "smiles"))
|
| |
|
| |
|
| | aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b'
|
| | for match in re.finditer(aa_pattern, text.upper()):
|
| | spans.append((match.start(), match.end(), "amino_acid"))
|
| |
|
| | return spans
|
| |
|
| | def encode_molecule(
|
| | self,
|
| | formula: str,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Encode a molecular formula into embedding.
|
| |
|
| | Args:
|
| | formula: Chemical formula string (e.g., "C6H12O6")
|
| |
|
| | Returns:
|
| | Molecule embedding (d_model,)
|
| | """
|
| |
|
| |
|
| | pattern = r'([A-Z][a-z]?)(\d*)'
|
| | matches = re.findall(pattern, formula)
|
| |
|
| | device = self.element_embed.weight.device
|
| | embeddings = []
|
| | weights = []
|
| |
|
| | for element, count_str in matches:
|
| |
|
| | element_map = {
|
| | 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
|
| | 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
|
| | 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
|
| |
|
| | }
|
| | z = element_map.get(element, 0)
|
| |
|
| | count = int(count_str) if count_str else 1
|
| |
|
| |
|
| | elem_emb = self.element_embed(torch.tensor(z, device=device))
|
| |
|
| |
|
| | props = self.element_properties[z].unsqueeze(0)
|
| | props_emb = self.property_proj(props).squeeze(0)
|
| |
|
| |
|
| | combined = elem_emb + props_emb
|
| | embeddings.append(combined)
|
| | weights.append(count)
|
| |
|
| | if not embeddings:
|
| |
|
| | return torch.zeros(self.d_model, device=device)
|
| |
|
| |
|
| | embeddings = torch.stack(embeddings)
|
| | weights = torch.tensor(weights, dtype=torch.float32, device=device)
|
| | weights = weights / weights.sum()
|
| |
|
| | return (embeddings * weights.unsqueeze(-1)).sum(dim=0)
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | text: Optional[List[str]] = None,
|
| | molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through molecular module.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | text: Optional original text strings
|
| | molecular_spans: Optional pre-computed molecular spans per batch
|
| |
|
| | Returns:
|
| | Molecular-enhanced representation (batch, seq_len, d_model)
|
| | """
|
| | batch, seq_len, d_model = x.shape
|
| | device = x.device
|
| |
|
| |
|
| | if molecular_spans is None and text is not None:
|
| | molecular_spans = []
|
| | for b in range(batch):
|
| | spans = self.detect_molecular_spans(text[b])
|
| |
|
| | token_spans = []
|
| | for start_char, end_char, span_type in spans:
|
| | start_tok = max(0, start_char // 4)
|
| | end_tok = min(seq_len, end_char // 4 + 1)
|
| | token_spans.append((start_tok, end_tok, span_type))
|
| | molecular_spans.append(token_spans)
|
| |
|
| |
|
| | output = x.clone()
|
| |
|
| | if molecular_spans:
|
| | for b in range(batch):
|
| | spans_b = molecular_spans[b] if b < len(molecular_spans) else []
|
| |
|
| | for start_tok, end_tok, span_type in spans_b:
|
| | if end_tok <= start_tok:
|
| | continue
|
| |
|
| | span_slice = x[b, start_tok:end_tok, :]
|
| |
|
| | if span_type == "formula":
|
| |
|
| | if text:
|
| | formula = text[b][start_tok*4:end_tok*4]
|
| | mol_emb = self.encode_molecule(formula)
|
| | else:
|
| | mol_emb = torch.randn(d_model, device=device)
|
| |
|
| |
|
| | output[b, start_tok, :] += mol_emb
|
| |
|
| | elif span_type == "amino_acid":
|
| |
|
| |
|
| | seq_len_span = end_tok - start_tok
|
| | aa_ids = torch.randint(0, 20, (seq_len_span,), device=device)
|
| | aa_emb = self.amino_embed(aa_ids)
|
| | output[b, start_tok:end_tok, :] += aa_emb
|
| |
|
| | elif span_type == "smiles":
|
| |
|
| |
|
| | seq_len_span = end_tok - start_tok
|
| | if seq_len_span > 1:
|
| |
|
| | attn_out, _ = self.mol_attention(
|
| | span_slice.unsqueeze(0),
|
| | span_slice.unsqueeze(0),
|
| | span_slice.unsqueeze(0),
|
| | )
|
| | output[b, start_tok:end_tok, :] += attn_out.squeeze(0)
|
| |
|
| | return output
|
| |
|
| | def compute_property_loss(
|
| | self,
|
| | x: torch.Tensor,
|
| | element_ids: torch.Tensor,
|
| | target_properties: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Compute auxiliary loss for property prediction.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | element_ids: Element IDs (batch, seq_len)
|
| | target_properties: Target property values (batch, seq_len, 12)
|
| |
|
| | Returns:
|
| | MSE loss for property prediction
|
| | """
|
| |
|
| | elem_emb = self.element_embed(element_ids)
|
| |
|
| |
|
| | pred_props = self.property_head(elem_emb)
|
| |
|
| |
|
| | loss = F.mse_loss(pred_props, target_properties)
|
| | return loss
|
| |
|
| |
|
| | def test_molecular_module():
|
| | """Test MolecularModule."""
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 128
|
| |
|
| | module = MolecularModule(d_model)
|
| |
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = [
|
| | "Water is H2O. The DNA sequence is ACGTACGTACGT.",
|
| | "Proteins are made of amino acids like ACDEFGH. Benzene is C6H6."
|
| | ]
|
| |
|
| | output = module(x, text=text)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | assert output.shape == x.shape
|
| |
|
| | print("MolecularModule test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_molecular_module()
|
| |
|