| import math |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
|
|
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
|
|
| |
| |
| |
|
|
| class LookThemLayer(nn.Module): |
|
|
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super(LookThemLayer, self).__init__() |
|
|
| self.num_tokens = num_tokens |
| self.in_features = in_features |
|
|
| self.mod1_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
|
|
| self.mod1_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
|
|
| self.mod1_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
|
|
| self.mod1_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self.mod2_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
|
|
| self.mod2_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
|
|
| self.mod2_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
|
|
| self.mod2_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self.trans_w = nn.Parameter( |
| torch.randn(num_tokens, 1, 1) |
| ) |
|
|
| self.trans_b = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
|
|
| for w in [ |
| self.mod1_w1, |
| self.mod2_w1, |
| self.mod1_w2, |
| self.mod2_w2, |
| self.trans_w |
| ]: |
| nn.init.kaiming_uniform_( |
| w, |
| a=math.sqrt(5) |
| ) |
|
|
| def forward(self, x): |
|
|
| N = self.num_tokens |
|
|
| h1 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod1_w1 |
| ) |
| + self.mod1_b1 |
| ) |
|
|
| out_m1 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h1), |
| self.mod1_w2 |
| ) |
| + self.mod1_b2 |
| ) |
|
|
| h2 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod2_w1 |
| ) |
| + self.mod2_b1 |
| ) |
|
|
| out_m2 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h2), |
| self.mod2_w2 |
| ) |
| + self.mod2_b2 |
| ) |
|
|
| out_m2_safe = out_m2 + 1e-5 |
|
|
| compare = torch.tanh( |
| out_m1.unsqueeze(2) / |
| out_m2_safe.unsqueeze(1) |
| ) |
|
|
| compare2 = torch.tanh( |
| out_m1.unsqueeze(1) / |
| out_m2_safe.unsqueeze(2) |
| ) |
|
|
| bias_reshaped = self.trans_b.view( |
| 1, |
| 1, |
| N, |
| 1 |
| ) |
|
|
| trans_compare = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
|
|
| trans_compare2 = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare2, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
|
|
| interaction = ( |
| trans_compare * x.unsqueeze(2) |
| + trans_compare2 * x.unsqueeze(1) |
| ) / 2 |
|
|
| mask = 1.0 - torch.eye( |
| N, |
| device=x.device |
| ) |
|
|
| interaction_masked = ( |
| interaction * |
| mask.view(1, N, N, 1) |
| ) |
|
|
| return ( |
| interaction_masked.sum(dim=2) |
| / (N - 1.0) |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class LookThemSTLV1(nn.Module): |
|
|
| def __init__(self): |
| super(LookThemSTLV1, self).__init__() |
|
|
| self.stream_a = nn.Sequential( |
|
|
| nn.Conv2d( |
| 3, 16, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 16, 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 32, 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
|
|
| self.stream_b = nn.Sequential( |
|
|
| nn.Conv2d( |
| 3, 16, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 16, 32, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 32, 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
|
|
| self.lookthemA = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=16 |
| ) |
|
|
| self.lookthemB = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=16 |
| ) |
|
|
| self.lookthem = LookThemLayer( |
| num_tokens=64, |
| in_features=128, |
| hidden_dim=32 |
| ) |
|
|
| self.compressor = nn.AdaptiveAvgPool1d(32) |
|
|
| self.classifier = nn.Sequential( |
|
|
| nn.Flatten(), |
|
|
| nn.Linear(64 * 32, 512), |
| nn.ReLU(), |
| nn.Dropout(0.4), |
|
|
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
|
|
| nn.Linear(256, 10) |
| ) |
|
|
| def forward(self, x): |
|
|
| batch_size = x.size(0) |
|
|
| feat_a = self.stream_a(x) |
|
|
| feat_a_flat = feat_a.view( |
| batch_size, |
| 64, |
| 64 |
| ) |
|
|
| feat_a_tokens = feat_a_flat.transpose(1, 2) |
|
|
| feat_a_lt = self.lookthemA(feat_a_tokens) |
|
|
| feat_b = self.stream_b(x) |
|
|
| feat_b_tokens = ( |
| feat_b |
| .view(batch_size, 64, 64) |
| .transpose(1, 2) |
| ) |
|
|
| feat_b_lt = self.lookthemB(feat_b_tokens) |
|
|
| tokens_combined = torch.cat( |
| [feat_a_lt, feat_b_lt], |
| dim=2 |
| ) |
|
|
| out_lookthem = self.lookthem(tokens_combined) |
|
|
| compressed = self.compressor(out_lookthem) |
|
|
| return self.classifier(compressed) |
|
|
|
|
| |
| |
| |
|
|
| device = torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| classes = [ |
| "airplane", |
| "bird", |
| "car", |
| "cat", |
| "deer", |
| "dog", |
| "horse", |
| "monkey", |
| "ship", |
| "truck" |
| ] |
|
|
|
|
| |
| |
| |
|
|
| transform = transforms.Compose([ |
|
|
| transforms.Resize(112), |
|
|
| transforms.CenterCrop(96), |
|
|
| transforms.ToTensor(), |
|
|
| transforms.Normalize( |
| (0.4914, 0.4822, 0.4465), |
| (0.2470, 0.2435, 0.2616) |
| ) |
| ]) |
|
|
|
|
| |
| |
| |
|
|
| model_path = hf_hub_download( |
| repo_id="ASomeoneWhoInterestedWithAI/LookThem_STL-10", |
| filename="LookThem_STL.pth" |
| ) |
|
|
| model = LookThemSTLV1().to(device) |
|
|
| model.load_state_dict( |
| torch.load( |
| model_path, |
| map_location=device |
| ) |
| ) |
|
|
| model.eval() |
|
|
|
|
| |
| |
| |
|
|
| def predict(image): |
|
|
| image = image.convert("RGB") |
|
|
| input_tensor = transform(image) |
|
|
| input_tensor = input_tensor.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
|
|
| output = model(input_tensor) |
|
|
| probabilities = F.softmax(output, dim=1) |
|
|
| probs = probabilities[0].cpu().numpy() |
|
|
| return { |
| classes[i]: float(probs[i]) |
| for i in range(len(classes)) |
| } |
|
|
|
|
| |
| |
| |
|
|
| demo = gr.Interface( |
| fn=predict, |
|
|
| inputs=gr.Image(type="pil"), |
|
|
| outputs=gr.Label(num_top_classes=5), |
|
|
| title="LookThem STL-10", |
|
|
| description=( |
| "Relational dual-stream with ratio based attention image classifier " |
| "trained on STL-10." |
| ) |
| ) |
|
|
| demo.launch() |