hit-detector / model.py
andrewromanenco's picture
Add Hugging Face-ready wrapper for HitDetector model
35290ca
import torch
from torch import nn
class SimpleCNN(nn.Module):
def __init__(self, sample_input):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
with torch.no_grad():
dummy_output = self.features(sample_input.unsqueeze(0))
self.flattened_size = dummy_output.view(1, -1).size(1)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(self.flattened_size, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
x = self.features(x)
return self.classifier(x)