import os import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download # ============================================================ # CONFIG & DOWNLOAD WEIGHTS FROM HF # ============================================================ REPO_ID = "ASomeoneWhoInterestedWithAI/LookThem_V7.6-ImageNet100" FILENAME = "LookThem_V76_LiteResidualClassifier.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("⏳ Downloading model weights from Hugging Face Hub...") MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) print(f"✅ Weights downloaded to: {MODEL_PATH}") # ============================================================ # IMAGENET-100 LABELS # ============================================================ CLASS_NAMES = [ "bonnet, poke bonnet", "green mamba", "langur", "Doberman, Doberman pinscher", "gyromitra", "Saluki, gazelle hound", "vacuum, vacuum cleaner", "window screen", "cocktail shaker", "garden spider, Aranea diademata", "garter snake, grass snake", "carbonara", "pineapple, ananas", "computer keyboard, keypad", "tripod", "komondor", "American lobster, Northern lobster, Maine lobster, Homarus americanus", "bannister, banister, balustrade, balusters, handrail", "honeycomb", "tile roof", "papillon", "boathouse", "stinkhorn, carrion fungus", "jean, blue jean, denim", "Chihuahua", "Chesapeake Bay retriever", "robin, American robin, Turdus migratorius", "tub, vat", "Great Dane", "rotisserie", "bottlecap", "throne", "little blue heron, Egretta caerulea", "rock crab, Cancer irroratus", "Rottweiler", "lorikeet", "Gila monster, Heloderma suspectum", "head cabbage", "car wheel", "coyote, prairie wolf, brush wolf, Canis latrans", "moped", "milk can", "mixing bowl", "toy terrier", "chocolate sauce, chocolate syrup", "rocking chair, rocker", "wing", "park bench", "ambulance", "football helmet", "leafhopper", "cauliflower", "pirate, pirate ship", "purse", "hare", "lampshade, lamp shade", "fiddler crab", "standard poodle", "Shih-Tzu", "pedestal, plinth, footstall", "gibbon, Hylobates lar", "safety pin", "English foxhound", "chime, bell, gong", "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "bassinet", "wild boar, boar, Sus scrofa", "theater curtain, theatre curtain", "dung beetle", "hognose snake, puff adder, sand viper", "Mexican hairless", "mortarboard", "Walker hound, Walker foxhound", "red fox, Vulpes vulpes", "modem", "slide rule, slipstick", "walking stick, walkingstick, stick insect", "cinema, movie theater, movie theatre, movie house, picture palace", "meerkat, mierkat", "kuvasz", "obelisk", "harmonica, mouth organ, harp, mouth harp", "sarong", "mousetrap", "hard disc, hard disk, fixed disk", "American coot, marsh hen, mud hen, water hen, Fulica americana", "reel", "pickup, pickup truck", "iron, smoothing iron", "tabby, tabby cat", "ski mask", "vizsla, Hungarian pointer", "laptop, laptop computer", "stretcher", "Dutch oven", "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "boxer", "gasmask, respirator, gas helmet", "goose", "borzoi, Russian wolfhound" ] # ============================================================ # TRANSFORM # ============================================================ transform = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ) ]) # ============================================================ # LOOKTHEM LAYER MODEL ARCHITECTURE # ============================================================ class LookThemLayer(nn.Module): def __init__(self, num_tokens, in_features, hidden_dim): super().__init__() self.num_tokens = num_tokens self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1)) self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1)) self._init_weights() def _init_weights(self): for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]: nn.init.kaiming_uniform_(w, a=math.sqrt(5)) def forward(self, x): N = self.num_tokens h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1 out_m1 = torch.einsum('btj,tjk->btk', F.gelu(h1), self.mod1_w2) + self.mod1_b2 h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1 out_m2 = torch.einsum('btj,tjk->btk', F.gelu(h2), self.mod2_w2) + self.mod2_b2 out_m2_safe = out_m2 + 1e-5 compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1)) compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2)) bias_reshaped = self.trans_b.view(1, 1, N, 1) trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 mask = 1.0 - torch.eye(N, device=x.device) interaksi_masked = interaksi * mask.view(1, N, N, 1) return interaksi_masked.sum(dim=2) / (N - 1.0) class LookThemBackbone(nn.Module): def __init__(self): super().__init__() self.stream_a = nn.Sequential( nn.Conv2d(3, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) self.stream_b = nn.Sequential( nn.Conv2d(3, 16, 3, stride=1, padding=1), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) self.lookthemA = LookThemLayer(num_tokens=64, in_features=64, hidden_dim=32) self.lookthemB = LookThemLayer(num_tokens=64, in_features=64, hidden_dim=32) self.lookthem = LookThemLayer(num_tokens=64, in_features=128, hidden_dim=32) self.compressor = nn.Conv1d(128, 64, kernel_size=1) def forward(self, x): B = x.size(0) feat_a = self.stream_a(x).view(B, 64, 64).transpose(1, 2) feat_a = self.lookthemA(feat_a) feat_b = self.stream_b(x).view(B, 64, 64).transpose(1, 2) feat_b = self.lookthemB(feat_b) combined = torch.cat([feat_a, feat_b], dim=2) out = self.lookthem(combined).transpose(1, 2) return self.compressor(out) class LiteResidualBlock(nn.Module): def __init__(self, dim, dropout=0.05): super().__init__() self.block = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim) ) self.norm = nn.LayerNorm(dim) def forward(self, x): return self.norm(x + self.block(x)) class EfficientResidualClassifier(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.input_proj = nn.Sequential( nn.Linear(4096, 256), nn.GELU(), nn.Dropout(0.08) ) self.res1 = LiteResidualBlock(256) self.res2 = LiteResidualBlock(256) self.head = nn.Sequential( nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 100) ) def forward(self, x): x = self.flatten(x) x = self.input_proj(x) x = self.res1(x) x = self.res2(x) return self.head(x) class FullModel(nn.Module): def __init__(self): super().__init__() self.backbone = LookThemBackbone() self.classifier = EfficientResidualClassifier() def forward(self, x): return self.classifier(self.backbone(x)) # ============================================================ # INITIALIZE & SETUP MODEL # ============================================================ model = FullModel().to(device) state_dict = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(state_dict) model.eval() print("🧠 Model loaded into Gradio Space safely!") # ============================================================ # GRADIO INFERENCE FUNCTION # ============================================================ def predict_image(pil_img): if pil_img is None: return {} # Preprocess the PIL image directly from Gradio x = transform(pil_img) x = x.unsqueeze(0).to(device) with torch.no_grad(): output = model(x) probs = torch.softmax(output, dim=1) top5_prob, top5_idx = torch.topk(probs, 5) # Format return dictionary for Gradio label component: {"ClassName": confidence_float} predictions = {} for rank in range(5): idx = top5_idx[0][rank].item() prob = top5_prob[0][rank].item() predictions[CLASS_NAMES[idx]] = prob return predictions # ============================================================ # GRADIO INTERFACE DESIGN # ============================================================ demo = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Upload Input Image"), outputs=gr.Label(num_top_classes=5, label="Top 5 ImageNet-100 Predictions"), title="LookThem V7.6 - ImageNet100 Classifier", description="Drop or upload an image to evaluate it using the LookThem LiteResidualClassifier pipeline.", flagging_mode="never" ) if __name__ == "__main__": demo.launch()