import os import io import math from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms # ============================================================ # CONFIG # ============================================================ MODEL_PATH = "LookThem_V76_LiteResidualClassifier.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ============================================================ # IMAGENET-100 LABELS # ============================================================ # kalau punya labels asli tinggal ganti CLASS_NAMES = [ "bonnet, poke bonnet", "green mamba", "langur", "Doberman, Doberman pinscher", "gyromitra", "Saluki, gazelle hound", "vacuum, vacuum cleaner", "window screen", "cocktail shaker", "garden spider, Aranea diademata", "garter snake, grass snake", "carbonara", "pineapple, ananas", "computer keyboard, keypad", "tripod", "komondor", "American lobster, Northern lobster, Maine lobster, Homarus americanus", "bannister, banister, balustrade, balusters, handrail", "honeycomb", "tile roof", "papillon", "boathouse", "stinkhorn, carrion fungus", "jean, blue jean, denim", "Chihuahua", "Chesapeake Bay retriever", "robin, American robin, Turdus migratorius", "tub, vat", "Great Dane", "rotisserie", "bottlecap", "throne", "little blue heron, Egretta caerulea", "rock crab, Cancer irroratus", "Rottweiler", "lorikeet", "Gila monster, Heloderma suspectum", "head cabbage", "car wheel", "coyote, prairie wolf, brush wolf, Canis latrans", "moped", "milk can", "mixing bowl", "toy terrier", "chocolate sauce, chocolate syrup", "rocking chair, rocker", "wing", "park bench", "ambulance", "football helmet", "leafhopper", "cauliflower", "pirate, pirate ship", "purse", "hare", "lampshade, lamp shade", "fiddler crab", "standard poodle", "Shih-Tzu", "pedestal, plinth, footstall", "gibbon, Hylobates lar", "safety pin", "English foxhound", "chime, bell, gong", "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "bassinet", "wild boar, boar, Sus scrofa", "theater curtain, theatre curtain", "dung beetle", "hognose snake, puff adder, sand viper", "Mexican hairless", "mortarboard", "Walker hound, Walker foxhound", "red fox, Vulpes vulpes", "modem", "slide rule, slipstick", "walking stick, walkingstick, stick insect", "cinema, movie theater, movie theatre, movie house, picture palace", "meerkat, mierkat", "kuvasz", "obelisk", "harmonica, mouth organ, harp, mouth harp", "sarong", "mousetrap", "hard disc, hard disk, fixed disk", "American coot, marsh hen, mud hen, water hen, Fulica americana", "reel", "pickup, pickup truck", "iron, smoothing iron", "tabby, tabby cat", "ski mask", "vizsla, Hungarian pointer", "laptop, laptop computer", "stretcher", "Dutch oven", "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "boxer", "gasmask, respirator, gas helmet", "goose", "borzoi, Russian wolfhound" ] # ============================================================ # TRANSFORM # ============================================================ transform = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ) ]) # ============================================================ # LOOKTHEM LAYER # ============================================================ class LookThemLayer(nn.Module): def __init__(self, num_tokens, in_features, hidden_dim): super().__init__() self.num_tokens = num_tokens self.mod1_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod1_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod1_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod1_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) self.mod2_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod2_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod2_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod2_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) self.trans_w = nn.Parameter( torch.randn(num_tokens, 1, 1) ) self.trans_b = nn.Parameter( torch.zeros(num_tokens, 1) ) self._init_weights() def _init_weights(self): for w in [ self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w ]: nn.init.kaiming_uniform_(w, a=math.sqrt(5)) def forward(self, x): N = self.num_tokens # ==================================================== # MOD 1 # ==================================================== h1 = ( torch.einsum( 'bti,tij->btj', x, self.mod1_w1 ) + self.mod1_b1 ) out_m1 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h1), self.mod1_w2 ) + self.mod1_b2 ) # ==================================================== # MOD 2 # ==================================================== h2 = ( torch.einsum( 'bti,tij->btj', x, self.mod2_w1 ) + self.mod2_b1 ) out_m2 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h2), self.mod2_w2 ) + self.mod2_b2 ) # ==================================================== # COMPARISON # ==================================================== out_m2_safe = out_m2 + 1e-5 compare = torch.tanh( out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) ) compare2 = torch.tanh( out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) ) # ==================================================== # TRANSFORM # ==================================================== bias_reshaped = self.trans_b.view(1, 1, N, 1) trans_compare = ( torch.einsum( 'bije,jef->bijf', compare, self.trans_w ) + bias_reshaped ) trans_compare2 = ( torch.einsum( 'bije,jef->bijf', compare2, self.trans_w ) + bias_reshaped ) # ==================================================== # INTERACTION # ==================================================== interaksi = ( trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1) ) / 2 mask = 1.0 - torch.eye(N, device=x.device) interaksi_masked = ( interaksi * mask.view(1, N, N, 1) ) return interaksi_masked.sum(dim=2) / (N - 1.0) # ============================================================ # BACKBONE # ============================================================ class LookThemBackbone(nn.Module): def __init__(self): super().__init__() self.stream_a = nn.Sequential( nn.Conv2d(3, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) self.stream_b = nn.Sequential( nn.Conv2d(3, 16, 3, stride=1, padding=1), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) self.lookthemA = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=32 ) self.lookthemB = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=32 ) self.lookthem = LookThemLayer( num_tokens=64, in_features=128, hidden_dim=32 ) self.compressor = nn.Conv1d( 128, 64, kernel_size=1 ) def forward(self, x): B = x.size(0) # ==================================================== # STREAM A # ==================================================== feat_a = self.stream_a(x) feat_a = ( feat_a .view(B, 64, 64) .transpose(1, 2) ) feat_a = self.lookthemA(feat_a) # ==================================================== # STREAM B # ==================================================== feat_b = self.stream_b(x) feat_b = ( feat_b .view(B, 64, 64) .transpose(1, 2) ) feat_b = self.lookthemB(feat_b) # ==================================================== # COMBINE # ==================================================== combined = torch.cat( [feat_a, feat_b], dim=2 ) out = self.lookthem(combined) out = out.transpose(1, 2) compressed = self.compressor(out) return compressed # ============================================================ # CLASSIFIER # ============================================================ class LiteResidualBlock(nn.Module): def __init__(self, dim, dropout=0.05): super().__init__() self.block = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim) ) self.norm = nn.LayerNorm(dim) def forward(self, x): residual = x x = self.block(x) x = x + residual x = self.norm(x) return x class EfficientResidualClassifier(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.input_proj = nn.Sequential( nn.Linear(4096, 256), nn.GELU(), nn.Dropout(0.08) ) self.res1 = LiteResidualBlock(256) self.res2 = LiteResidualBlock(256) self.head = nn.Sequential( nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 100) ) def forward(self, x): x = self.flatten(x) x = self.input_proj(x) x = self.res1(x) x = self.res2(x) x = self.head(x) return x # ============================================================ # FULL MODEL # ============================================================ class FullModel(nn.Module): def __init__(self): super().__init__() self.backbone = LookThemBackbone() self.classifier = EfficientResidualClassifier() def forward(self, x): feat = self.backbone(x) out = self.classifier(feat) return out # ============================================================ # LOAD MODEL # ============================================================ print("🧠 Loading model...") model = FullModel().to(device) state_dict = torch.load( MODEL_PATH, map_location=device ) model.load_state_dict(state_dict) model.eval() print("āœ… Model loaded!") # ============================================================ # PREDICTION FUNCTION # ============================================================ def predict_image(image_path): img = Image.open(image_path) x = transform(img) x = x.unsqueeze(0).to(device) with torch.no_grad(): output = model(x) probs = torch.softmax(output, dim=1) top5_prob, top5_idx = torch.topk(probs, 5) print("\nšŸ† TOP 5 PREDICTIONS:\n") for rank in range(5): idx = top5_idx[0][rank].item() prob = top5_prob[0][rank].item() * 100 print( f"{rank+1}. " f"{CLASS_NAMES[idx]} " f"({prob:.2f}%)" ) # ============================================================ # INTERACTIVE LOOP # ============================================================ print("\n===================================") print("🧠 LookThem V7.6 Inference") print("Type image path") print("Type 'exit' to quit") print("===================================\n") while True: image_path = input("šŸ“· Image Path: ") if image_path.lower() == "exit": print("\nšŸ‘‹ Exiting...") break if not os.path.exists(image_path): print("āŒ File not found!\n") continue try: predict_image(image_path) except Exception as e: print(f"\nāŒ Error: {e}\n")