| import os |
| import io |
| import math |
| from PIL import Image |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import torchvision.transforms as transforms |
|
|
| |
| |
| |
|
|
| MODEL_PATH = "LookThem_V76_LiteResidualClassifier.pth" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
|
|
| |
| CLASS_NAMES = [ |
|
|
| "bonnet, poke bonnet", |
| "green mamba", |
| "langur", |
| "Doberman, Doberman pinscher", |
| "gyromitra", |
| "Saluki, gazelle hound", |
| "vacuum, vacuum cleaner", |
| "window screen", |
| "cocktail shaker", |
| "garden spider, Aranea diademata", |
| "garter snake, grass snake", |
| "carbonara", |
| "pineapple, ananas", |
| "computer keyboard, keypad", |
| "tripod", |
| "komondor", |
| "American lobster, Northern lobster, Maine lobster, Homarus americanus", |
| "bannister, banister, balustrade, balusters, handrail", |
| "honeycomb", |
| "tile roof", |
| "papillon", |
| "boathouse", |
| "stinkhorn, carrion fungus", |
| "jean, blue jean, denim", |
| "Chihuahua", |
| "Chesapeake Bay retriever", |
| "robin, American robin, Turdus migratorius", |
| "tub, vat", |
| "Great Dane", |
| "rotisserie", |
| "bottlecap", |
| "throne", |
| "little blue heron, Egretta caerulea", |
| "rock crab, Cancer irroratus", |
| "Rottweiler", |
| "lorikeet", |
| "Gila monster, Heloderma suspectum", |
| "head cabbage", |
| "car wheel", |
| "coyote, prairie wolf, brush wolf, Canis latrans", |
| "moped", |
| "milk can", |
| "mixing bowl", |
| "toy terrier", |
| "chocolate sauce, chocolate syrup", |
| "rocking chair, rocker", |
| "wing", |
| "park bench", |
| "ambulance", |
| "football helmet", |
| "leafhopper", |
| "cauliflower", |
| "pirate, pirate ship", |
| "purse", |
| "hare", |
| "lampshade, lamp shade", |
| "fiddler crab", |
| "standard poodle", |
| "Shih-Tzu", |
| "pedestal, plinth, footstall", |
| "gibbon, Hylobates lar", |
| "safety pin", |
| "English foxhound", |
| "chime, bell, gong", |
| "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", |
| "bassinet", |
| "wild boar, boar, Sus scrofa", |
| "theater curtain, theatre curtain", |
| "dung beetle", |
| "hognose snake, puff adder, sand viper", |
| "Mexican hairless", |
| "mortarboard", |
| "Walker hound, Walker foxhound", |
| "red fox, Vulpes vulpes", |
| "modem", |
| "slide rule, slipstick", |
| "walking stick, walkingstick, stick insect", |
| "cinema, movie theater, movie theatre, movie house, picture palace", |
| "meerkat, mierkat", |
| "kuvasz", |
| "obelisk", |
| "harmonica, mouth organ, harp, mouth harp", |
| "sarong", |
| "mousetrap", |
| "hard disc, hard disk, fixed disk", |
| "American coot, marsh hen, mud hen, water hen, Fulica americana", |
| "reel", |
| "pickup, pickup truck", |
| "iron, smoothing iron", |
| "tabby, tabby cat", |
| "ski mask", |
| "vizsla, Hungarian pointer", |
| "laptop, laptop computer", |
| "stretcher", |
| "Dutch oven", |
| "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", |
| "boxer", |
| "gasmask, respirator, gas helmet", |
| "goose", |
| "borzoi, Russian wolfhound" |
|
|
| ] |
|
|
| |
| |
| |
|
|
| transform = transforms.Compose([ |
| transforms.Lambda(lambda img: img.convert("RGB")), |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.485, 0.456, 0.406), |
| std=(0.229, 0.224, 0.225) |
| ) |
| ]) |
|
|
| |
| |
| |
|
|
| class LookThemLayer(nn.Module): |
|
|
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super().__init__() |
|
|
| self.num_tokens = num_tokens |
|
|
| self.mod1_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
|
|
| self.mod1_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
|
|
| self.mod1_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
|
|
| self.mod1_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self.mod2_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
|
|
| self.mod2_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
|
|
| self.mod2_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
|
|
| self.mod2_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self.trans_w = nn.Parameter( |
| torch.randn(num_tokens, 1, 1) |
| ) |
|
|
| self.trans_b = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
|
|
| for w in [ |
| self.mod1_w1, |
| self.mod2_w1, |
| self.mod1_w2, |
| self.mod2_w2, |
| self.trans_w |
| ]: |
| nn.init.kaiming_uniform_(w, a=math.sqrt(5)) |
|
|
| def forward(self, x): |
|
|
| N = self.num_tokens |
|
|
| |
| |
| |
|
|
| h1 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod1_w1 |
| ) |
| + self.mod1_b1 |
| ) |
|
|
| out_m1 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h1), |
| self.mod1_w2 |
| ) |
| + self.mod1_b2 |
| ) |
|
|
| |
| |
| |
|
|
| h2 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod2_w1 |
| ) |
| + self.mod2_b1 |
| ) |
|
|
| out_m2 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h2), |
| self.mod2_w2 |
| ) |
| + self.mod2_b2 |
| ) |
|
|
| |
| |
| |
|
|
| out_m2_safe = out_m2 + 1e-5 |
|
|
| compare = torch.tanh( |
| out_m1.unsqueeze(2) |
| / out_m2_safe.unsqueeze(1) |
| ) |
|
|
| compare2 = torch.tanh( |
| out_m1.unsqueeze(1) |
| / out_m2_safe.unsqueeze(2) |
| ) |
|
|
| |
| |
| |
|
|
| bias_reshaped = self.trans_b.view(1, 1, N, 1) |
|
|
| trans_compare = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
|
|
| trans_compare2 = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare2, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
|
|
| |
| |
| |
|
|
| interaksi = ( |
| trans_compare * x.unsqueeze(2) |
| + trans_compare2 * x.unsqueeze(1) |
| ) / 2 |
|
|
| mask = 1.0 - torch.eye(N, device=x.device) |
|
|
| interaksi_masked = ( |
| interaksi |
| * mask.view(1, N, N, 1) |
| ) |
|
|
| return interaksi_masked.sum(dim=2) / (N - 1.0) |
|
|
| |
| |
| |
|
|
| class LookThemBackbone(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.stream_a = nn.Sequential( |
|
|
| nn.Conv2d(3, 16, 3, stride=2, padding=1), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
|
|
| nn.Conv2d(16, 32, 3, stride=2, padding=1), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
|
|
| nn.Conv2d(32, 64, 3, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.Conv2d(64, 64, 3, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
|
|
| self.stream_b = nn.Sequential( |
|
|
| nn.Conv2d(3, 16, 3, stride=1, padding=1), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
|
|
| nn.Conv2d(16, 32, 3, stride=1, padding=1), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
|
|
| nn.Conv2d(32, 64, 3, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.Conv2d(64, 64, 3, stride=1, padding=1), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
|
|
| self.lookthemA = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=32 |
| ) |
|
|
| self.lookthemB = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=32 |
| ) |
|
|
| self.lookthem = LookThemLayer( |
| num_tokens=64, |
| in_features=128, |
| hidden_dim=32 |
| ) |
|
|
| self.compressor = nn.Conv1d( |
| 128, |
| 64, |
| kernel_size=1 |
| ) |
|
|
| def forward(self, x): |
|
|
| B = x.size(0) |
|
|
| |
| |
| |
|
|
| feat_a = self.stream_a(x) |
|
|
| feat_a = ( |
| feat_a |
| .view(B, 64, 64) |
| .transpose(1, 2) |
| ) |
|
|
| feat_a = self.lookthemA(feat_a) |
|
|
| |
| |
| |
|
|
| feat_b = self.stream_b(x) |
|
|
| feat_b = ( |
| feat_b |
| .view(B, 64, 64) |
| .transpose(1, 2) |
| ) |
|
|
| feat_b = self.lookthemB(feat_b) |
|
|
| |
| |
| |
|
|
| combined = torch.cat( |
| [feat_a, feat_b], |
| dim=2 |
| ) |
|
|
| out = self.lookthem(combined) |
|
|
| out = out.transpose(1, 2) |
|
|
| compressed = self.compressor(out) |
|
|
| return compressed |
|
|
| |
| |
| |
|
|
| class LiteResidualBlock(nn.Module): |
|
|
| def __init__(self, dim, dropout=0.05): |
| super().__init__() |
|
|
| self.block = nn.Sequential( |
|
|
| nn.Linear(dim, dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
|
|
| nn.Linear(dim, dim) |
| ) |
|
|
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
|
|
| residual = x |
|
|
| x = self.block(x) |
|
|
| x = x + residual |
|
|
| x = self.norm(x) |
|
|
| return x |
|
|
| class EfficientResidualClassifier(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.flatten = nn.Flatten() |
|
|
| self.input_proj = nn.Sequential( |
|
|
| nn.Linear(4096, 256), |
| nn.GELU(), |
| nn.Dropout(0.08) |
| ) |
|
|
| self.res1 = LiteResidualBlock(256) |
| self.res2 = LiteResidualBlock(256) |
|
|
| self.head = nn.Sequential( |
|
|
| nn.Linear(256, 128), |
| nn.GELU(), |
|
|
| nn.Linear(128, 100) |
| ) |
|
|
| def forward(self, x): |
|
|
| x = self.flatten(x) |
|
|
| x = self.input_proj(x) |
|
|
| x = self.res1(x) |
|
|
| x = self.res2(x) |
|
|
| x = self.head(x) |
|
|
| return x |
|
|
| |
| |
| |
|
|
| class FullModel(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.backbone = LookThemBackbone() |
|
|
| self.classifier = EfficientResidualClassifier() |
|
|
| def forward(self, x): |
|
|
| feat = self.backbone(x) |
|
|
| out = self.classifier(feat) |
|
|
| return out |
|
|
| |
| |
| |
|
|
| print("🧠 Loading model...") |
|
|
| model = FullModel().to(device) |
|
|
| state_dict = torch.load( |
| MODEL_PATH, |
| map_location=device |
| ) |
|
|
| model.load_state_dict(state_dict) |
|
|
| model.eval() |
|
|
| print("✅ Model loaded!") |
|
|
| |
| |
| |
|
|
| def predict_image(image_path): |
|
|
| img = Image.open(image_path) |
|
|
| x = transform(img) |
|
|
| x = x.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
|
|
| output = model(x) |
|
|
| probs = torch.softmax(output, dim=1) |
|
|
| top5_prob, top5_idx = torch.topk(probs, 5) |
|
|
| print("\n🏆 TOP 5 PREDICTIONS:\n") |
|
|
| for rank in range(5): |
|
|
| idx = top5_idx[0][rank].item() |
|
|
| prob = top5_prob[0][rank].item() * 100 |
|
|
| print( |
| f"{rank+1}. " |
| f"{CLASS_NAMES[idx]} " |
| f"({prob:.2f}%)" |
| ) |
|
|
| |
| |
| |
|
|
| print("\n===================================") |
| print("🧠 LookThem V7.6 Inference") |
| print("Type image path") |
| print("Type 'exit' to quit") |
| print("===================================\n") |
|
|
| while True: |
|
|
| image_path = input("📷 Image Path: ") |
|
|
| if image_path.lower() == "exit": |
| print("\n👋 Exiting...") |
| break |
|
|
| if not os.path.exists(image_path): |
| print("❌ File not found!\n") |
| continue |
|
|
| try: |
| predict_image(image_path) |
|
|
| except Exception as e: |
| print(f"\n❌ Error: {e}\n") |
|
|