File size: 4,119 Bytes
7a6e052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional

# =============================================================================
# Biology Data
# =============================================================================

GENETIC_CODE = {
    'ATA':'I', 'ATC':'I', 'ATT':'I', 'ATG':'M',
    'ACA':'T', 'ACC':'T', 'ACG':'T', 'ACT':'T',
    'AAC':'N', 'AAT':'N', 'AAA':'K', 'AAG':'K',
    'AGC':'S', 'AGT':'S', 'AGA':'R', 'AGG':'R',
    'CTA':'L', 'CTC':'L', 'CTG':'L', 'CTT':'L',
    'CCA':'P', 'CCC':'P', 'CCG':'P', 'CCT':'P',
    'CAC':'H', 'CAT':'H', 'CAA':'Q', 'CAG':'Q',
    'CGA':'R', 'CGC':'R', 'CGG':'R', 'CGT':'R',
    'GTA':'V', 'GTC':'V', 'GTG':'V', 'GTT':'V',
    'GCA':'A', 'GCC':'A', 'GCG':'A', 'GCT':'A',
    'GAC':'D', 'GAT':'D', 'GAA':'E', 'GAG':'E',
    'GGA':'G', 'GGC':'A', 'GGG':'G', 'GGT':'G', # Note: GGC is G, typo in some maps but let's be careful
    'TCA':'S', 'TCC':'S', 'TCG':'S', 'TCT':'S',
    'TTC':'F', 'TTT':'F', 'TTA':'L', 'TTG':'L',
    'TAC':'Y', 'TAT':'Y', 'TAA':'*', 'TAG':'*',
    'TGC':'C', 'TGT':'C', 'TGA':'*', 'TGG':'W',
}
# Correction
GENETIC_CODE['GGC'] = 'G'

BASES = ['A', 'C', 'G', 'T']
CODON_TO_INDEX = {b1+b2+b3: i for i, (b1,b2,b3) in enumerate([(b1,b2,b3) for b1 in BASES for b2 in BASES for b3 in BASES])}
INDEX_TO_CODON = {v: k for k, v in CODON_TO_INDEX.items()}

# =============================================================================
# Hyperbolic Utilities
# =============================================================================

def exp_map_zero(x: torch.Tensor, c: float = 1.0) -> torch.Tensor:
    sqrt_c = math.sqrt(c)
    norm_x = torch.norm(x, p=2, dim=-1, keepdim=True)
    norm_x = torch.clamp(norm_x, min=1e-15)
    res = torch.tanh(sqrt_c * norm_x) * x / (sqrt_c * norm_x)
    return res

def project_to_poincare(z: torch.Tensor, max_norm: float = 0.95, c: float = 1.0) -> torch.Tensor:
    norm = torch.norm(z, p=2, dim=-1, keepdim=True)
    mask = norm > max_norm
    projected = (z / norm) * max_norm
    return torch.where(mask, projected, z)

# =============================================================================
# Codon Encoder
# =============================================================================

class CodonEncoderMLP(nn.Module):
    def __init__(self, latent_dim=16, hidden_dim=64, dropout=0.1):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(12, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, latent_dim)
        )
    def forward(self, x): return self.encoder(x)

class TrainableCodonEncoder(nn.Module):
    def __init__(self, latent_dim=16, hidden_dim=64, curvature=1.0, max_radius=0.9, dropout=0.1):
        super().__init__()
        self.latent_dim = latent_dim; self.curvature = curvature; self.max_radius = max_radius
        self.encoder = CodonEncoderMLP(latent_dim, hidden_dim, dropout)
        
        # Precompute one-hots
        onehots = torch.zeros(64, 12)
        base_to_idx = {'A':0, 'C':1, 'G':2, 'T':3, 'U':3}
        for i in range(64):
            codon = INDEX_TO_CODON[i]
            for pos, base in enumerate(codon):
                onehots[i, pos*4 + base_to_idx[base]] = 1.0
        self.register_buffer('codon_onehots', onehots)

    def encode_all(self):
        z_tangent = self.encoder(self.codon_onehots)
        z_hyp = exp_map_zero(z_tangent, c=self.curvature)
        return project_to_poincare(z_hyp, max_norm=self.max_radius, c=self.curvature)

    def forward(self, codon_indices):
        flat_indices = codon_indices.flatten()
        onehots = self.codon_onehots[flat_indices]
        z_tangent = self.encoder(onehots)
        z_hyp = exp_map_zero(z_tangent, c=self.curvature)
        z_hyp = project_to_poincare(z_hyp, max_norm=self.max_radius, c=self.curvature)
        if len(codon_indices.shape) > 1:
            z_hyp = z_hyp.view(*codon_indices.shape, self.latent_dim)
        return z_hyp