| import os |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| import gradio as gr |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
| REPO_ID = "ASomeoneWhoInterestedWithAI/LookThem_V7.6-ImageNet100" |
| FILENAME = "LookThem_V76_LiteResidualClassifier.pth" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print("⏳ Downloading model weights from Hugging Face Hub...") |
| MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) |
| print(f"✅ Weights downloaded to: {MODEL_PATH}") |
|
|
| |
| |
| |
| CLASS_NAMES = [ |
| "bonnet, poke bonnet", "green mamba", "langur", "Doberman, Doberman pinscher", "gyromitra", |
| "Saluki, gazelle hound", "vacuum, vacuum cleaner", "window screen", "cocktail shaker", |
| "garden spider, Aranea diademata", "garter snake, grass snake", "carbonara", "pineapple, ananas", |
| "computer keyboard, keypad", "tripod", "komondor", "American lobster, Northern lobster, Maine lobster, Homarus americanus", |
| "bannister, banister, balustrade, balusters, handrail", "honeycomb", "tile roof", "papillon", |
| "boathouse", "stinkhorn, carrion fungus", "jean, blue jean, denim", "Chihuahua", |
| "Chesapeake Bay retriever", "robin, American robin, Turdus migratorius", "tub, vat", "Great Dane", |
| "rotisserie", "bottlecap", "throne", "little blue heron, Egretta caerulea", "rock crab, Cancer irroratus", |
| "Rottweiler", "lorikeet", "Gila monster, Heloderma suspectum", "head cabbage", "car wheel", |
| "coyote, prairie wolf, brush wolf, Canis latrans", "moped", "milk can", "mixing bowl", "toy terrier", |
| "chocolate sauce, chocolate syrup", "rocking chair, rocker", "wing", "park bench", "ambulance", |
| "football helmet", "leafhopper", "cauliflower", "pirate, pirate ship", "purse", "hare", |
| "lampshade, lamp shade", "fiddler crab", "standard poodle", "Shih-Tzu", "pedestal, plinth, footstall", |
| "gibbon, Hylobates lar", "safety pin", "English foxhound", "chime, bell, gong", |
| "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", |
| "bassinet", "wild boar, boar, Sus scrofa", "theater curtain, theatre curtain", "dung beetle", |
| "hognose snake, puff adder, sand viper", "Mexican hairless", "mortarboard", "Walker hound, Walker foxhound", |
| "red fox, Vulpes vulpes", "modem", "slide rule, slipstick", "walking stick, walkingstick, stick insect", |
| "cinema, movie theater, movie theatre, movie house, picture palace", "meerkat, mierkat", "kuvasz", |
| "obelisk", "harmonica, mouth organ, harp, mouth harp", "sarong", "mousetrap", "hard disc, hard disk, fixed disk", |
| "American coot, marsh hen, mud hen, water hen, Fulica americana", "reel", "pickup, pickup truck", |
| "iron, smoothing iron", "tabby, tabby cat", "ski mask", "vizsla, Hungarian pointer", "laptop, laptop computer", |
| "stretcher", "Dutch oven", "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "boxer", |
| "gasmask, respirator, gas helmet", "goose", "borzoi, Russian wolfhound" |
| ] |
|
|
| |
| |
| |
| transform = transforms.Compose([ |
| transforms.Lambda(lambda img: img.convert("RGB")), |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.485, 0.456, 0.406), |
| std=(0.229, 0.224, 0.225) |
| ) |
| ]) |
|
|
| |
| |
| |
| class LookThemLayer(nn.Module): |
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super().__init__() |
| self.num_tokens = num_tokens |
| self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) |
| self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) |
| self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) |
| self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) |
| |
| self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) |
| self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) |
| self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) |
| self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) |
| |
| self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1)) |
| self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1)) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]: |
| nn.init.kaiming_uniform_(w, a=math.sqrt(5)) |
|
|
| def forward(self, x): |
| N = self.num_tokens |
| h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1 |
| out_m1 = torch.einsum('btj,tjk->btk', F.gelu(h1), self.mod1_w2) + self.mod1_b2 |
|
|
| h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1 |
| out_m2 = torch.einsum('btj,tjk->btk', F.gelu(h2), self.mod2_w2) + self.mod2_b2 |
|
|
| out_m2_safe = out_m2 + 1e-5 |
| compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1)) |
| compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2)) |
|
|
| bias_reshaped = self.trans_b.view(1, 1, N, 1) |
| trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped |
| trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped |
|
|
| interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 |
| mask = 1.0 - torch.eye(N, device=x.device) |
| interaksi_masked = interaksi * mask.view(1, N, N, 1) |
|
|
| return interaksi_masked.sum(dim=2) / (N - 1.0) |
|
|
| class LookThemBackbone(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.stream_a = nn.Sequential( |
| nn.Conv2d(3, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.GELU(), |
| nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.GELU(), |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
| self.stream_b = nn.Sequential( |
| nn.Conv2d(3, 16, 3, stride=1, padding=1), nn.BatchNorm2d(16), nn.GELU(), |
| nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.GELU(), |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
| self.lookthemA = LookThemLayer(num_tokens=64, in_features=64, hidden_dim=32) |
| self.lookthemB = LookThemLayer(num_tokens=64, in_features=64, hidden_dim=32) |
| self.lookthem = LookThemLayer(num_tokens=64, in_features=128, hidden_dim=32) |
| self.compressor = nn.Conv1d(128, 64, kernel_size=1) |
|
|
| def forward(self, x): |
| B = x.size(0) |
| feat_a = self.stream_a(x).view(B, 64, 64).transpose(1, 2) |
| feat_a = self.lookthemA(feat_a) |
|
|
| feat_b = self.stream_b(x).view(B, 64, 64).transpose(1, 2) |
| feat_b = self.lookthemB(feat_b) |
|
|
| combined = torch.cat([feat_a, feat_b], dim=2) |
| out = self.lookthem(combined).transpose(1, 2) |
| return self.compressor(out) |
|
|
| class LiteResidualBlock(nn.Module): |
| def __init__(self, dim, dropout=0.05): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(dim, dim) |
| ) |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
| return self.norm(x + self.block(x)) |
|
|
| class EfficientResidualClassifier(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.flatten = nn.Flatten() |
| self.input_proj = nn.Sequential( |
| nn.Linear(4096, 256), nn.GELU(), nn.Dropout(0.08) |
| ) |
| self.res1 = LiteResidualBlock(256) |
| self.res2 = LiteResidualBlock(256) |
| self.head = nn.Sequential( |
| nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 100) |
| ) |
|
|
| def forward(self, x): |
| x = self.flatten(x) |
| x = self.input_proj(x) |
| x = self.res1(x) |
| x = self.res2(x) |
| return self.head(x) |
|
|
| class FullModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.backbone = LookThemBackbone() |
| self.classifier = EfficientResidualClassifier() |
|
|
| def forward(self, x): |
| return self.classifier(self.backbone(x)) |
|
|
| |
| |
| |
| model = FullModel().to(device) |
| state_dict = torch.load(MODEL_PATH, map_location=device) |
| model.load_state_dict(state_dict) |
| model.eval() |
| print("🧠 Model loaded into Gradio Space safely!") |
|
|
| |
| |
| |
| def predict_image(pil_img): |
| if pil_img is None: |
| return {} |
| |
| |
| x = transform(pil_img) |
| x = x.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| output = model(x) |
| probs = torch.softmax(output, dim=1) |
| top5_prob, top5_idx = torch.topk(probs, 5) |
|
|
| |
| predictions = {} |
| for rank in range(5): |
| idx = top5_idx[0][rank].item() |
| prob = top5_prob[0][rank].item() |
| predictions[CLASS_NAMES[idx]] = prob |
| |
| return predictions |
|
|
| |
| |
| |
| demo = gr.Interface( |
| fn=predict_image, |
| inputs=gr.Image(type="pil", label="Upload Input Image"), |
| outputs=gr.Label(num_top_classes=5, label="Top 5 ImageNet-100 Predictions"), |
| title="LookThem V7.6 - ImageNet100 Classifier", |
| description="Drop or upload an image to evaluate it using the LookThem LiteResidualClassifier pipeline.", |
| flagging_mode="never" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|