LookThem_STL-10 / app.py
ASomeoneWhoInterestedWithAI's picture
Change resize mechanism
1bb76d7 verified
import math
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
from PIL import Image
# =========================================================
# 1. LOOKTHEM CORE LAYER
# =========================================================
class LookThemLayer(nn.Module):
def __init__(self, num_tokens, in_features, hidden_dim):
super(LookThemLayer, self).__init__()
self.num_tokens = num_tokens
self.in_features = in_features
self.mod1_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod1_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod1_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod1_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self.mod2_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod2_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod2_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod2_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self.trans_w = nn.Parameter(
torch.randn(num_tokens, 1, 1)
)
self.trans_b = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self._init_weights()
def _init_weights(self):
for w in [
self.mod1_w1,
self.mod2_w1,
self.mod1_w2,
self.mod2_w2,
self.trans_w
]:
nn.init.kaiming_uniform_(
w,
a=math.sqrt(5)
)
def forward(self, x):
N = self.num_tokens
h1 = (
torch.einsum(
'bti,tij->btj',
x,
self.mod1_w1
)
+ self.mod1_b1
)
out_m1 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h1),
self.mod1_w2
)
+ self.mod1_b2
)
h2 = (
torch.einsum(
'bti,tij->btj',
x,
self.mod2_w1
)
+ self.mod2_b1
)
out_m2 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h2),
self.mod2_w2
)
+ self.mod2_b2
)
out_m2_safe = out_m2 + 1e-5
compare = torch.tanh(
out_m1.unsqueeze(2) /
out_m2_safe.unsqueeze(1)
)
compare2 = torch.tanh(
out_m1.unsqueeze(1) /
out_m2_safe.unsqueeze(2)
)
bias_reshaped = self.trans_b.view(
1,
1,
N,
1
)
trans_compare = (
torch.einsum(
'bije,jef->bijf',
compare,
self.trans_w
)
+ bias_reshaped
)
trans_compare2 = (
torch.einsum(
'bije,jef->bijf',
compare2,
self.trans_w
)
+ bias_reshaped
)
interaction = (
trans_compare * x.unsqueeze(2)
+ trans_compare2 * x.unsqueeze(1)
) / 2
mask = 1.0 - torch.eye(
N,
device=x.device
)
interaction_masked = (
interaction *
mask.view(1, N, N, 1)
)
return (
interaction_masked.sum(dim=2)
/ (N - 1.0)
)
# =========================================================
# 2. LOOKTHEM STL MODEL
# =========================================================
class LookThemSTLV1(nn.Module):
def __init__(self):
super(LookThemSTLV1, self).__init__()
self.stream_a = nn.Sequential(
nn.Conv2d(
3, 16,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16, 32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32, 64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8))
)
self.stream_b = nn.Sequential(
nn.Conv2d(
3, 16,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16, 32,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32, 64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8))
)
self.lookthemA = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=16
)
self.lookthemB = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=16
)
self.lookthem = LookThemLayer(
num_tokens=64,
in_features=128,
hidden_dim=32
)
self.compressor = nn.AdaptiveAvgPool1d(32)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 32, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 10)
)
def forward(self, x):
batch_size = x.size(0)
feat_a = self.stream_a(x)
feat_a_flat = feat_a.view(
batch_size,
64,
64
)
feat_a_tokens = feat_a_flat.transpose(1, 2)
feat_a_lt = self.lookthemA(feat_a_tokens)
feat_b = self.stream_b(x)
feat_b_tokens = (
feat_b
.view(batch_size, 64, 64)
.transpose(1, 2)
)
feat_b_lt = self.lookthemB(feat_b_tokens)
tokens_combined = torch.cat(
[feat_a_lt, feat_b_lt],
dim=2
)
out_lookthem = self.lookthem(tokens_combined)
compressed = self.compressor(out_lookthem)
return self.classifier(compressed)
# =========================================================
# 3. DEVICE
# =========================================================
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# =========================================================
# 4. STL-10 CLASSES
# =========================================================
classes = [
"airplane",
"bird",
"car",
"cat",
"deer",
"dog",
"horse",
"monkey",
"ship",
"truck"
]
# =========================================================
# 5. IMAGE TRANSFORM
# =========================================================
transform = transforms.Compose([
transforms.Resize(112),
transforms.CenterCrop(96),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)
)
])
# =========================================================
# 6. LOAD MODEL FROM HUGGING FACE
# =========================================================
model_path = hf_hub_download(
repo_id="ASomeoneWhoInterestedWithAI/LookThem_STL-10",
filename="LookThem_STL.pth"
)
model = LookThemSTLV1().to(device)
model.load_state_dict(
torch.load(
model_path,
map_location=device
)
)
model.eval()
# =========================================================
# 7. INFERENCE FUNCTION
# =========================================================
def predict(image):
image = image.convert("RGB")
input_tensor = transform(image)
input_tensor = input_tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output, dim=1)
probs = probabilities[0].cpu().numpy()
return {
classes[i]: float(probs[i])
for i in range(len(classes))
}
# =========================================================
# 8. GRADIO UI
# =========================================================
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="LookThem STL-10",
description=(
"Relational dual-stream with ratio based attention image classifier "
"trained on STL-10."
)
)
demo.launch()