| # Cleaned code |
| ## Training |
|
|
| ```python |
| import os |
| import math |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| |
| import torchvision |
| import torchvision.transforms as transforms |
| |
| from torch.utils.data import DataLoader |
| |
| |
| # ========================================================= |
| # 1. DATA PREPARATION |
| # ========================================================= |
| |
| # Training augmentation and normalization pipeline. |
| # STL10 images are already 96x96, so no resize is required. |
| transform_train = transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| (0.4914, 0.4822, 0.4465), |
| (0.2470, 0.2435, 0.2616) |
| ) |
| ]) |
| |
| # Validation / test preprocessing pipeline. |
| # Only normalization is applied for evaluation consistency. |
| transform_test = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize( |
| (0.4914, 0.4822, 0.4465), |
| (0.2470, 0.2435, 0.2616) |
| ) |
| ]) |
| |
| |
| # ========================================================= |
| # 2. STL10 DATASET LOADING |
| # ========================================================= |
| |
| # Automatically downloads STL10 into ./data |
| train_dataset = torchvision.datasets.STL10( |
| root='./data', |
| split='train', |
| download=True, |
| transform=transform_train |
| ) |
| |
| test_dataset = torchvision.datasets.STL10( |
| root='./data', |
| split='test', |
| download=True, |
| transform=transform_test |
| ) |
| |
| # Data loaders for batch training and validation |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=64, |
| shuffle=True, |
| num_workers=2 |
| ) |
| |
| val_loader = DataLoader( |
| test_dataset, |
| batch_size=64, |
| shuffle=False, |
| num_workers=2 |
| ) |
| |
| print(f"Training samples : {len(train_dataset)}") |
| print(f"Testing samples : {len(test_dataset)}") |
| |
| |
| # ========================================================= |
| # 3. CORE RELATIONAL LAYER — LOOKTHEM LAYER |
| # ========================================================= |
| |
| class LookThemLayer(nn.Module): |
| """ |
| Relational token-processing layer. |
| |
| Each token owns its own tiny dual-network pair: |
| - mod1 |
| - mod2 |
| |
| The outputs from both branches are compared against |
| every other token using ratio-based interaction maps. |
| |
| Final interactions are transformed and redistributed |
| back into the token space. |
| """ |
| |
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super(LookThemLayer, self).__init__() |
| |
| self.num_tokens = num_tokens |
| self.in_features = in_features |
| |
| # ------------------------------------------------- |
| # Branch 1 parameters |
| # ------------------------------------------------- |
| self.mod1_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
| |
| self.mod1_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
| |
| self.mod1_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
| |
| self.mod1_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| # ------------------------------------------------- |
| # Branch 2 parameters |
| # ------------------------------------------------- |
| self.mod2_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
| |
| self.mod2_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
| |
| self.mod2_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
| |
| self.mod2_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| # ------------------------------------------------- |
| # Relational transformation parameters |
| # ------------------------------------------------- |
| self.trans_w = nn.Parameter( |
| torch.randn(num_tokens, 1, 1) |
| ) |
| |
| self.trans_b = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """ |
| Kaiming initialization for all learnable projections. |
| """ |
| |
| for w in [ |
| self.mod1_w1, |
| self.mod2_w1, |
| self.mod1_w2, |
| self.mod2_w2, |
| self.trans_w |
| ]: |
| nn.init.kaiming_uniform_(w, a=math.sqrt(5)) |
| |
| def forward(self, x): |
| """ |
| Input shape: |
| [B, Tokens, Features] |
| |
| Output shape: |
| [B, Tokens, Features] |
| """ |
| |
| N = self.num_tokens |
| |
| # ================================================= |
| # Branch 1 forward pass |
| # ================================================= |
| h1 = ( |
| torch.einsum('bti,tij->btj', x, self.mod1_w1) |
| + self.mod1_b1 |
| ) |
| |
| out_m1 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h1), |
| self.mod1_w2 |
| ) |
| + self.mod1_b2 |
| ) |
| |
| # ================================================= |
| # Branch 2 forward pass |
| # ================================================= |
| h2 = ( |
| torch.einsum('bti,tij->btj', x, self.mod2_w1) |
| + self.mod2_b1 |
| ) |
| |
| out_m2 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h2), |
| self.mod2_w2 |
| ) |
| + self.mod2_b2 |
| ) |
| |
| # Numerical stabilization |
| out_m2_safe = out_m2 + 1e-5 |
| |
| # ================================================= |
| # Pairwise relational comparison |
| # ================================================= |
| |
| # Token-to-token directional comparison |
| compare = torch.tanh( |
| out_m1.unsqueeze(2) / |
| out_m2_safe.unsqueeze(1) |
| ) |
| |
| # Reverse-direction comparison |
| compare2 = torch.tanh( |
| out_m1.unsqueeze(1) / |
| out_m2_safe.unsqueeze(2) |
| ) |
| |
| # ================================================= |
| # Transform relational maps |
| # ================================================= |
| bias_reshaped = self.trans_b.view(1, 1, N, 1) |
| |
| trans_compare = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
| |
| trans_compare2 = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare2, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
| |
| # ================================================= |
| # Bidirectional interaction fusion |
| # ================================================= |
| interaction = ( |
| trans_compare * x.unsqueeze(2) |
| + trans_compare2 * x.unsqueeze(1) |
| ) / 2 |
| |
| # Remove self-interaction |
| mask = 1.0 - torch.eye(N, device=x.device) |
| |
| interaction_masked = ( |
| interaction * mask.view(1, N, N, 1) |
| ) |
| |
| # Aggregate all external token interactions |
| return interaction_masked.sum(dim=2) / (N - 1.0) |
| |
| |
| # ========================================================= |
| # 4. MAIN ARCHITECTURE — LOOKTHEM STL V1 |
| # ========================================================= |
| |
| class LookThemSTLV1(nn.Module): |
| """ |
| Dual-stream relational vision architecture. |
| |
| Stream A: |
| Macro-spatial extraction using aggressive downsampling. |
| |
| Stream B: |
| Higher-detail extraction using slower reduction. |
| |
| Both streams are fused inside relational LookThem layers. |
| """ |
| |
| def __init__(self): |
| super(LookThemSTLV1, self).__init__() |
| |
| # ================================================= |
| # STREAM A — MACRO STRUCTURE STREAM |
| # ================================================= |
| # |
| # Aggressive downsampling path focused on |
| # large-scale spatial structure extraction. |
| # |
| self.stream_a = nn.Sequential( |
| |
| nn.Conv2d( |
| 3, 16, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 16, 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 32, 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
| |
| # Final spatial alignment |
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
| |
| # ================================================= |
| # STREAM B — MICRO DETAIL STREAM |
| # ================================================= |
| # |
| # Slower reduction preserves more local detail |
| # before relational processing. |
| # |
| self.stream_b = nn.Sequential( |
| |
| nn.Conv2d( |
| 3, 16, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 16, 32, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 32, 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
| |
| # Match Stream A token resolution |
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
| |
| # ================================================= |
| # STREAM-SPECIFIC RELATIONAL PROCESSORS |
| # ================================================= |
| self.lookthemA = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=16 |
| ) |
| |
| self.lookthemB = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=16 |
| ) |
| |
| # ================================================= |
| # FUSION RELATIONAL PROCESSOR |
| # ================================================= |
| # |
| # Receives concatenated features from both streams. |
| # |
| self.lookthem = LookThemLayer( |
| num_tokens=64, |
| in_features=128, |
| hidden_dim=32 |
| ) |
| |
| # ================================================= |
| # TOKEN COMPRESSOR |
| # ================================================= |
| # |
| # Compresses token feature width before |
| # dense classification. |
| # |
| self.compressor = nn.AdaptiveAvgPool1d(32) |
| |
| # ================================================= |
| # CLASSIFIER HEAD |
| # ================================================= |
| # |
| # Progressive dense head with dropout |
| # regularization to reduce overfitting. |
| # |
| self.classifier = nn.Sequential( |
| nn.Flatten(), |
| |
| nn.Linear(64 * 32, 512), |
| nn.ReLU(), |
| nn.Dropout(0.4), |
| |
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| |
| nn.Linear(256, 10) |
| ) |
| |
| def forward(self, x): |
| |
| batch_size = x.size(0) |
| |
| # ================================================= |
| # STREAM A FORWARD PASS |
| # ================================================= |
| feat_a = self.stream_a(x) |
| |
| # Convert spatial map into token representation |
| feat_a_flat = feat_a.view(batch_size, 64, 64) |
| |
| feat_a_tokens = feat_a_flat.transpose(1, 2) |
| |
| # Relational processing |
| feat_a_lt = self.lookthemA(feat_a_tokens) |
| |
| # ================================================= |
| # STREAM B FORWARD PASS |
| # ================================================= |
| feat_b = self.stream_b(x) |
| |
| feat_b_tokens = ( |
| feat_b |
| .view(batch_size, 64, 64) |
| .transpose(1, 2) |
| ) |
| |
| feat_b_lt = self.lookthemB(feat_b_tokens) |
| |
| # ================================================= |
| # ASYMMETRIC FEATURE-LEVEL FUSION |
| # ================================================= |
| # |
| # Keeps token count fixed while expanding |
| # feature dimensionality. |
| # |
| tokens_combined = torch.cat( |
| [feat_a_lt, feat_b_lt], |
| dim=2 |
| ) |
| |
| # ================================================= |
| # FINAL RELATIONAL COGNITION |
| # ================================================= |
| out_lookthem = self.lookthem(tokens_combined) |
| |
| # Token compression |
| compressed = self.compressor(out_lookthem) |
| |
| # Final classification |
| return self.classifier(compressed) |
| |
| |
| # ========================================================= |
| # 5. TRAINING RUNTIME + CHECKPOINT SYSTEM |
| # ========================================================= |
| |
| device = torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| |
| model = LookThemSTLV1().to(device) |
| |
| criterion = nn.CrossEntropyLoss() |
| |
| optimizer = optim.Adam( |
| model.parameters(), |
| lr=0.001, |
| weight_decay=1e-4 |
| ) |
| |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=40 |
| ) |
| |
| start_epoch = 0 |
| checkpoint_path = "lookthem_stl_checkpoint.pth" |
| |
| |
| # ========================================================= |
| # CHECKPOINT RESUME |
| # ========================================================= |
| |
| if os.path.exists(checkpoint_path): |
| |
| print( |
| "Checkpoint detected. " |
| "Resuming previous experiment..." |
| ) |
| |
| checkpoint = torch.load(checkpoint_path) |
| |
| model.load_state_dict( |
| checkpoint['model_state_dict'] |
| ) |
| |
| optimizer.load_state_dict( |
| checkpoint['optimizer_state_dict'] |
| ) |
| |
| scheduler.load_state_dict( |
| checkpoint['scheduler_state_dict'] |
| ) |
| |
| start_epoch = checkpoint['epoch'] |
| |
| print( |
| f"Successfully resumed from " |
| f"epoch {start_epoch + 1}" |
| ) |
| |
| print( |
| f"Starting LookThem STL V1 training on {device}..." |
| ) |
| |
| |
| # ========================================================= |
| # TRAINING LOOP |
| # ========================================================= |
| |
| for epoch in range(start_epoch, 100): |
| |
| model.train() |
| |
| total_loss = 0 |
| correct = 0 |
| total = 0 |
| |
| for data, target in train_loader: |
| |
| data = data.to(device) |
| target = target.to(device) |
| |
| optimizer.zero_grad() |
| |
| output = model(data) |
| |
| loss = criterion(output, target) |
| |
| loss.backward() |
| |
| optimizer.step() |
| |
| total_loss += loss.item() |
| |
| _, predicted = output.max(1) |
| |
| total += target.size(0) |
| |
| correct += predicted.eq(target).sum().item() |
| |
| scheduler.step() |
| |
| acc = 100. * correct / total |
| |
| current_lr = optimizer.param_groups[0]['lr'] |
| |
| print( |
| f"Epoch {epoch+1:02d}/100 | " |
| f"Train Loss: " |
| f"{total_loss / len(train_loader):.4f} | " |
| f"Train Acc: {acc:.2f}% | " |
| f"LR: {current_lr:.6f}" |
| ) |
| |
| # ----------------------------------------------------- |
| # Periodic checkpoint save |
| # ----------------------------------------------------- |
| if (epoch + 1) % 5 == 0: |
| |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| }, checkpoint_path) |
| |
| print( |
| f"[CHECKPOINT] " |
| f"Epoch {epoch+1} saved successfully." |
| ) |
| |
| |
| # ========================================================= |
| # 6. FINAL VALIDATION |
| # ========================================================= |
| |
| model.eval() |
| |
| test_loss = 0 |
| test_correct = 0 |
| test_total = 0 |
| |
| print("\nStarting final validation...") |
| |
| with torch.no_grad(): |
| |
| for data, target in val_loader: |
| |
| data = data.to(device) |
| target = target.to(device) |
| |
| output = model(data) |
| |
| loss = criterion(output, target) |
| |
| test_loss += loss.item() |
| |
| _, predicted = output.max(1) |
| |
| test_total += target.size(0) |
| |
| test_correct += predicted.eq(target).sum().item() |
| |
| final_test_acc = 100. * test_correct / test_total |
| |
| print("=== FINAL LOOKTHEM STL V1 RESULTS ===") |
| |
| print( |
| f"Test Loss: " |
| f"{test_loss / len(val_loader):.4f} | " |
| f"Test Accuracy: {final_test_acc:.2f}%" |
| ) |
| |
| # Save final trained weights |
| torch.save(model.state_dict(), "LookThem_STL.pth") |
| |
| print( |
| f"Training complete! " |
| f"Final model size: " |
| f"{os.path.getsize('LookThem_STL.pth') / (1024*1024):.2f} MB" |
| ) |
| ``` |
|
|
| ## Inference |
| ```python |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| |
| from PIL import Image |
| import math |
| |
| |
| # ========================================================= |
| # 1. LOOKTHEM CORE LAYER |
| # ========================================================= |
| |
| class LookThemLayer(nn.Module): |
| """ |
| Relational token-processing layer used by |
| the LookThem STL architecture. |
| """ |
| |
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super(LookThemLayer, self).__init__() |
| |
| self.num_tokens = num_tokens |
| self.in_features = in_features |
| |
| # ------------------------------------------------- |
| # Branch 1 |
| # ------------------------------------------------- |
| self.mod1_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
| |
| self.mod1_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
| |
| self.mod1_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
| |
| self.mod1_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| # ------------------------------------------------- |
| # Branch 2 |
| # ------------------------------------------------- |
| self.mod2_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
| |
| self.mod2_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
| |
| self.mod2_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
| |
| self.mod2_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| # ------------------------------------------------- |
| # Relational transformation |
| # ------------------------------------------------- |
| self.trans_w = nn.Parameter( |
| torch.randn(num_tokens, 1, 1) |
| ) |
| |
| self.trans_b = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| |
| for w in [ |
| self.mod1_w1, |
| self.mod2_w1, |
| self.mod1_w2, |
| self.mod2_w2, |
| self.trans_w |
| ]: |
| nn.init.kaiming_uniform_( |
| w, |
| a=math.sqrt(5) |
| ) |
| |
| def forward(self, x): |
| |
| N = self.num_tokens |
| |
| # ================================================= |
| # Branch 1 |
| # ================================================= |
| h1 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod1_w1 |
| ) |
| + self.mod1_b1 |
| ) |
| |
| out_m1 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h1), |
| self.mod1_w2 |
| ) |
| + self.mod1_b2 |
| ) |
| |
| # ================================================= |
| # Branch 2 |
| # ================================================= |
| h2 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod2_w1 |
| ) |
| + self.mod2_b1 |
| ) |
| |
| out_m2 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h2), |
| self.mod2_w2 |
| ) |
| + self.mod2_b2 |
| ) |
| |
| # Numerical stabilization |
| out_m2_safe = out_m2 + 1e-5 |
| |
| # ================================================= |
| # Pairwise comparison |
| # ================================================= |
| compare = torch.tanh( |
| out_m1.unsqueeze(2) / |
| out_m2_safe.unsqueeze(1) |
| ) |
| |
| compare2 = torch.tanh( |
| out_m1.unsqueeze(1) / |
| out_m2_safe.unsqueeze(2) |
| ) |
| |
| # ================================================= |
| # Relational transformation |
| # ================================================= |
| bias_reshaped = self.trans_b.view( |
| 1, |
| 1, |
| N, |
| 1 |
| ) |
| |
| trans_compare = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
| |
| trans_compare2 = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare2, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
| |
| # ================================================= |
| # Interaction fusion |
| # ================================================= |
| interaction = ( |
| trans_compare * x.unsqueeze(2) |
| + trans_compare2 * x.unsqueeze(1) |
| ) / 2 |
| |
| # Remove self-interaction |
| mask = 1.0 - torch.eye( |
| N, |
| device=x.device |
| ) |
| |
| interaction_masked = ( |
| interaction * |
| mask.view(1, N, N, 1) |
| ) |
| |
| return ( |
| interaction_masked.sum(dim=2) |
| / (N - 1.0) |
| ) |
| |
| |
| # ========================================================= |
| # 2. LOOKTHEM STL MODEL |
| # ========================================================= |
| |
| class LookThemSTLV1(nn.Module): |
| |
| def __init__(self): |
| super(LookThemSTLV1, self).__init__() |
| |
| # ================================================= |
| # STREAM A — MACRO STRUCTURE |
| # ================================================= |
| self.stream_a = nn.Sequential( |
| |
| nn.Conv2d( |
| 3, |
| 16, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 16, |
| 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 32, |
| 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
| |
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
| |
| # ================================================= |
| # STREAM B — MICRO DETAIL |
| # ================================================= |
| self.stream_b = nn.Sequential( |
| |
| nn.Conv2d( |
| 3, |
| 16, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 16, |
| 32, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 32, |
| 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(64), |
| nn.GELU(), |
| |
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
| |
| # ================================================= |
| # RELATIONAL PROCESSORS |
| # ================================================= |
| self.lookthemA = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=16 |
| ) |
| |
| self.lookthemB = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=16 |
| ) |
| |
| self.lookthem = LookThemLayer( |
| num_tokens=64, |
| in_features=128, |
| hidden_dim=32 |
| ) |
| |
| # ================================================= |
| # TOKEN COMPRESSOR |
| # ================================================= |
| self.compressor = nn.AdaptiveAvgPool1d(32) |
| |
| # ================================================= |
| # CLASSIFIER HEAD |
| # ================================================= |
| self.classifier = nn.Sequential( |
| |
| nn.Flatten(), |
| |
| nn.Linear(64 * 32, 512), |
| nn.ReLU(), |
| nn.Dropout(0.4), |
| |
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| |
| nn.Linear(256, 10) |
| ) |
| |
| def forward(self, x): |
| |
| batch_size = x.size(0) |
| |
| # ================================================= |
| # STREAM A |
| # ================================================= |
| feat_a = self.stream_a(x) |
| |
| feat_a_flat = feat_a.view( |
| batch_size, |
| 64, |
| 64 |
| ) |
| |
| feat_a_tokens = feat_a_flat.transpose(1, 2) |
| |
| feat_a_lt = self.lookthemA(feat_a_tokens) |
| |
| # ================================================= |
| # STREAM B |
| # ================================================= |
| feat_b = self.stream_b(x) |
| |
| feat_b_tokens = ( |
| feat_b |
| .view(batch_size, 64, 64) |
| .transpose(1, 2) |
| ) |
| |
| feat_b_lt = self.lookthemB(feat_b_tokens) |
| |
| # ================================================= |
| # FEATURE FUSION |
| # ================================================= |
| tokens_combined = torch.cat( |
| [feat_a_lt, feat_b_lt], |
| dim=2 |
| ) |
| |
| # ================================================= |
| # RELATIONAL COGNITION |
| # ================================================= |
| out_lookthem = self.lookthem(tokens_combined) |
| |
| compressed = self.compressor(out_lookthem) |
| |
| return self.classifier(compressed) |
| |
| |
| # ========================================================= |
| # 3. DEVICE SETUP |
| # ========================================================= |
| |
| device = torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| |
| print(f"Using device: {device}") |
| |
| |
| # ========================================================= |
| # 4. CLASS LABELS |
| # ========================================================= |
| |
| classes = [ |
| "airplane", |
| "bird", |
| "car", |
| "cat", |
| "deer", |
| "dog", |
| "horse", |
| "monkey", |
| "ship", |
| "truck" |
| ] |
| |
| |
| # ========================================================= |
| # 5. IMAGE TRANSFORM |
| # ========================================================= |
| |
| transform = transforms.Compose([ |
| |
| transforms.Resize((96, 96)), |
| |
| transforms.ToTensor(), |
| |
| transforms.Normalize( |
| (0.4914, 0.4822, 0.4465), |
| (0.2470, 0.2435, 0.2616) |
| ) |
| ]) |
| |
| |
| # ========================================================= |
| # 6. LOAD MODEL |
| # ========================================================= |
| |
| model = LookThemSTLV1().to(device) |
| |
| model.load_state_dict( |
| torch.load( |
| "LookThem_STL.pth", |
| map_location=device |
| ) |
| ) |
| |
| model.eval() |
| |
| print("Model loaded successfully!") |
| |
| |
| # ========================================================= |
| # 7. LOAD IMAGE |
| # ========================================================= |
| |
| # Replace with your image path |
| image_path = "test.jpg" |
| |
| image = Image.open(image_path).convert("RGB") |
| |
| input_tensor = transform(image) |
| |
| # Add batch dimension |
| input_tensor = input_tensor.unsqueeze(0).to(device) |
| |
| |
| # ========================================================= |
| # 8. INFERENCE |
| # ========================================================= |
| |
| with torch.no_grad(): |
| |
| output = model(input_tensor) |
| |
| probabilities = F.softmax(output, dim=1) |
| |
| confidence, predicted = torch.max( |
| probabilities, |
| dim=1 |
| ) |
| |
| predicted_class = classes[predicted.item()] |
| |
| confidence_score = confidence.item() * 100 |
| |
| |
| # ========================================================= |
| # 9. RESULT |
| # ========================================================= |
| |
| print("\n===== INFERENCE RESULT =====") |
| |
| print(f"Predicted Class : {predicted_class}") |
| |
| print(f"Confidence : {confidence_score:.2f}%") |
| |
| print("\n===== CLASS PROBABILITIES =====") |
| |
| for idx, class_name in enumerate(classes): |
| |
| prob = probabilities[0][idx].item() * 100 |
| |
| print(f"{class_name:<10} : {prob:.2f}%") |
| ``` |