Safetensors
thales_quant
finance
fintech
sparse-autoencoders
xai
Thales / model.py
imbue2025's picture
Upload 3 files
2955221 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
class ResNetBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
x = F.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x += residual
return F.relu(x)
class SAE(nn.Module):
def __init__(self, in_dim=128, sae_dim=1024):
super().__init__()
self.encoder = nn.Linear(in_dim, sae_dim)
self.decoder = nn.Linear(sae_dim, in_dim)
def forward(self, x):
f = F.relu(self.encoder(x))
recon = self.decoder(f)
return f, recon
class ThalesModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, grid_size=11, in_dim=128, sae_dim=1024, pricing_hidden=256):
super().__init__()
self.config = {
"grid_size": grid_size,
"in_dim": in_dim,
"sae_dim": sae_dim,
"pricing_hidden": pricing_hidden
}
self.cnn = nn.Sequential(
nn.Conv2d(2, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
ResNetBlock(32),
nn.MaxPool2d(2),
ResNetBlock(32),
nn.Flatten(),
nn.Linear(32 * (grid_size // 2)**2, in_dim)
)
self.sae = SAE(in_dim=in_dim, sae_dim=sae_dim)
self.pricing_head = nn.Sequential(
nn.Linear(in_dim + 4, pricing_hidden),
nn.SiLU(),
nn.Linear(pricing_hidden, pricing_hidden // 2),
nn.SiLU(),
nn.Linear(pricing_hidden // 2, 1),
nn.Softplus()
)
def forward(self, surface, scalars, return_acts=False):
# Sobolev
if self.training:
scalars.requires_grad_(True)
cnn_out = self.cnn(surface)
sae_f, recon = self.sae(cnn_out)
scalars_norm = torch.stack([
scalars[:, 0] / 100.0, # S
scalars[:, 1] / 100.0, # K
scalars[:, 2], # T
scalars[:, 3] # r
], dim=1)
concat_feat = torch.cat([recon, scalars_norm], dim=1)
price = self.pricing_head(concat_feat)
if return_acts:
return price, scalars, cnn_out, recon, sae_f
return price, scalars, cnn_out, recon