Vortex-7b-V1 / models /science_modules /molecular_module.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
MolecularModule: Domain knowledge for chemistry and biology.
Element embeddings, SMILES understanding, bond types, amino acids.
"""
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
class MolecularModule(nn.Module):
"""
Domain knowledge for chemistry and biology.
- All 118 elements as learned embeddings with properties
(atomic number, mass, electronegativity, valence electrons)
- SMILES string understanding for molecular structures
- Bond type awareness (covalent, ionic, hydrogen, van der Waals)
- Amino acid sequence understanding for biology/zoology
- Molecular formula → property reasoning
"""
def __init__(self, d_model: int, num_elements: int = 118):
"""
Initialize MolecularModule.
Args:
d_model: Model dimension
num_elements: Number of chemical elements (default 118)
"""
super().__init__()
self.d_model = d_model
self.num_elements = num_elements
# Element embeddings — 118 elements
self.element_embed = nn.Embedding(num_elements + 1, d_model) # +1 for unknown
# Element property encoder (12 properties)
# [atomic_number, mass, electronegativity, valence_e, period, group,
# atomic_radius, ionization_energy, electron_affinity, density,
# melting_point, boiling_point]
self.property_proj = nn.Linear(12, d_model)
# Bond type embeddings (8 types)
# 0: none, 1: single, 2: double, 3: triple, 4: aromatic,
# 5: ionic, 6: hydrogen, 7: van der waals
self.bond_embed = nn.Embedding(8, d_model)
# Amino acid embeddings (20 standard + special)
self.amino_acid_vocab = 25 # 20 standard + stop + start + unknown + special
self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model)
# Molecular graph attention (treats molecules as graphs)
self.mol_attention = nn.MultiheadAttention(
d_model,
num_heads=8,
batch_first=True,
dropout=0.1,
)
# Property prediction head (for auxiliary tasks)
self.property_head = nn.Linear(d_model, 12)
# Initialize weights
self._initialize_weights()
# Pre-compute element properties (simplified)
self._init_element_properties()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.element_embed, self.property_proj, self.bond_embed,
self.amino_embed, self.mol_attention, self.property_head]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def _init_element_properties(self):
"""Initialize element property table with approximate values."""
# This is a simplified version - in practice would load from database
# Properties: [atomic_number, mass, electronegativity, valence_e, period, group,
# atomic_radius, ionization_energy, electron_affinity, density,
# melting_point, boiling_point]
properties = torch.zeros(self.num_elements + 1, 12)
# Fill in known elements (simplified data for first 20 + some common ones)
# Real implementation would use a comprehensive chemistry database
element_data = {
1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20],
6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027],
7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77],
8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90],
# ... would fill all 118 elements
}
for z, props in element_data.items():
properties[z] = torch.tensor(props)
self.register_buffer("element_properties", properties)
def detect_molecular_spans(
self,
text: str,
) -> List[Tuple[int, int, str]]:
"""
Detect molecular/chemical spans in text.
Args:
text: Input text string
Returns:
List of (start_char, end_char, span_type)
span_type: "formula", "smiles", "amino_acid"
"""
spans = []
# Chemical formulas: H2O, CO2, C6H12O6, NaCl, HCl
formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b'
for match in re.finditer(formula_pattern, text):
# Filter out single letters that are not formulas
span = match.group()
if len(span) > 1 or span.isupper():
spans.append((match.start(), match.end(), "formula"))
# SMILES patterns (simplified detection)
# Contains: =, #, @, [], (), numbers in sequence
smiles_hints = ['=', '#', '@', '[', ']', '(', ')']
words = re.findall(r'\S+', text)
for word in words:
if any(hint in word for hint in smiles_hints) and len(word) > 3:
# Find position in text
pos = text.find(word)
if pos >= 0:
spans.append((pos, pos + len(word), "smiles"))
# Amino acid sequences (single letters, length > 5)
aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b'
for match in re.finditer(aa_pattern, text.upper()):
spans.append((match.start(), match.end(), "amino_acid"))
return spans
def encode_molecule(
self,
formula: str,
) -> torch.Tensor:
"""
Encode a molecular formula into embedding.
Args:
formula: Chemical formula string (e.g., "C6H12O6")
Returns:
Molecule embedding (d_model,)
"""
# Parse formula into elements and counts
# Simplified parser - real would handle nested parentheses
pattern = r'([A-Z][a-z]?)(\d*)'
matches = re.findall(pattern, formula)
device = self.element_embed.weight.device
embeddings = []
weights = []
for element, count_str in matches:
# Get element atomic number (simplified mapping)
element_map = {
'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
# ... extend as needed
}
z = element_map.get(element, 0) # 0 = unknown
count = int(count_str) if count_str else 1
# Get element embedding
elem_emb = self.element_embed(torch.tensor(z, device=device))
# Get properties and project
props = self.element_properties[z].unsqueeze(0) # (1, 12)
props_emb = self.property_proj(props).squeeze(0)
# Combine
combined = elem_emb + props_emb
embeddings.append(combined)
weights.append(count)
if not embeddings:
# Return zero embedding
return torch.zeros(self.d_model, device=device)
# Weighted average
embeddings = torch.stack(embeddings)
weights = torch.tensor(weights, dtype=torch.float32, device=device)
weights = weights / weights.sum()
return (embeddings * weights.unsqueeze(-1)).sum(dim=0)
def forward(
self,
x: torch.Tensor,
text: Optional[List[str]] = None,
molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
) -> torch.Tensor:
"""
Forward pass through molecular module.
Args:
x: Input tensor (batch, seq_len, d_model)
text: Optional original text strings
molecular_spans: Optional pre-computed molecular spans per batch
Returns:
Molecular-enhanced representation (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
device = x.device
# Detect molecular spans
if molecular_spans is None and text is not None:
molecular_spans = []
for b in range(batch):
spans = self.detect_molecular_spans(text[b])
# Convert char spans to token spans
token_spans = []
for start_char, end_char, span_type in spans:
start_tok = max(0, start_char // 4)
end_tok = min(seq_len, end_char // 4 + 1)
token_spans.append((start_tok, end_tok, span_type))
molecular_spans.append(token_spans)
# Enhance molecular spans
output = x.clone()
if molecular_spans:
for b in range(batch):
spans_b = molecular_spans[b] if b < len(molecular_spans) else []
for start_tok, end_tok, span_type in spans_b:
if end_tok <= start_tok:
continue
span_slice = x[b, start_tok:end_tok, :]
if span_type == "formula":
# Extract formula from text if available
if text:
formula = text[b][start_tok*4:end_tok*4] # rough extraction
mol_emb = self.encode_molecule(formula)
else:
mol_emb = torch.randn(d_model, device=device)
# Add molecular embedding to first token
output[b, start_tok, :] += mol_emb
elif span_type == "amino_acid":
# Encode as amino acid sequence
# Simplified: treat each letter as amino acid
seq_len_span = end_tok - start_tok
aa_ids = torch.randint(0, 20, (seq_len_span,), device=device)
aa_emb = self.amino_embed(aa_ids) # (seq_len_span, d_model)
output[b, start_tok:end_tok, :] += aa_emb
elif span_type == "smiles":
# For SMILES, apply graph attention (simplified)
# Treat each character as a node
seq_len_span = end_tok - start_tok
if seq_len_span > 1:
# Self-attention over the span
attn_out, _ = self.mol_attention(
span_slice.unsqueeze(0),
span_slice.unsqueeze(0),
span_slice.unsqueeze(0),
)
output[b, start_tok:end_tok, :] += attn_out.squeeze(0)
return output
def compute_property_loss(
self,
x: torch.Tensor,
element_ids: torch.Tensor,
target_properties: torch.Tensor,
) -> torch.Tensor:
"""
Compute auxiliary loss for property prediction.
Args:
x: Input tensor (batch, seq_len, d_model)
element_ids: Element IDs (batch, seq_len)
target_properties: Target property values (batch, seq_len, 12)
Returns:
MSE loss for property prediction
"""
# Get element embeddings
elem_emb = self.element_embed(element_ids)
# Predict properties
pred_props = self.property_head(elem_emb)
# Compute loss
loss = F.mse_loss(pred_props, target_properties)
return loss
def test_molecular_module():
"""Test MolecularModule."""
d_model = 512
batch_size = 2
seq_len = 128
module = MolecularModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = [
"Water is H2O. The DNA sequence is ACGTACGTACGT.",
"Proteins are made of amino acids like ACDEFGH. Benzene is C6H6."
]
output = module(x, text=text)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape
print("MolecularModule test passed!")
if __name__ == "__main__":
test_molecular_module()