import math import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from huggingface_hub import hf_hub_download from PIL import Image # ========================================================= # 1. LOOKTHEM CORE LAYER # ========================================================= class LookThemLayer(nn.Module): def __init__(self, num_tokens, in_features, hidden_dim): super(LookThemLayer, self).__init__() self.num_tokens = num_tokens self.in_features = in_features self.mod1_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod1_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod1_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod1_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) self.mod2_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod2_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod2_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod2_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) self.trans_w = nn.Parameter( torch.randn(num_tokens, 1, 1) ) self.trans_b = nn.Parameter( torch.zeros(num_tokens, 1) ) self._init_weights() def _init_weights(self): for w in [ self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w ]: nn.init.kaiming_uniform_( w, a=math.sqrt(5) ) def forward(self, x): N = self.num_tokens h1 = ( torch.einsum( 'bti,tij->btj', x, self.mod1_w1 ) + self.mod1_b1 ) out_m1 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h1), self.mod1_w2 ) + self.mod1_b2 ) h2 = ( torch.einsum( 'bti,tij->btj', x, self.mod2_w1 ) + self.mod2_b1 ) out_m2 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h2), self.mod2_w2 ) + self.mod2_b2 ) out_m2_safe = out_m2 + 1e-5 compare = torch.tanh( out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) ) compare2 = torch.tanh( out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) ) bias_reshaped = self.trans_b.view( 1, 1, N, 1 ) trans_compare = ( torch.einsum( 'bije,jef->bijf', compare, self.trans_w ) + bias_reshaped ) trans_compare2 = ( torch.einsum( 'bije,jef->bijf', compare2, self.trans_w ) + bias_reshaped ) interaction = ( trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1) ) / 2 mask = 1.0 - torch.eye( N, device=x.device ) interaction_masked = ( interaction * mask.view(1, N, N, 1) ) return ( interaction_masked.sum(dim=2) / (N - 1.0) ) # ========================================================= # 2. LOOKTHEM STL MODEL # ========================================================= class LookThemSTLV1(nn.Module): def __init__(self): super(LookThemSTLV1, self).__init__() self.stream_a = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) self.stream_b = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) self.lookthemA = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=16 ) self.lookthemB = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=16 ) self.lookthem = LookThemLayer( num_tokens=64, in_features=128, hidden_dim=32 ) self.compressor = nn.AdaptiveAvgPool1d(32) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(64 * 32, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 10) ) def forward(self, x): batch_size = x.size(0) feat_a = self.stream_a(x) feat_a_flat = feat_a.view( batch_size, 64, 64 ) feat_a_tokens = feat_a_flat.transpose(1, 2) feat_a_lt = self.lookthemA(feat_a_tokens) feat_b = self.stream_b(x) feat_b_tokens = ( feat_b .view(batch_size, 64, 64) .transpose(1, 2) ) feat_b_lt = self.lookthemB(feat_b_tokens) tokens_combined = torch.cat( [feat_a_lt, feat_b_lt], dim=2 ) out_lookthem = self.lookthem(tokens_combined) compressed = self.compressor(out_lookthem) return self.classifier(compressed) # ========================================================= # 3. DEVICE # ========================================================= device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # ========================================================= # 4. STL-10 CLASSES # ========================================================= classes = [ "airplane", "bird", "car", "cat", "deer", "dog", "horse", "monkey", "ship", "truck" ] # ========================================================= # 5. IMAGE TRANSFORM # ========================================================= transform = transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(96), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) ) ]) # ========================================================= # 6. LOAD MODEL FROM HUGGING FACE # ========================================================= model_path = hf_hub_download( repo_id="ASomeoneWhoInterestedWithAI/LookThem_STL-10", filename="LookThem_STL.pth" ) model = LookThemSTLV1().to(device) model.load_state_dict( torch.load( model_path, map_location=device ) ) model.eval() # ========================================================= # 7. INFERENCE FUNCTION # ========================================================= def predict(image): image = image.convert("RGB") input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) probabilities = F.softmax(output, dim=1) probs = probabilities[0].cpu().numpy() return { classes[i]: float(probs[i]) for i in range(len(classes)) } # ========================================================= # 8. GRADIO UI # ========================================================= demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="LookThem STL-10", description=( "Relational dual-stream with ratio based attention image classifier " "trained on STL-10." ) ) demo.launch()