vedatonuryilmaz commited on
Commit
a34e3f1
Β·
verified Β·
1 Parent(s): c7ad40e

Upload mulgit/drug_target.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mulgit/drug_target.py +438 -0
mulgit/drug_target.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chemical Genomics & Drug Target Identification Module
3
+
4
+ Integrates multi-omics data with chemical genomics and perturbation
5
+ genomics to identify molecular targets and pharmaceutical agents
6
+ associated with exceptional longevity.
7
+
8
+ Methods:
9
+ 1. Drug-Target Affinity Prediction (SSM-DTA inspired)
10
+ - Cross-attention between drug (SMILES) and protein target representations
11
+ - Semi-supervised training with masked language modeling
12
+
13
+ 2. Perturbation Response Prediction
14
+ - Predict gene expression changes after drug treatment
15
+ - Based on LINCS L1000 patterns + deep learning
16
+
17
+ 3. Drug Repurposing for Longevity
18
+ - Match drug-induced expression changes to anti-aging signatures
19
+ - Identify existing drugs that mimic longevity-associated patterns
20
+
21
+ Datasets:
22
+ - BALM/BALM-benchmark: Drug-target binding affinity
23
+ - LINCS L1000 (via pytdc): Perturbation gene expression signatures
24
+ - GDSC/CTRP (via pytdc): Drug sensitivity in cell lines
25
+
26
+ References:
27
+ - SSM-DTA (arxiv:2206.09818): Drug-target affinity with semi-supervised training
28
+ - PaccMann (arxiv:1909.05114): Drug design from transcriptomic data
29
+ - MAMMAL (arxiv:2410.22367): Multi-modal drug discovery foundation model
30
+ """
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from typing import Optional, Dict, List, Tuple
36
+
37
+
38
+ # ─── Molecular Encoders ──────────────────────────────────────────────────────
39
+
40
+ class DrugEncoder(nn.Module):
41
+ """
42
+ Encodes drug SMILES strings into molecular embeddings.
43
+
44
+ Uses a simple 1D CNN over character-level SMILES tokens.
45
+ For production: replace with ChemBERTa, MolFormer, or similar
46
+ pretrained molecular transformer.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_size: int = 64, # SMILES character vocabulary
52
+ embed_dim: int = 128,
53
+ hidden_dim: int = 256,
54
+ output_dim: int = 128,
55
+ num_layers: int = 3,
56
+ kernel_size: int = 5,
57
+ ):
58
+ super().__init__()
59
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
60
+
61
+ self.convs = nn.ModuleList([
62
+ nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2)
63
+ for i in range(num_layers)
64
+ ])
65
+
66
+ self.output = nn.Linear(hidden_dim, output_dim)
67
+ self.activation = nn.SELU()
68
+
69
+ def forward(self, smiles_tokens: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ Args:
72
+ smiles_tokens: (B, L) tokenized SMILES strings
73
+ Returns:
74
+ drug_embedding: (B, output_dim)
75
+ """
76
+ x = self.embedding(smiles_tokens) # (B, L, E)
77
+ x = x.transpose(1, 2) # (B, E, L)
78
+
79
+ for conv in self.convs:
80
+ x = self.activation(conv(x)) # (B, H, L)
81
+
82
+ # Global average pooling
83
+ x = x.mean(dim=-1) # (B, H)
84
+ return self.output(x)
85
+
86
+
87
+ class ProteinTargetEncoder(nn.Module):
88
+ """
89
+ Encodes protein target sequences (amino acid strings) into embeddings.
90
+
91
+ For production: replace with ESM-2 or ProtBERT pretrained embeddings.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size: int = 26, # amino acid alphabet
97
+ embed_dim: int = 128,
98
+ hidden_dim: int = 256,
99
+ output_dim: int = 128,
100
+ num_layers: int = 3,
101
+ kernel_size: int = 7,
102
+ ):
103
+ super().__init__()
104
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
105
+
106
+ self.convs = nn.ModuleList([
107
+ nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2)
108
+ for i in range(num_layers)
109
+ ])
110
+
111
+ self.output = nn.Linear(hidden_dim, output_dim)
112
+ self.activation = nn.SELU()
113
+
114
+ def forward(self, aa_tokens: torch.Tensor) -> torch.Tensor:
115
+ x = self.embedding(aa_tokens)
116
+ x = x.transpose(1, 2)
117
+ for conv in self.convs:
118
+ x = self.activation(conv(x))
119
+ x = x.mean(dim=-1)
120
+ return self.output(x)
121
+
122
+
123
+ # ─── Drug-Target Affinity (DTA) Predictor ────────────────────────────────────
124
+
125
+ class DrugTargetAffinityPredictor(nn.Module):
126
+ """
127
+ Predicts binding affinity between drugs and protein targets.
128
+
129
+ Uses cross-attention between drug and target representations,
130
+ inspired by SSM-DTA architecture.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ drug_dim: int = 128,
136
+ target_dim: int = 128,
137
+ hidden_dim: int = 256,
138
+ dropout: float = 0.1,
139
+ ):
140
+ super().__init__()
141
+
142
+ # Cross-attention: drug attends to target, target attends to drug
143
+ self.drug_cross_attn = nn.MultiheadAttention(
144
+ embed_dim=drug_dim, num_heads=4, batch_first=True, dropout=dropout
145
+ )
146
+ self.target_cross_attn = nn.MultiheadAttention(
147
+ embed_dim=target_dim, num_heads=4, batch_first=True, dropout=dropout
148
+ )
149
+
150
+ # Fusion + prediction
151
+ fusion_dim = drug_dim + target_dim
152
+ self.fusion = nn.Sequential(
153
+ nn.Linear(fusion_dim, hidden_dim),
154
+ nn.SELU(),
155
+ nn.AlphaDropout(dropout),
156
+ nn.Linear(hidden_dim, hidden_dim // 2),
157
+ nn.SELU(),
158
+ nn.AlphaDropout(dropout),
159
+ nn.Linear(hidden_dim // 2, 1),
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ drug_embed: torch.Tensor,
165
+ target_embed: torch.Tensor,
166
+ ) -> torch.Tensor:
167
+ """
168
+ Args:
169
+ drug_embed: (B, D_d) drug molecular embeddings
170
+ target_embed: (B, D_t) protein target embeddings
171
+ Returns:
172
+ affinity: (B,) predicted binding affinity (pKd)
173
+ """
174
+ # Cross-attention (treat as single-token sequences)
175
+ drug_attended, _ = self.drug_cross_attn(
176
+ drug_embed.unsqueeze(1),
177
+ target_embed.unsqueeze(1),
178
+ target_embed.unsqueeze(1),
179
+ )
180
+ target_attended, _ = self.target_cross_attn(
181
+ target_embed.unsqueeze(1),
182
+ drug_embed.unsqueeze(1),
183
+ drug_embed.unsqueeze(1),
184
+ )
185
+
186
+ # Concatenate and predict
187
+ fused = torch.cat([drug_attended.squeeze(1), target_attended.squeeze(1)], dim=-1)
188
+ return self.fusion(fused).squeeze(-1)
189
+
190
+
191
+ # ─── Perturbation Response Predictor ─────────────────────────────────────────
192
+
193
+ class PerturbationResponsePredictor(nn.Module):
194
+ """
195
+ Predicts gene expression changes after drug perturbation.
196
+
197
+ Architecture: drug embedding β†’ conditioned decoder β†’ gene expression delta.
198
+ Maps from LINCS L1000-style data: drug treatment β†’ 978 landmark gene changes.
199
+
200
+ Reference: PaccMann, DeepProfile
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ drug_dim: int = 128,
206
+ num_output_genes: int = 978, # LINCS L1000 landmark genes
207
+ hidden_dim: int = 512,
208
+ dropout: float = 0.1,
209
+ ):
210
+ super().__init__()
211
+
212
+ self.condition_net = nn.Sequential(
213
+ nn.Linear(drug_dim, hidden_dim),
214
+ nn.SELU(),
215
+ nn.AlphaDropout(dropout),
216
+ nn.Linear(hidden_dim, hidden_dim),
217
+ nn.SELU(),
218
+ nn.AlphaDropout(dropout),
219
+ )
220
+
221
+ # Decoder: conditioned on drug embedding
222
+ self.decoder = nn.Sequential(
223
+ nn.Linear(hidden_dim + drug_dim, hidden_dim),
224
+ nn.SELU(),
225
+ nn.AlphaDropout(dropout),
226
+ nn.Linear(hidden_dim, hidden_dim // 2),
227
+ nn.SELU(),
228
+ nn.AlphaDropout(dropout),
229
+ nn.Linear(hidden_dim // 2, num_output_genes),
230
+ )
231
+
232
+ def forward(
233
+ self,
234
+ drug_embed: torch.Tensor,
235
+ baseline_expression: Optional[torch.Tensor] = None,
236
+ ) -> torch.Tensor:
237
+ """
238
+ Args:
239
+ drug_embed: (B, D_d) drug embeddings
240
+ baseline_expression: (B, G) baseline gene expression (optional)
241
+ Returns:
242
+ predicted_expression: (B, G) predicted post-perturbation expression
243
+ """
244
+ condition = self.condition_net(drug_embed)
245
+ combined = torch.cat([condition, drug_embed], dim=-1)
246
+ delta = self.decoder(combined)
247
+
248
+ if baseline_expression is not None:
249
+ return baseline_expression + delta
250
+ return delta
251
+
252
+
253
+ # ─── Longevity Drug Repurposing ──────────────────────────────────────────────
254
+
255
+ class LongevityDrugScreener(nn.Module):
256
+ """
257
+ Screens drugs for longevity potential by comparing drug-induced
258
+ expression changes to anti-aging gene expression signatures.
259
+
260
+ Core idea: if a drug's perturbation signature reverses aging-associated
261
+ expression changes, it's a candidate longevity therapeutic.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ dta_predictor: DrugTargetAffinityPredictor,
267
+ perturbation_predictor: PerturbationResponsePredictor,
268
+ gene_dim: int = 978,
269
+ ):
270
+ super().__init__()
271
+ self.dta_predictor = dta_predictor
272
+ self.perturbation_predictor = perturbation_predictor
273
+
274
+ # Aging signature: the gene expression pattern to target
275
+ # Learned during training from aging datasets
276
+ self.aging_signature = nn.Parameter(torch.zeros(gene_dim))
277
+ nn.init.normal_(self.aging_signature, std=0.01)
278
+
279
+ # Longevity target signature: what we want to achieve
280
+ self.longevity_signature = nn.Parameter(torch.zeros(gene_dim))
281
+ nn.init.normal_(self.longevity_signature, std=0.01)
282
+
283
+ def compute_longevity_score(
284
+ self,
285
+ drug_embed: torch.Tensor,
286
+ target_embed: Optional[torch.Tensor] = None,
287
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
288
+ """
289
+ Score a drug for longevity potential.
290
+
291
+ Returns:
292
+ longevity_score: scalar (higher = better longevity drug)
293
+ details: dict with intermediate computations
294
+ """
295
+ # Predict perturbation effect
296
+ delta = self.perturbation_predictor(drug_embed)
297
+
298
+ # How well does the perturbation reverse the aging signature?
299
+ # We want: delta β‰ˆ longevity_signature - aging_signature
300
+ target_delta = (self.longevity_signature - self.aging_signature).unsqueeze(0) # (1, G)
301
+ reversal_score = -F.mse_loss(delta, target_delta.expand_as(delta), reduction='none').mean(dim=-1)
302
+
303
+ # Drug-target affinity (if target provided)
304
+ affinity = None
305
+ if target_embed is not None:
306
+ affinity = self.dta_predictor(drug_embed, target_embed)
307
+
308
+ details = {
309
+ "predicted_delta": delta,
310
+ "reversal_score": reversal_score,
311
+ "affinity": affinity,
312
+ }
313
+
314
+ return reversal_score, details
315
+
316
+ def screen_drugs(
317
+ self,
318
+ drug_embeds: List[torch.Tensor],
319
+ drug_names: List[str],
320
+ top_k: int = 10,
321
+ ) -> List[Tuple[str, float]]:
322
+ """Screen a batch of drugs and return top-k longevity candidates."""
323
+ scores = []
324
+ for embed, name in zip(drug_embeds, drug_names):
325
+ score, _ = self.compute_longevity_score(embed.unsqueeze(0))
326
+ scores.append((name, score.item()))
327
+
328
+ scores.sort(key=lambda x: x[1], reverse=True)
329
+ return scores[:top_k]
330
+
331
+
332
+ # ─── End-to-End Drug Discovery Pipeline ──────────────────────────────────────
333
+
334
+ class DrugDiscoveryPipeline:
335
+ """
336
+ Complete pipeline: multi-omics β†’ drug targets β†’ drug screening β†’ validation.
337
+
338
+ Steps:
339
+ 1. Use MuLGIT causal module to identify longevity-associated genes
340
+ 2. Use DTA predictor to find drugs targeting those genes
341
+ 3. Use perturbation predictor to verify drug effects
342
+ 4. Rank drugs by longevity reversal potential
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ dta_predictor: DrugTargetAffinityPredictor,
348
+ perturbation_predictor: PerturbationResponsePredictor,
349
+ screener: LongevityDrugScreener,
350
+ ):
351
+ self.dta = dta_predictor
352
+ self.perturbation = perturbation_predictor
353
+ self.screener = screener
354
+
355
+ def run(
356
+ self,
357
+ causal_gene_targets: List[str],
358
+ drug_pool: Dict[str, torch.Tensor], # drug_name β†’ embedding
359
+ target_pool: Dict[str, torch.Tensor], # gene_name β†’ embedding
360
+ top_k: int = 20,
361
+ ) -> Dict:
362
+ """
363
+ Full drug discovery run.
364
+
365
+ Args:
366
+ causal_gene_targets: genes identified as causal for longevity
367
+ drug_pool: dictionary of candidate drug embeddings
368
+ target_pool: dictionary of protein target embeddings
369
+ top_k: number of top drugs to return
370
+ """
371
+ results = []
372
+
373
+ for drug_name, drug_embed in drug_pool.items():
374
+ drug_scores = []
375
+
376
+ for gene in causal_gene_targets:
377
+ if gene in target_pool:
378
+ target_embed = target_pool[gene]
379
+ score, details = self.screener.compute_longevity_score(
380
+ drug_embed.unsqueeze(0),
381
+ target_embed.unsqueeze(0),
382
+ )
383
+ drug_scores.append({
384
+ "gene": gene,
385
+ "score": score.item(),
386
+ "affinity": details["affinity"].item() if details["affinity"] is not None else None,
387
+ })
388
+
389
+ if drug_scores:
390
+ avg_score = sum(d["score"] for d in drug_scores) / len(drug_scores)
391
+ results.append({
392
+ "drug": drug_name,
393
+ "avg_score": avg_score,
394
+ "gene_details": sorted(drug_scores, key=lambda x: x["score"], reverse=True),
395
+ })
396
+
397
+ results.sort(key=lambda x: x["avg_score"], reverse=True)
398
+ return {"top_drugs": results[:top_k], "all_results": results}
399
+
400
+
401
+ # ─── Molecular Tokenizers ────────────────────────────────────────────────────
402
+
403
+ # Simple SMILES tokenizer (for MVP; use DeepChem/RDKit in production)
404
+ SMILES_CHARS = sorted(set("ABCDEFGHIKLMNOPQRSTUVWXYZ[\\]^_abcdefghilmnopqrstuv=0123456789+-.()#@/\\%"))
405
+ SMILES_TO_IDX = {c: i + 1 for i, c in enumerate(SMILES_CHARS)} # 0 = padding
406
+
407
+ # Amino acid tokenizer
408
+ AA_CHARS = sorted(set("ACDEFGHIKLMNPQRSTVWY"))
409
+ AA_TO_IDX = {c: i + 1 for i, c in enumerate(AA_CHARS)}
410
+
411
+
412
+ def tokenize_smiles(smiles: str, max_len: int = 256) -> torch.Tensor:
413
+ """Tokenize a SMILES string."""
414
+ tokens = [SMILES_TO_IDX.get(c, 0) for c in smiles[:max_len]]
415
+ # Pad
416
+ tokens += [0] * (max_len - len(tokens))
417
+ return torch.tensor(tokens, dtype=torch.long)
418
+
419
+
420
+ def tokenize_protein(sequence: str, max_len: int = 1024) -> torch.Tensor:
421
+ """Tokenize a protein amino acid sequence."""
422
+ tokens = [AA_TO_IDX.get(c, 0) for c in sequence[:max_len]]
423
+ tokens += [0] * (max_len - len(tokens))
424
+ return torch.tensor(tokens, dtype=torch.long)
425
+
426
+
427
+ # ─── Model Factory ───────────────────────────────────────────────────────────
428
+
429
+ def create_drug_discovery_modules() -> Tuple[
430
+ DrugTargetAffinityPredictor,
431
+ PerturbationResponsePredictor,
432
+ LongevityDrugScreener,
433
+ ]:
434
+ """Create all drug discovery modules with default configs."""
435
+ dta = DrugTargetAffinityPredictor(drug_dim=128, target_dim=128)
436
+ perturbation = PerturbationResponsePredictor(drug_dim=128)
437
+ screener = LongevityDrugScreener(dta, perturbation)
438
+ return dta, perturbation, screener