JacobLinCool's picture
Update script.py
3d34cc8 verified
import csv
from pathlib import Path
import torch
import torch.nn as nn
from datasets import Dataset, Image
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from torch.utils.data import DataLoader
class DetectionHeads(nn.Module):
def __init__(self, input_dim: int, class_num: int):
super().__init__()
self.heads = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_dim, 64),
nn.GELU(),
nn.Linear(64, 32),
nn.GELU(),
nn.Linear(32, 16),
nn.GELU(),
nn.Linear(16, 8),
nn.GELU(),
)
for _ in range(4)
]
)
self.proj = nn.Linear(8, class_num)
def forward(self, x: Tensor) -> Tensor:
# x: (batch_size, input_dim)
# output: (batch_size, 4, class_num)
y = torch.stack([self.proj(self.heads[i](x)) for i in range(4)], dim=1)
return y
class Baseline2024(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
class_num: int = 26 + 10 + 3,
n_channels: int = 32,
p_dropout: float = 0.95,
):
super().__init__()
self.act = nn.GELU()
self.pool2d = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.pool1d = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
self.conv1 = nn.Conv2d(3, n_channels, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(n_channels)
self.conv2 = nn.Conv2d(
n_channels, n_channels * 2, kernel_size=3, stride=1, padding=1
)
self.bn2 = nn.BatchNorm2d(n_channels * 2)
self.conv3 = nn.Conv2d(
n_channels * 2, n_channels * 4, kernel_size=3, stride=1, padding=1
)
self.bn3 = nn.BatchNorm2d(n_channels * 4)
self.conv4 = nn.Conv2d(
n_channels * 4, n_channels * 8, kernel_size=3, stride=1, padding=1
)
self.bn4 = nn.BatchNorm2d(n_channels * 8)
self.conv5 = nn.Conv2d(
n_channels * 8, n_channels * 16, kernel_size=3, stride=1, padding=1
)
self.bn5 = nn.BatchNorm2d(n_channels * 16)
self.conv6 = nn.Conv2d(
n_channels * 16, n_channels * 32, kernel_size=3, stride=1, padding=1
)
self.bn6 = nn.BatchNorm2d(n_channels * 32)
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(p_dropout)
self.heads = DetectionHeads(n_channels * 32, class_num)
def forward(self, x: Tensor) -> Tensor:
# x: (batch_size, 3, 30, 108)
# output: (batch_size, 4, class_num)
x = self.conv1(x)
x = self.act(x)
x = self.pool2d(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.act(x)
x = self.pool2d(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.act(x)
x = self.pool2d(x)
x = self.bn3(x)
x = self.conv4(x)
x = self.act(x)
x = self.pool2d(x)
x = self.bn4(x)
x = self.conv5(x)
x = self.act(x)
x = self.pool1d(x)
x = self.bn5(x)
x = self.conv6(x)
x = self.act(x)
x = self.pool1d(x)
x = self.bn6(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.heads(x)
return x
char_dict = {
"0": 0,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"-": 10,
"+": 11,
"=": 12,
"a": 13,
"b": 14,
"c": 15,
"d": 16,
"e": 17,
"f": 18,
"g": 19,
"h": 20,
"i": 21,
"j": 22,
"k": 23,
"l": 24,
"m": 25,
"n": 26,
"o": 27,
"p": 28,
"q": 29,
"r": 30,
"s": 31,
"t": 32,
"u": 33,
"v": 34,
"w": 35,
"x": 36,
"y": 37,
"z": 38,
}
char_dict_rev = {v: k for k, v in char_dict.items()}
def tensor_to_text(tensor: torch.Tensor) -> str:
text = ""
for i in tensor:
text += char_dict_rev[torch.argmax(i).item()]
return text
def tensors_to_texts(tensors: torch.Tensor) -> list[str]:
texts = []
for tensor in tensors:
texts.append(tensor_to_text(tensor))
return texts
if __name__ == "__main__":
model = Baseline2024.from_pretrained("./")
dir = Path("/tmp/data/test-data")
captchas = [str(captcha) for captcha in dir.glob("*.jpg")]
dataset = (
Dataset.from_dict({"image": captchas, "path": captchas})
.cast_column("image", Image())
.with_format("torch")
)
loader = DataLoader(dataset, batch_size=16)
model.eval()
submission = "submission.csv"
with open(submission, "w") as f, torch.no_grad():
writer = csv.writer(f)
writer.writerow(["filename", "text"])
for batch in loader:
image = batch["image"].float() / 255.0
output = model(image)
texts = tensors_to_texts(output)
for i, text in enumerate(texts):
file = Path(batch["path"][i]).name
writer.writerow([file, text])