import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin class ResNetBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): residual = x x = F.relu(self.bn1(self.conv1(x))) x = self.bn2(self.conv2(x)) x += residual return F.relu(x) class SAE(nn.Module): def __init__(self, in_dim=128, sae_dim=1024): super().__init__() self.encoder = nn.Linear(in_dim, sae_dim) self.decoder = nn.Linear(sae_dim, in_dim) def forward(self, x): f = F.relu(self.encoder(x)) recon = self.decoder(f) return f, recon class ThalesModel(nn.Module, PyTorchModelHubMixin): def __init__(self, grid_size=11, in_dim=128, sae_dim=1024, pricing_hidden=256): super().__init__() self.config = { "grid_size": grid_size, "in_dim": in_dim, "sae_dim": sae_dim, "pricing_hidden": pricing_hidden } self.cnn = nn.Sequential( nn.Conv2d(2, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), ResNetBlock(32), nn.MaxPool2d(2), ResNetBlock(32), nn.Flatten(), nn.Linear(32 * (grid_size // 2)**2, in_dim) ) self.sae = SAE(in_dim=in_dim, sae_dim=sae_dim) self.pricing_head = nn.Sequential( nn.Linear(in_dim + 4, pricing_hidden), nn.SiLU(), nn.Linear(pricing_hidden, pricing_hidden // 2), nn.SiLU(), nn.Linear(pricing_hidden // 2, 1), nn.Softplus() ) def forward(self, surface, scalars, return_acts=False): # Sobolev if self.training: scalars.requires_grad_(True) cnn_out = self.cnn(surface) sae_f, recon = self.sae(cnn_out) scalars_norm = torch.stack([ scalars[:, 0] / 100.0, # S scalars[:, 1] / 100.0, # K scalars[:, 2], # T scalars[:, 3] # r ], dim=1) concat_feat = torch.cat([recon, scalars_norm], dim=1) price = self.pricing_head(concat_feat) if return_acts: return price, scalars, cnn_out, recon, sae_f return price, scalars, cnn_out, recon