File size: 12,593 Bytes
5c43f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""

MolecularModule: Domain knowledge for chemistry and biology.

Element embeddings, SMILES understanding, bond types, amino acids.

"""

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List


class MolecularModule(nn.Module):
    """

    Domain knowledge for chemistry and biology.

    - All 118 elements as learned embeddings with properties

      (atomic number, mass, electronegativity, valence electrons)

    - SMILES string understanding for molecular structures

    - Bond type awareness (covalent, ionic, hydrogen, van der Waals)

    - Amino acid sequence understanding for biology/zoology

    - Molecular formula → property reasoning

    """

    def __init__(self, d_model: int, num_elements: int = 118):
        """

        Initialize MolecularModule.



        Args:

            d_model: Model dimension

            num_elements: Number of chemical elements (default 118)

        """
        super().__init__()
        self.d_model = d_model
        self.num_elements = num_elements

        # Element embeddings — 118 elements
        self.element_embed = nn.Embedding(num_elements + 1, d_model)  # +1 for unknown

        # Element property encoder (12 properties)
        # [atomic_number, mass, electronegativity, valence_e, period, group,
        #  atomic_radius, ionization_energy, electron_affinity, density,
        #  melting_point, boiling_point]
        self.property_proj = nn.Linear(12, d_model)

        # Bond type embeddings (8 types)
        # 0: none, 1: single, 2: double, 3: triple, 4: aromatic,
        # 5: ionic, 6: hydrogen, 7: van der waals
        self.bond_embed = nn.Embedding(8, d_model)

        # Amino acid embeddings (20 standard + special)
        self.amino_acid_vocab = 25  # 20 standard + stop + start + unknown + special
        self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model)

        # Molecular graph attention (treats molecules as graphs)
        self.mol_attention = nn.MultiheadAttention(
            d_model,
            num_heads=8,
            batch_first=True,
            dropout=0.1,
        )

        # Property prediction head (for auxiliary tasks)
        self.property_head = nn.Linear(d_model, 12)

        # Initialize weights
        self._initialize_weights()

        # Pre-compute element properties (simplified)
        self._init_element_properties()

    def _initialize_weights(self):
        """Initialize weights."""
        for module in [self.element_embed, self.property_proj, self.bond_embed,
                       self.amino_embed, self.mol_attention, self.property_head]:
            if hasattr(module, 'weight'):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, 'bias') and module.bias is not None:
                nn.init.zeros_(module.bias)

    def _init_element_properties(self):
        """Initialize element property table with approximate values."""
        # This is a simplified version - in practice would load from database
        # Properties: [atomic_number, mass, electronegativity, valence_e, period, group,
        #              atomic_radius, ionization_energy, electron_affinity, density,
        #              melting_point, boiling_point]
        properties = torch.zeros(self.num_elements + 1, 12)

        # Fill in known elements (simplified data for first 20 + some common ones)
        # Real implementation would use a comprehensive chemistry database
        element_data = {
            1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20],
            6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027],
            7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77],
            8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90],
            # ... would fill all 118 elements
        }

        for z, props in element_data.items():
            properties[z] = torch.tensor(props)

        self.register_buffer("element_properties", properties)

    def detect_molecular_spans(

        self,

        text: str,

    ) -> List[Tuple[int, int, str]]:
        """

        Detect molecular/chemical spans in text.



        Args:

            text: Input text string



        Returns:

            List of (start_char, end_char, span_type)

            span_type: "formula", "smiles", "amino_acid"

        """
        spans = []

        # Chemical formulas: H2O, CO2, C6H12O6, NaCl, HCl
        formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b'
        for match in re.finditer(formula_pattern, text):
            # Filter out single letters that are not formulas
            span = match.group()
            if len(span) > 1 or span.isupper():
                spans.append((match.start(), match.end(), "formula"))

        # SMILES patterns (simplified detection)
        # Contains: =, #, @, [], (), numbers in sequence
        smiles_hints = ['=', '#', '@', '[', ']', '(', ')']
        words = re.findall(r'\S+', text)
        for word in words:
            if any(hint in word for hint in smiles_hints) and len(word) > 3:
                # Find position in text
                pos = text.find(word)
                if pos >= 0:
                    spans.append((pos, pos + len(word), "smiles"))

        # Amino acid sequences (single letters, length > 5)
        aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b'
        for match in re.finditer(aa_pattern, text.upper()):
            spans.append((match.start(), match.end(), "amino_acid"))

        return spans

    def encode_molecule(

        self,

        formula: str,

    ) -> torch.Tensor:
        """

        Encode a molecular formula into embedding.



        Args:

            formula: Chemical formula string (e.g., "C6H12O6")



        Returns:

            Molecule embedding (d_model,)

        """
        # Parse formula into elements and counts
        # Simplified parser - real would handle nested parentheses
        pattern = r'([A-Z][a-z]?)(\d*)'
        matches = re.findall(pattern, formula)

        device = self.element_embed.weight.device
        embeddings = []
        weights = []

        for element, count_str in matches:
            # Get element atomic number (simplified mapping)
            element_map = {
                'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
                'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
                'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
                # ... extend as needed
            }
            z = element_map.get(element, 0)  # 0 = unknown

            count = int(count_str) if count_str else 1

            # Get element embedding
            elem_emb = self.element_embed(torch.tensor(z, device=device))

            # Get properties and project
            props = self.element_properties[z].unsqueeze(0)  # (1, 12)
            props_emb = self.property_proj(props).squeeze(0)

            # Combine
            combined = elem_emb + props_emb
            embeddings.append(combined)
            weights.append(count)

        if not embeddings:
            # Return zero embedding
            return torch.zeros(self.d_model, device=device)

        # Weighted average
        embeddings = torch.stack(embeddings)
        weights = torch.tensor(weights, dtype=torch.float32, device=device)
        weights = weights / weights.sum()

        return (embeddings * weights.unsqueeze(-1)).sum(dim=0)

    def forward(

        self,

        x: torch.Tensor,

        text: Optional[List[str]] = None,

        molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None,

    ) -> torch.Tensor:
        """

        Forward pass through molecular module.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            text: Optional original text strings

            molecular_spans: Optional pre-computed molecular spans per batch



        Returns:

            Molecular-enhanced representation (batch, seq_len, d_model)

        """
        batch, seq_len, d_model = x.shape
        device = x.device

        # Detect molecular spans
        if molecular_spans is None and text is not None:
            molecular_spans = []
            for b in range(batch):
                spans = self.detect_molecular_spans(text[b])
                # Convert char spans to token spans
                token_spans = []
                for start_char, end_char, span_type in spans:
                    start_tok = max(0, start_char // 4)
                    end_tok = min(seq_len, end_char // 4 + 1)
                    token_spans.append((start_tok, end_tok, span_type))
                molecular_spans.append(token_spans)

        # Enhance molecular spans
        output = x.clone()

        if molecular_spans:
            for b in range(batch):
                spans_b = molecular_spans[b] if b < len(molecular_spans) else []

                for start_tok, end_tok, span_type in spans_b:
                    if end_tok <= start_tok:
                        continue

                    span_slice = x[b, start_tok:end_tok, :]

                    if span_type == "formula":
                        # Extract formula from text if available
                        if text:
                            formula = text[b][start_tok*4:end_tok*4]  # rough extraction
                            mol_emb = self.encode_molecule(formula)
                        else:
                            mol_emb = torch.randn(d_model, device=device)

                        # Add molecular embedding to first token
                        output[b, start_tok, :] += mol_emb

                    elif span_type == "amino_acid":
                        # Encode as amino acid sequence
                        # Simplified: treat each letter as amino acid
                        seq_len_span = end_tok - start_tok
                        aa_ids = torch.randint(0, 20, (seq_len_span,), device=device)
                        aa_emb = self.amino_embed(aa_ids)  # (seq_len_span, d_model)
                        output[b, start_tok:end_tok, :] += aa_emb

                    elif span_type == "smiles":
                        # For SMILES, apply graph attention (simplified)
                        # Treat each character as a node
                        seq_len_span = end_tok - start_tok
                        if seq_len_span > 1:
                            # Self-attention over the span
                            attn_out, _ = self.mol_attention(
                                span_slice.unsqueeze(0),
                                span_slice.unsqueeze(0),
                                span_slice.unsqueeze(0),
                            )
                            output[b, start_tok:end_tok, :] += attn_out.squeeze(0)

        return output

    def compute_property_loss(

        self,

        x: torch.Tensor,

        element_ids: torch.Tensor,

        target_properties: torch.Tensor,

    ) -> torch.Tensor:
        """

        Compute auxiliary loss for property prediction.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            element_ids: Element IDs (batch, seq_len)

            target_properties: Target property values (batch, seq_len, 12)



        Returns:

            MSE loss for property prediction

        """
        # Get element embeddings
        elem_emb = self.element_embed(element_ids)

        # Predict properties
        pred_props = self.property_head(elem_emb)

        # Compute loss
        loss = F.mse_loss(pred_props, target_properties)
        return loss


def test_molecular_module():
    """Test MolecularModule."""
    d_model = 512
    batch_size = 2
    seq_len = 128

    module = MolecularModule(d_model)

    x = torch.randn(batch_size, seq_len, d_model)
    text = [
        "Water is H2O. The DNA sequence is ACGTACGTACGT.",
        "Proteins are made of amino acids like ACDEFGH. Benzene is C6H6."
    ]

    output = module(x, text=text)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    assert output.shape == x.shape

    print("MolecularModule test passed!")


if __name__ == "__main__":
    test_molecular_module()