File size: 2,723 Bytes
414b96e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# src/features/interaction.py
#
# Interaction feature block β€” FIXED design.
#
# Previous bug: joint PCA on [ESM_640d || ECFP_1024d]
#   ESM: dense floats, semantic meaning, range ~[-3,3]
#   ECFP: sparse binary {0,1}
#   Joint PCA on incompatible spaces β†’ garbage projections
#
# Fixed design: SEPARATE PCA per modality, THEN interact
#   Protein:  ESM_multi [N, n_layers*480] β†’ PCA β†’ [N, 128]
#   Ligand:   ECFP+MACCS+AtomPair+Torsion [N, 5287] β†’ PCA β†’ [N, 128]
#   Interact: [PβŠ™L, |P-L|] β†’ [N, 256]
#   Final:    [P_proj, L_proj, PβŠ™L, |P-L|] β†’ [N, 512]
#
# Biological rationale for hadamard + diff:
#   PβŠ™L captures co-activation: which latent dimensions
#        are simultaneously high in both protein and ligand
#   |P-L| captures complementarity: which dimensions differ
#        (shape complementarity is about what doesn't match)

import numpy as np
from sklearn.decomposition import PCA
import joblib
from pathlib import Path


def build_interaction_features(prot_emb, lig_concat,
                                dim=128,
                                prot_pca=None, lig_pca=None,
                                fit=False):
    """
    Args:
        prot_emb:   [N, prot_dim]  β€” ESM multi+attention concatenated
        lig_concat: [N, lig_dim]   β€” ECFP+MACCS+AtomPair+Torsion concatenated
                                     (NOT including scaled RDKit phys)
        dim:        projection dimension
        fit:        if True, fit PCA on this data

    Returns:
        interaction [N, 4*dim]
        prot_pca, lig_pca
    """
    if fit:
        prot_pca = PCA(n_components=min(dim, prot_emb.shape[1]),
                       random_state=42)
        lig_pca  = PCA(n_components=min(dim, lig_concat.shape[1]),
                       random_state=42)
        p_proj = prot_pca.fit_transform(prot_emb)
        l_proj = lig_pca.fit_transform(lig_concat)
    else:
        p_proj = prot_pca.transform(prot_emb)
        l_proj = lig_pca.transform(lig_concat)

    # Pad if PCA gave fewer components than dim (small datasets)
    if p_proj.shape[1] < dim:
        p_proj = np.pad(p_proj, ((0,0),(0, dim - p_proj.shape[1])))
    if l_proj.shape[1] < dim:
        l_proj = np.pad(l_proj, ((0,0),(0, dim - l_proj.shape[1])))

    hadamard = p_proj * l_proj
    diff     = np.abs(p_proj - l_proj)

    interaction = np.concatenate([p_proj, l_proj, hadamard, diff], axis=1)
    return interaction, prot_pca, lig_pca


def save_pcas(prot_pca, lig_pca, out_dir):
    joblib.dump(prot_pca, out_dir / "prot_pca.pkl")
    joblib.dump(lig_pca,  out_dir / "lig_pca.pkl")


def load_pcas(out_dir):
    return (joblib.load(out_dir / "prot_pca.pkl"),
            joblib.load(out_dir / "lig_pca.pkl"))