| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| class ResNetBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.bn1 = nn.BatchNorm2d(channels) |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.bn2 = nn.BatchNorm2d(channels) |
|
|
| def forward(self, x): |
| residual = x |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = self.bn2(self.conv2(x)) |
| x += residual |
| return F.relu(x) |
|
|
| class SAE(nn.Module): |
| def __init__(self, in_dim=128, sae_dim=1024): |
| super().__init__() |
| self.encoder = nn.Linear(in_dim, sae_dim) |
| self.decoder = nn.Linear(sae_dim, in_dim) |
| |
| def forward(self, x): |
| f = F.relu(self.encoder(x)) |
| recon = self.decoder(f) |
| return f, recon |
|
|
| class ThalesModel(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, grid_size=11, in_dim=128, sae_dim=1024, pricing_hidden=256): |
| super().__init__() |
| self.config = { |
| "grid_size": grid_size, |
| "in_dim": in_dim, |
| "sae_dim": sae_dim, |
| "pricing_hidden": pricing_hidden |
| } |
| |
| self.cnn = nn.Sequential( |
| nn.Conv2d(2, 32, kernel_size=3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| ResNetBlock(32), |
| nn.MaxPool2d(2), |
| ResNetBlock(32), |
| nn.Flatten(), |
| nn.Linear(32 * (grid_size // 2)**2, in_dim) |
| ) |
| self.sae = SAE(in_dim=in_dim, sae_dim=sae_dim) |
| |
| self.pricing_head = nn.Sequential( |
| nn.Linear(in_dim + 4, pricing_hidden), |
| nn.SiLU(), |
| nn.Linear(pricing_hidden, pricing_hidden // 2), |
| nn.SiLU(), |
| nn.Linear(pricing_hidden // 2, 1), |
| nn.Softplus() |
| ) |
|
|
| def forward(self, surface, scalars, return_acts=False): |
| |
| if self.training: |
| scalars.requires_grad_(True) |
| |
| cnn_out = self.cnn(surface) |
| sae_f, recon = self.sae(cnn_out) |
| |
| scalars_norm = torch.stack([ |
| scalars[:, 0] / 100.0, |
| scalars[:, 1] / 100.0, |
| scalars[:, 2], |
| scalars[:, 3] |
| ], dim=1) |
| |
| concat_feat = torch.cat([recon, scalars_norm], dim=1) |
| price = self.pricing_head(concat_feat) |
| |
| if return_acts: |
| return price, scalars, cnn_out, recon, sae_f |
| return price, scalars, cnn_out, recon |