Safetensors
thales_quant
finance
fintech
sparse-autoencoders
xai
File size: 2,659 Bytes
2955221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

class SAE(nn.Module):
    def __init__(self, in_dim=128, sae_dim=1024):
        super().__init__()
        self.encoder = nn.Linear(in_dim, sae_dim)
        self.decoder = nn.Linear(sae_dim, in_dim)
        
    def forward(self, x):
        f = F.relu(self.encoder(x))
        recon = self.decoder(f)
        return f, recon

class ThalesModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, grid_size=11, in_dim=128, sae_dim=1024, pricing_hidden=256):
        super().__init__()
        self.config = {
            "grid_size": grid_size,
            "in_dim": in_dim,
            "sae_dim": sae_dim,
            "pricing_hidden": pricing_hidden
        }
        
        self.cnn = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            ResNetBlock(32),
            nn.MaxPool2d(2),
            ResNetBlock(32),
            nn.Flatten(),
            nn.Linear(32 * (grid_size // 2)**2, in_dim)
        )
        self.sae = SAE(in_dim=in_dim, sae_dim=sae_dim)
        
        self.pricing_head = nn.Sequential(
            nn.Linear(in_dim + 4, pricing_hidden),
            nn.SiLU(),
            nn.Linear(pricing_hidden, pricing_hidden // 2),
            nn.SiLU(),
            nn.Linear(pricing_hidden // 2, 1),
            nn.Softplus() 
        )

    def forward(self, surface, scalars, return_acts=False):
        # Sobolev
        if self.training:
            scalars.requires_grad_(True)
            
        cnn_out = self.cnn(surface)
        sae_f, recon = self.sae(cnn_out)
        
        scalars_norm = torch.stack([
            scalars[:, 0] / 100.0, # S
            scalars[:, 1] / 100.0, # K
            scalars[:, 2],         # T
            scalars[:, 3]          # r
        ], dim=1)
        
        concat_feat = torch.cat([recon, scalars_norm], dim=1)
        price = self.pricing_head(concat_feat)
        
        if return_acts:
            return price, scalars, cnn_out, recon, sae_f
        return price, scalars, cnn_out, recon