LookThem_STL-10 / CleanedCode.md
ASomeoneWhoInterestedWithAI's picture
Update CleanedCode.md
766a703 verified

Cleaned code

Training

import os
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader


# =========================================================
# 1. DATA PREPARATION
# =========================================================

# Training augmentation and normalization pipeline.
# STL10 images are already 96x96, so no resize is required.
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2470, 0.2435, 0.2616)
    )
])

# Validation / test preprocessing pipeline.
# Only normalization is applied for evaluation consistency.
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2470, 0.2435, 0.2616)
    )
])


# =========================================================
# 2. STL10 DATASET LOADING
# =========================================================

# Automatically downloads STL10 into ./data
train_dataset = torchvision.datasets.STL10(
    root='./data',
    split='train',
    download=True,
    transform=transform_train
)

test_dataset = torchvision.datasets.STL10(
    root='./data',
    split='test',
    download=True,
    transform=transform_test
)

# Data loaders for batch training and validation
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=2
)

print(f"Training samples : {len(train_dataset)}")
print(f"Testing samples  : {len(test_dataset)}")


# =========================================================
# 3. CORE RELATIONAL LAYER — LOOKTHEM LAYER
# =========================================================

class LookThemLayer(nn.Module):
    """
    Relational token-processing layer.

    Each token owns its own tiny dual-network pair:
        - mod1
        - mod2

    The outputs from both branches are compared against
    every other token using ratio-based interaction maps.

    Final interactions are transformed and redistributed
    back into the token space.
    """

    def __init__(self, num_tokens, in_features, hidden_dim):
        super(LookThemLayer, self).__init__()

        self.num_tokens = num_tokens
        self.in_features = in_features

        # -------------------------------------------------
        # Branch 1 parameters
        # -------------------------------------------------
        self.mod1_w1 = nn.Parameter(
            torch.randn(num_tokens, in_features, hidden_dim)
        )

        self.mod1_b1 = nn.Parameter(
            torch.zeros(num_tokens, hidden_dim)
        )

        self.mod1_w2 = nn.Parameter(
            torch.randn(num_tokens, hidden_dim, 1)
        )

        self.mod1_b2 = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        # -------------------------------------------------
        # Branch 2 parameters
        # -------------------------------------------------
        self.mod2_w1 = nn.Parameter(
            torch.randn(num_tokens, in_features, hidden_dim)
        )

        self.mod2_b1 = nn.Parameter(
            torch.zeros(num_tokens, hidden_dim)
        )

        self.mod2_w2 = nn.Parameter(
            torch.randn(num_tokens, hidden_dim, 1)
        )

        self.mod2_b2 = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        # -------------------------------------------------
        # Relational transformation parameters
        # -------------------------------------------------
        self.trans_w = nn.Parameter(
            torch.randn(num_tokens, 1, 1)
        )

        self.trans_b = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        self._init_weights()

    def _init_weights(self):
        """
        Kaiming initialization for all learnable projections.
        """

        for w in [
            self.mod1_w1,
            self.mod2_w1,
            self.mod1_w2,
            self.mod2_w2,
            self.trans_w
        ]:
            nn.init.kaiming_uniform_(w, a=math.sqrt(5))

    def forward(self, x):
        """
        Input shape:
            [B, Tokens, Features]

        Output shape:
            [B, Tokens, Features]
        """

        N = self.num_tokens

        # =================================================
        # Branch 1 forward pass
        # =================================================
        h1 = (
            torch.einsum('bti,tij->btj', x, self.mod1_w1)
            + self.mod1_b1
        )

        out_m1 = (
            torch.einsum(
                'btj,tjk->btk',
                F.gelu(h1),
                self.mod1_w2
            )
            + self.mod1_b2
        )

        # =================================================
        # Branch 2 forward pass
        # =================================================
        h2 = (
            torch.einsum('bti,tij->btj', x, self.mod2_w1)
            + self.mod2_b1
        )

        out_m2 = (
            torch.einsum(
                'btj,tjk->btk',
                F.gelu(h2),
                self.mod2_w2
            )
            + self.mod2_b2
        )

        # Numerical stabilization
        out_m2_safe = out_m2 + 1e-5

        # =================================================
        # Pairwise relational comparison
        # =================================================

        # Token-to-token directional comparison
        compare = torch.tanh(
            out_m1.unsqueeze(2) /
            out_m2_safe.unsqueeze(1)
        )

        # Reverse-direction comparison
        compare2 = torch.tanh(
            out_m1.unsqueeze(1) /
            out_m2_safe.unsqueeze(2)
        )

        # =================================================
        # Transform relational maps
        # =================================================
        bias_reshaped = self.trans_b.view(1, 1, N, 1)

        trans_compare = (
            torch.einsum(
                'bije,jef->bijf',
                compare,
                self.trans_w
            )
            + bias_reshaped
        )

        trans_compare2 = (
            torch.einsum(
                'bije,jef->bijf',
                compare2,
                self.trans_w
            )
            + bias_reshaped
        )

        # =================================================
        # Bidirectional interaction fusion
        # =================================================
        interaction = (
            trans_compare * x.unsqueeze(2)
            + trans_compare2 * x.unsqueeze(1)
        ) / 2

        # Remove self-interaction
        mask = 1.0 - torch.eye(N, device=x.device)

        interaction_masked = (
            interaction * mask.view(1, N, N, 1)
        )

        # Aggregate all external token interactions
        return interaction_masked.sum(dim=2) / (N - 1.0)


