# Cleaned code ## Training ```python import os import math import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader # ========================================================= # 1. DATA PREPARATION # ========================================================= # Training augmentation and normalization pipeline. # STL10 images are already 96x96, so no resize is required. transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) ) ]) # Validation / test preprocessing pipeline. # Only normalization is applied for evaluation consistency. transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) ) ]) # ========================================================= # 2. STL10 DATASET LOADING # ========================================================= # Automatically downloads STL10 into ./data train_dataset = torchvision.datasets.STL10( root='./data', split='train', download=True, transform=transform_train ) test_dataset = torchvision.datasets.STL10( root='./data', split='test', download=True, transform=transform_test ) # Data loaders for batch training and validation train_loader = DataLoader( train_dataset, batch_size=64, shuffle=True, num_workers=2 ) val_loader = DataLoader( test_dataset, batch_size=64, shuffle=False, num_workers=2 ) print(f"Training samples : {len(train_dataset)}") print(f"Testing samples : {len(test_dataset)}") # ========================================================= # 3. CORE RELATIONAL LAYER — LOOKTHEM LAYER # ========================================================= class LookThemLayer(nn.Module): """ Relational token-processing layer. Each token owns its own tiny dual-network pair: - mod1 - mod2 The outputs from both branches are compared against every other token using ratio-based interaction maps. Final interactions are transformed and redistributed back into the token space. """ def __init__(self, num_tokens, in_features, hidden_dim): super(LookThemLayer, self).__init__() self.num_tokens = num_tokens self.in_features = in_features # ------------------------------------------------- # Branch 1 parameters # ------------------------------------------------- self.mod1_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod1_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod1_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod1_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) # ------------------------------------------------- # Branch 2 parameters # ------------------------------------------------- self.mod2_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod2_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod2_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod2_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) # ------------------------------------------------- # Relational transformation parameters # ------------------------------------------------- self.trans_w = nn.Parameter( torch.randn(num_tokens, 1, 1) ) self.trans_b = nn.Parameter( torch.zeros(num_tokens, 1) ) self._init_weights() def _init_weights(self): """ Kaiming initialization for all learnable projections. """ for w in [ self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w ]: nn.init.kaiming_uniform_(w, a=math.sqrt(5)) def forward(self, x): """ Input shape: [B, Tokens, Features] Output shape: [B, Tokens, Features] """ N = self.num_tokens # ================================================= # Branch 1 forward pass # ================================================= h1 = ( torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1 ) out_m1 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h1), self.mod1_w2 ) + self.mod1_b2 ) # ================================================= # Branch 2 forward pass # ================================================= h2 = ( torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1 ) out_m2 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h2), self.mod2_w2 ) + self.mod2_b2 ) # Numerical stabilization out_m2_safe = out_m2 + 1e-5 # ================================================= # Pairwise relational comparison # ================================================= # Token-to-token directional comparison compare = torch.tanh( out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) ) # Reverse-direction comparison compare2 = torch.tanh( out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) ) # ================================================= # Transform relational maps # ================================================= bias_reshaped = self.trans_b.view(1, 1, N, 1) trans_compare = ( torch.einsum( 'bije,jef->bijf', compare, self.trans_w ) + bias_reshaped ) trans_compare2 = ( torch.einsum( 'bije,jef->bijf', compare2, self.trans_w ) + bias_reshaped ) # ================================================= # Bidirectional interaction fusion # ================================================= interaction = ( trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1) ) / 2 # Remove self-interaction mask = 1.0 - torch.eye(N, device=x.device) interaction_masked = ( interaction * mask.view(1, N, N, 1) ) # Aggregate all external token interactions return interaction_masked.sum(dim=2) / (N - 1.0) # ========================================================= # 4. MAIN ARCHITECTURE — LOOKTHEM STL V1 # ========================================================= class LookThemSTLV1(nn.Module): """ Dual-stream relational vision architecture. Stream A: Macro-spatial extraction using aggressive downsampling. Stream B: Higher-detail extraction using slower reduction. Both streams are fused inside relational LookThem layers. """ def __init__(self): super(LookThemSTLV1, self).__init__() # ================================================= # STREAM A — MACRO STRUCTURE STREAM # ================================================= # # Aggressive downsampling path focused on # large-scale spatial structure extraction. # self.stream_a = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), # Final spatial alignment nn.AdaptiveMaxPool2d((8, 8)) ) # ================================================= # STREAM B — MICRO DETAIL STREAM # ================================================= # # Slower reduction preserves more local detail # before relational processing. # self.stream_b = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), # Match Stream A token resolution nn.AdaptiveMaxPool2d((8, 8)) ) # ================================================= # STREAM-SPECIFIC RELATIONAL PROCESSORS # ================================================= self.lookthemA = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=16 ) self.lookthemB = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=16 ) # ================================================= # FUSION RELATIONAL PROCESSOR # ================================================= # # Receives concatenated features from both streams. # self.lookthem = LookThemLayer( num_tokens=64, in_features=128, hidden_dim=32 ) # ================================================= # TOKEN COMPRESSOR # ================================================= # # Compresses token feature width before # dense classification. # self.compressor = nn.AdaptiveAvgPool1d(32) # ================================================= # CLASSIFIER HEAD # ================================================= # # Progressive dense head with dropout # regularization to reduce overfitting. # self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(64 * 32, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 10) ) def forward(self, x): batch_size = x.size(0) # ================================================= # STREAM A FORWARD PASS # ================================================= feat_a = self.stream_a(x) # Convert spatial map into token representation feat_a_flat = feat_a.view(batch_size, 64, 64) feat_a_tokens = feat_a_flat.transpose(1, 2) # Relational processing feat_a_lt = self.lookthemA(feat_a_tokens) # ================================================= # STREAM B FORWARD PASS # ================================================= feat_b = self.stream_b(x) feat_b_tokens = ( feat_b .view(batch_size, 64, 64) .transpose(1, 2) ) feat_b_lt = self.lookthemB(feat_b_tokens) # ================================================= # ASYMMETRIC FEATURE-LEVEL FUSION # ================================================= # # Keeps token count fixed while expanding # feature dimensionality. # tokens_combined = torch.cat( [feat_a_lt, feat_b_lt], dim=2 ) # ================================================= # FINAL RELATIONAL COGNITION # ================================================= out_lookthem = self.lookthem(tokens_combined) # Token compression compressed = self.compressor(out_lookthem) # Final classification return self.classifier(compressed) # ========================================================= # 5. TRAINING RUNTIME + CHECKPOINT SYSTEM # ========================================================= device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) model = LookThemSTLV1().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( model.parameters(), lr=0.001, weight_decay=1e-4 ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=40 ) start_epoch = 0 checkpoint_path = "lookthem_stl_checkpoint.pth" # ========================================================= # CHECKPOINT RESUME # ========================================================= if os.path.exists(checkpoint_path): print( "Checkpoint detected. " "Resuming previous experiment..." ) checkpoint = torch.load(checkpoint_path) model.load_state_dict( checkpoint['model_state_dict'] ) optimizer.load_state_dict( checkpoint['optimizer_state_dict'] ) scheduler.load_state_dict( checkpoint['scheduler_state_dict'] ) start_epoch = checkpoint['epoch'] print( f"Successfully resumed from " f"epoch {start_epoch + 1}" ) print( f"Starting LookThem STL V1 training on {device}..." ) # ========================================================= # TRAINING LOOP # ========================================================= for epoch in range(start_epoch, 100): model.train() total_loss = 0 correct = 0 total = 0 for data, target in train_loader: data = data.to(device) target = target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() scheduler.step() acc = 100. * correct / total current_lr = optimizer.param_groups[0]['lr'] print( f"Epoch {epoch+1:02d}/100 | " f"Train Loss: " f"{total_loss / len(train_loader):.4f} | " f"Train Acc: {acc:.2f}% | " f"LR: {current_lr:.6f}" ) # ----------------------------------------------------- # Periodic checkpoint save # ----------------------------------------------------- if (epoch + 1) % 5 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, checkpoint_path) print( f"[CHECKPOINT] " f"Epoch {epoch+1} saved successfully." ) # ========================================================= # 6. FINAL VALIDATION # ========================================================= model.eval() test_loss = 0 test_correct = 0 test_total = 0 print("\nStarting final validation...") with torch.no_grad(): for data, target in val_loader: data = data.to(device) target = target.to(device) output = model(data) loss = criterion(output, target) test_loss += loss.item() _, predicted = output.max(1) test_total += target.size(0) test_correct += predicted.eq(target).sum().item() final_test_acc = 100. * test_correct / test_total print("=== FINAL LOOKTHEM STL V1 RESULTS ===") print( f"Test Loss: " f"{test_loss / len(val_loader):.4f} | " f"Test Accuracy: {final_test_acc:.2f}%" ) # Save final trained weights torch.save(model.state_dict(), "LookThem_STL.pth") print( f"Training complete! " f"Final model size: " f"{os.path.getsize('LookThem_STL.pth') / (1024*1024):.2f} MB" ) ``` ## Inference ```python import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import math # ========================================================= # 1. LOOKTHEM CORE LAYER # ========================================================= class LookThemLayer(nn.Module): """ Relational token-processing layer used by the LookThem STL architecture. """ def __init__(self, num_tokens, in_features, hidden_dim): super(LookThemLayer, self).__init__() self.num_tokens = num_tokens self.in_features = in_features # ------------------------------------------------- # Branch 1 # ------------------------------------------------- self.mod1_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod1_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod1_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod1_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) # ------------------------------------------------- # Branch 2 # ------------------------------------------------- self.mod2_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod2_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod2_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod2_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) # ------------------------------------------------- # Relational transformation # ------------------------------------------------- self.trans_w = nn.Parameter( torch.randn(num_tokens, 1, 1) ) self.trans_b = nn.Parameter( torch.zeros(num_tokens, 1) ) self._init_weights() def _init_weights(self): for w in [ self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w ]: nn.init.kaiming_uniform_( w, a=math.sqrt(5) ) def forward(self, x): N = self.num_tokens # ================================================= # Branch 1 # ================================================= h1 = ( torch.einsum( 'bti,tij->btj', x, self.mod1_w1 ) + self.mod1_b1 ) out_m1 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h1), self.mod1_w2 ) + self.mod1_b2 ) # ================================================= # Branch 2 # ================================================= h2 = ( torch.einsum( 'bti,tij->btj', x, self.mod2_w1 ) + self.mod2_b1 ) out_m2 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h2), self.mod2_w2 ) + self.mod2_b2 ) # Numerical stabilization out_m2_safe = out_m2 + 1e-5 # ================================================= # Pairwise comparison # ================================================= compare = torch.tanh( out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) ) compare2 = torch.tanh( out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) ) # ================================================= # Relational transformation # ================================================= bias_reshaped = self.trans_b.view( 1, 1, N, 1 ) trans_compare = ( torch.einsum( 'bije,jef->bijf', compare, self.trans_w ) + bias_reshaped ) trans_compare2 = ( torch.einsum( 'bije,jef->bijf', compare2, self.trans_w ) + bias_reshaped ) # ================================================= # Interaction fusion # ================================================= interaction = ( trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1) ) / 2 # Remove self-interaction mask = 1.0 - torch.eye( N, device=x.device ) interaction_masked = ( interaction * mask.view(1, N, N, 1) ) return ( interaction_masked.sum(dim=2) / (N - 1.0) ) # ========================================================= # 2. LOOKTHEM STL MODEL # ========================================================= class LookThemSTLV1(nn.Module): def __init__(self): super(LookThemSTLV1, self).__init__() # ================================================= # STREAM A — MACRO STRUCTURE # ================================================= self.stream_a = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) # ================================================= # STREAM B — MICRO DETAIL # ================================================= self.stream_b = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) # ================================================= # RELATIONAL PROCESSORS # ================================================= self.lookthemA = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=16 ) self.lookthemB = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=16 ) self.lookthem = LookThemLayer( num_tokens=64, in_features=128, hidden_dim=32 ) # ================================================= # TOKEN COMPRESSOR # ================================================= self.compressor = nn.AdaptiveAvgPool1d(32) # ================================================= # CLASSIFIER HEAD # ================================================= self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(64 * 32, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 10) ) def forward(self, x): batch_size = x.size(0) # ================================================= # STREAM A # ================================================= feat_a = self.stream_a(x) feat_a_flat = feat_a.view( batch_size, 64, 64 ) feat_a_tokens = feat_a_flat.transpose(1, 2) feat_a_lt = self.lookthemA(feat_a_tokens) # ================================================= # STREAM B # ================================================= feat_b = self.stream_b(x) feat_b_tokens = ( feat_b .view(batch_size, 64, 64) .transpose(1, 2) ) feat_b_lt = self.lookthemB(feat_b_tokens) # ================================================= # FEATURE FUSION # ================================================= tokens_combined = torch.cat( [feat_a_lt, feat_b_lt], dim=2 ) # ================================================= # RELATIONAL COGNITION # ================================================= out_lookthem = self.lookthem(tokens_combined) compressed = self.compressor(out_lookthem) return self.classifier(compressed) # ========================================================= # 3. DEVICE SETUP # ========================================================= device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) print(f"Using device: {device}") # ========================================================= # 4. CLASS LABELS # ========================================================= classes = [ "airplane", "bird", "car", "cat", "deer", "dog", "horse", "monkey", "ship", "truck" ] # ========================================================= # 5. IMAGE TRANSFORM # ========================================================= transform = transforms.Compose([ transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) ) ]) # ========================================================= # 6. LOAD MODEL # ========================================================= model = LookThemSTLV1().to(device) model.load_state_dict( torch.load( "LookThem_STL.pth", map_location=device ) ) model.eval() print("Model loaded successfully!") # ========================================================= # 7. LOAD IMAGE # ========================================================= # Replace with your image path image_path = "test.jpg" image = Image.open(image_path).convert("RGB") input_tensor = transform(image) # Add batch dimension input_tensor = input_tensor.unsqueeze(0).to(device) # ========================================================= # 8. INFERENCE # ========================================================= with torch.no_grad(): output = model(input_tensor) probabilities = F.softmax(output, dim=1) confidence, predicted = torch.max( probabilities, dim=1 ) predicted_class = classes[predicted.item()] confidence_score = confidence.item() * 100 # ========================================================= # 9. RESULT # ========================================================= print("\n===== INFERENCE RESULT =====") print(f"Predicted Class : {predicted_class}") print(f"Confidence : {confidence_score:.2f}%") print("\n===== CLASS PROBABILITIES =====") for idx, class_name in enumerate(classes): prob = probabilities[0][idx].item() * 100 print(f"{class_name:<10} : {prob:.2f}%") ```