# =========================================================
# 4. MAIN ARCHITECTURE — LOOKTHEM STL V1
# =========================================================

class LookThemSTLV1(nn.Module):
    """
    Dual-stream relational vision architecture.

    Stream A:
        Macro-spatial extraction using aggressive downsampling.

    Stream B:
        Higher-detail extraction using slower reduction.

    Both streams are fused inside relational LookThem layers.
    """

    def __init__(self):
        super(LookThemSTLV1, self).__init__()

        # =================================================
        # STREAM A — MACRO STRUCTURE STREAM
        # =================================================
        #
        # Aggressive downsampling path focused on
        # large-scale spatial structure extraction.
        #
        self.stream_a = nn.Sequential(

            nn.Conv2d(
                3, 16,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(16),
            nn.GELU(),

            nn.Conv2d(
                16, 32,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU(),

            nn.Conv2d(
                32, 64,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(64),
            nn.GELU(),

            # Final spatial alignment
            nn.AdaptiveMaxPool2d((8, 8))
        )

        # =================================================
        # STREAM B — MICRO DETAIL STREAM
        # =================================================
        #
        # Slower reduction preserves more local detail
        # before relational processing.
        #
        self.stream_b = nn.Sequential(

            nn.Conv2d(
                3, 16,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.BatchNorm2d(16),
            nn.GELU(),

            nn.Conv2d(
                16, 32,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU(),

            nn.Conv2d(
                32, 64,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(64),
            nn.GELU(),

            # Match Stream A token resolution
            nn.AdaptiveMaxPool2d((8, 8))
        )

        # =================================================
        # STREAM-SPECIFIC RELATIONAL PROCESSORS
        # =================================================
        self.lookthemA = LookThemLayer(
            num_tokens=64,
            in_features=64,
            hidden_dim=16
        )

        self.lookthemB = LookThemLayer(
            num_tokens=64,
            in_features=64,
            hidden_dim=16
        )

        # =================================================
        # FUSION RELATIONAL PROCESSOR
        # =================================================
        #
        # Receives concatenated features from both streams.
        #
        self.lookthem = LookThemLayer(
            num_tokens=64,
            in_features=128,
            hidden_dim=32
        )

        # =================================================
        # TOKEN COMPRESSOR
        # =================================================
        #
        # Compresses token feature width before
        # dense classification.
        #
        self.compressor = nn.AdaptiveAvgPool1d(32)

        # =================================================
        # CLASSIFIER HEAD
        # =================================================
        #
        # Progressive dense head with dropout
        # regularization to reduce overfitting.
        #
        self.classifier = nn.Sequential(
            nn.Flatten(),

            nn.Linear(64 * 32, 512),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(256, 10)
        )

    def forward(self, x):

        batch_size = x.size(0)

        # =================================================
        # STREAM A FORWARD PASS
        # =================================================
        feat_a = self.stream_a(x)

        # Convert spatial map into token representation
        feat_a_flat = feat_a.view(batch_size, 64, 64)

        feat_a_tokens = feat_a_flat.transpose(1, 2)

        # Relational processing
        feat_a_lt = self.lookthemA(feat_a_tokens)

        # =================================================
        # STREAM B FORWARD PASS
        # =================================================
        feat_b = self.stream_b(x)

        feat_b_tokens = (
            feat_b
            .view(batch_size, 64, 64)
            .transpose(1, 2)
        )

        feat_b_lt = self.lookthemB(feat_b_tokens)

        # =================================================
        # ASYMMETRIC FEATURE-LEVEL FUSION
        # =================================================
        #
        # Keeps token count fixed while expanding
        # feature dimensionality.
        #
        tokens_combined = torch.cat(
            [feat_a_lt, feat_b_lt],
            dim=2
        )

        # =================================================
        # FINAL RELATIONAL COGNITION
        # =================================================
        out_lookthem = self.lookthem(tokens_combined)

        # Token compression
        compressed = self.compressor(out_lookthem)

        # Final classification
        return self.classifier(compressed)


# =========================================================
# 5. TRAINING RUNTIME + CHECKPOINT SYSTEM
# =========================================================

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

model = LookThemSTLV1().to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=40
)

start_epoch = 0
checkpoint_path = "lookthem_stl_checkpoint.pth"


# =========================================================
# CHECKPOINT RESUME
# =========================================================

if os.path.exists(checkpoint_path):

    print(
        "Checkpoint detected. "
        "Resuming previous experiment..."
    )

    checkpoint = torch.load(checkpoint_path)

    model.load_state_dict(
        checkpoint['model_state_dict']
    )

    optimizer.load_state_dict(
        checkpoint['optimizer_state_dict']
    )

    scheduler.load_state_dict(
        checkpoint['scheduler_state_dict']
    )

    start_epoch = checkpoint['epoch']

    print(
        f"Successfully resumed from "
        f"epoch {start_epoch + 1}"
    )

print(
    f"Starting LookThem STL V1 training on {device}..."
)


# =========================================================
# TRAINING LOOP
# =========================================================

for epoch in range(start_epoch, 100):

    model.train()

    total_loss = 0
    correct = 0
    total = 0

    for data, target in train_loader:

        data = data.to(device)
        target = target.to(device)

        optimizer.zero_grad()

        output = model(data)

        loss = criterion(output, target)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        _, predicted = output.max(1)

        total += target.size(0)

        correct += predicted.eq(target).sum().item()

    scheduler.step()

    acc = 100. * correct / total

    current_lr = optimizer.param_groups[0]['lr']

    print(
        f"Epoch {epoch+1:02d}/100 | "
        f"Train Loss: "
        f"{total_loss / len(train_loader):.4f} | "
        f"Train Acc: {acc:.2f}% | "
        f"LR: {current_lr:.6f}"
    )

    # -----------------------------------------------------
    # Periodic checkpoint save
    # -----------------------------------------------------
    if (epoch + 1) % 5 == 0:

        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }, checkpoint_path)

        print(
            f"[CHECKPOINT] "
            f"Epoch {epoch+1} saved successfully."
        )


# =========================================================
# 6. FINAL VALIDATION
# =========================================================

model.eval()

test_loss = 0
test_correct = 0
test_total = 0

print("\nStarting final validation...")

with torch.no_grad():

    for data, target in val_loader:

        data = data.to(device)
        target = target.to(device)

        output = model(data)

        loss = criterion(output, target)

        test_loss += loss.item()

        _, predicted = output.max(1)

        test_total += target.size(0)

        test_correct += predicted.eq(target).sum().item()

final_test_acc = 100. * test_correct / test_total

print("=== FINAL LOOKTHEM STL V1 RESULTS ===")

print(
    f"Test Loss: "
    f"{test_loss / len(val_loader):.4f} | "
    f"Test Accuracy: {final_test_acc:.2f}%"
)

# Save final trained weights
torch.save(model.state_dict(), "LookThem_STL.pth")

print(
    f"Training complete! "
    f"Final model size: "
    f"{os.path.getsize('LookThem_STL.pth') / (1024*1024):.2f} MB"
)

Inference

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from PIL import Image
import math


# =========================================================
# 1. LOOKTHEM CORE LAYER
# =========================================================

class LookThemLayer(nn.Module):
    """
    Relational token-processing layer used by
    the LookThem STL architecture.
    """

    def __init__(self, num_tokens, in_features, hidden_dim):
        super(LookThemLayer, self).__init__()

        self.num_tokens = num_tokens
        self.in_features = in_features

        # -------------------------------------------------
        # Branch 1
        # -------------------------------------------------
        self.mod1_w1 = nn.Parameter(
            torch.randn(num_tokens, in_features, hidden_dim)
        )

        self.mod1_b1 = nn.Parameter(
            torch.zeros(num_tokens, hidden_dim)
        )

        self.mod1_w2 = nn.Parameter(
            torch.randn(num_tokens, hidden_dim, 1)
        )

        self.mod1_b2 = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        # -------------------------------------------------
        # Branch 2
        # -------------------------------------------------
        self.mod2_w1 = nn.Parameter(
            torch.randn(num_tokens, in_features, hidden_dim)
        )

        self.mod2_b1 = nn.Parameter(
            torch.zeros(num_tokens, hidden_dim)
        )

        self.mod2_w2 = nn.Parameter(
            torch.randn(num_tokens, hidden_dim, 1)
        )

        self.mod2_b2 = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        # -------------------------------------------------
        # Relational transformation
        # -------------------------------------------------
        self.trans_w = nn.Parameter(
            torch.randn(num_tokens, 1, 1)
        )

        self.trans_b = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        self._init_weights()

    def _init_weights(self):

        for w in [
            self.mod1_w1,
            self.mod2_w1,
            self.mod1_w2,
            self.mod2_w2,
            self.trans_w
        ]:
            nn.init.kaiming_uniform_(
                w,
                a=math.sqrt(5)
            )

    def forward(self, x):

        N = self.num_tokens

        # =================================================
        # Branch 1
        # =================================================
        h1 = (
            torch.einsum(
                'bti,tij->btj',
                x,
                self.mod1_w1
            )
            + self.mod1_b1
        )

        out_m1 = (
            torch.einsum(
                'btj,tjk->btk',
                F.gelu(h1),
                self.mod1_w2
            )
            + self.mod1_b2
        )

        # =================================================
        # Branch 2
        # =================================================
        h2 = (
            torch.einsum(
                'bti,tij->btj',
                x,
                self.mod2_w1
            )
            + self.mod2_b1
        )

        out_m2 = (
            torch.einsum(
                'btj,tjk->btk',
                F.gelu(h2),
                self.mod2_w2
            )
            + self.mod2_b2
        )

        # Numerical stabilization
        out_m2_safe = out_m2 + 1e-5

        # =================================================
        # Pairwise comparison
        # =================================================
        compare = torch.tanh(
            out_m1.unsqueeze(2) /
            out_m2_safe.unsqueeze(1)
        )

        compare2 = torch.tanh(
            out_m1.unsqueeze(1) /
            out_m2_safe.unsqueeze(2)
        )

        # =================================================
        # Relational transformation
        # =================================================
        bias_reshaped = self.trans_b.view(
            1,
            1,
            N,
            1
        )

        trans_compare = (
            torch.einsum(
                'bije,jef->bijf',
                compare,
                self.trans_w
            )
            + bias_reshaped
        )

        trans_compare2 = (
            torch.einsum(
                'bije,jef->bijf',
                compare2,
                self.trans_w
            )
            + bias_reshaped
        )

        # =================================================
        # Interaction fusion
        # =================================================
        interaction = (
            trans_compare * x.unsqueeze(2)
            + trans_compare2 * x.unsqueeze(1)
        ) / 2

        # Remove self-interaction
        mask = 1.0 - torch.eye(
            N,
            device=x.device
        )

        interaction_masked = (
            interaction *
            mask.view(1, N, N, 1)
        )

        return (
            interaction_masked.sum(dim=2)
            / (N - 1.0)
        )


# =========================================================
# 2. LOOKTHEM STL MODEL
# =========================================================

class LookThemSTLV1(nn.Module):

    def __init__(self):
        super(LookThemSTLV1, self).__init__()

        # =================================================
        # STREAM A — MACRO STRUCTURE
        # =================================================
        self.stream_a = nn.Sequential(

            nn.Conv2d(
                3,
                16,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(16),
            nn.GELU(),

            nn.Conv2d(
                16,
                32,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU(),

            nn.Conv2d(
                32,
                64,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(64),
            nn.GELU(),

            nn.AdaptiveMaxPool2d((8, 8))
        )

        # =================================================
        # STREAM B — MICRO DETAIL
        # =================================================
        self.stream_b = nn.Sequential(

            nn.Conv2d(
                3,
                16,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.BatchNorm2d(16),
            nn.GELU(),

            nn.Conv2d(
                16,
                32,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU(),

            nn.Conv2d(
                32,
                64,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(64),
            nn.GELU(),

            nn.AdaptiveMaxPool2d((8, 8))
        )

        # =================================================
        # RELATIONAL PROCESSORS
        # =================================================
        self.lookthemA = LookThemLayer(
            num_tokens=64,
            in_features=64,
            hidden_dim=16
        )

        self.lookthemB = LookThemLayer(
            num_tokens=64,
            in_features=64,
            hidden_dim=16
        )

        self.lookthem = LookThemLayer(
            num_tokens=64,
            in_features=128,
            hidden_dim=32
        )

        # =================================================
        # TOKEN COMPRESSOR
        # =================================================
        self.compressor = nn.AdaptiveAvgPool1d(32)

        # =================================================
        # CLASSIFIER HEAD
        # =================================================
        self.classifier = nn.Sequential(

            nn.Flatten(),

            nn.Linear(64 * 32, 512),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(256, 10)
        )

    def forward(self, x):

        batch_size = x.size(0)

        # =================================================
        # STREAM A
        # =================================================
        feat_a = self.stream_a(x)

        feat_a_flat = feat_a.view(
            batch_size,
            64,
            64
        )

        feat_a_tokens = feat_a_flat.transpose(1, 2)

        feat_a_lt = self.lookthemA(feat_a_tokens)

        # =================================================
        # STREAM B
        # =================================================
        feat_b = self.stream_b(x)

        feat_b_tokens = (
            feat_b
            .view(batch_size, 64, 64)
            .transpose(1, 2)
        )

        feat_b_lt = self.lookthemB(feat_b_tokens)

        # =================================================
        # FEATURE FUSION
        # =================================================
        tokens_combined = torch.cat(
            [feat_a_lt, feat_b_lt],
            dim=2
        )

        # =================================================
        # RELATIONAL COGNITION
        # =================================================
        out_lookthem = self.lookthem(tokens_combined)

        compressed = self.compressor(out_lookthem)

        return self.classifier(compressed)


# =========================================================
# 3. DEVICE SETUP
# =========================================================

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

print(f"Using device: {device}")


# =========================================================
# 4. CLASS LABELS
# =========================================================

classes = [
    "airplane",
    "bird",
    "car",
    "cat",
    "deer",
    "dog",
    "horse",
    "monkey",
    "ship",
    "truck"
]


# =========================================================
# 5. IMAGE TRANSFORM
# =========================================================

transform = transforms.Compose([

    transforms.Resize((96, 96)),

    transforms.ToTensor(),

    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2470, 0.2435, 0.2616)
    )
])


# =========================================================
# 6. LOAD MODEL
# =========================================================

model = LookThemSTLV1().to(device)

model.load_state_dict(
    torch.load(
        "LookThem_STL.pth",
        map_location=device
    )
)

model.eval()

print("Model loaded successfully!")


# =========================================================
# 7. LOAD IMAGE
# =========================================================

# Replace with your image path
image_path = "test.jpg"

image = Image.open(image_path).convert("RGB")

input_tensor = transform(image)

# Add batch dimension
input_tensor = input_tensor.unsqueeze(0).to(device)


# =========================================================
# 8. INFERENCE
# =========================================================

with torch.no_grad():

    output = model(input_tensor)

    probabilities = F.softmax(output, dim=1)

    confidence, predicted = torch.max(
        probabilities,
        dim=1
    )

predicted_class = classes[predicted.item()]

confidence_score = confidence.item() * 100


# =========================================================
# 9. RESULT
# =========================================================

print("\n===== INFERENCE RESULT =====")

print(f"Predicted Class : {predicted_class}")

print(f"Confidence      : {confidence_score:.2f}%")

print("\n===== CLASS PROBABILITIES =====")

for idx, class_name in enumerate(classes):

    prob = probabilities[0][idx].item() * 100

    print(f"{class_name:<10} : {prob:.2f}%")