Cleaned code
Training
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# =========================================================
# 1. DATA PREPARATION
# =========================================================
# Training augmentation and normalization pipeline.
# STL10 images are already 96x96, so no resize is required.
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)
)
])
# Validation / test preprocessing pipeline.
# Only normalization is applied for evaluation consistency.
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)
)
])
# =========================================================
# 2. STL10 DATASET LOADING
# =========================================================
# Automatically downloads STL10 into ./data
train_dataset = torchvision.datasets.STL10(
root='./data',
split='train',
download=True,
transform=transform_train
)
test_dataset = torchvision.datasets.STL10(
root='./data',
split='test',
download=True,
transform=transform_test
)
# Data loaders for batch training and validation
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=2
)
val_loader = DataLoader(
test_dataset,
batch_size=64,
shuffle=False,
num_workers=2
)
print(f"Training samples : {len(train_dataset)}")
print(f"Testing samples : {len(test_dataset)}")
# =========================================================
# 3. CORE RELATIONAL LAYER — LOOKTHEM LAYER
# =========================================================
class LookThemLayer(nn.Module):
"""
Relational token-processing layer.
Each token owns its own tiny dual-network pair:
- mod1
- mod2
The outputs from both branches are compared against
every other token using ratio-based interaction maps.
Final interactions are transformed and redistributed
back into the token space.
"""
def __init__(self, num_tokens, in_features, hidden_dim):
super(LookThemLayer, self).__init__()
self.num_tokens = num_tokens
self.in_features = in_features
# -------------------------------------------------
# Branch 1 parameters
# -------------------------------------------------
self.mod1_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod1_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod1_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod1_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
# -------------------------------------------------
# Branch 2 parameters
# -------------------------------------------------
self.mod2_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod2_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod2_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod2_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
# -------------------------------------------------
# Relational transformation parameters
# -------------------------------------------------
self.trans_w = nn.Parameter(
torch.randn(num_tokens, 1, 1)
)
self.trans_b = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self._init_weights()
def _init_weights(self):
"""
Kaiming initialization for all learnable projections.
"""
for w in [
self.mod1_w1,
self.mod2_w1,
self.mod1_w2,
self.mod2_w2,
self.trans_w
]:
nn.init.kaiming_uniform_(w, a=math.sqrt(5))
def forward(self, x):
"""
Input shape:
[B, Tokens, Features]
Output shape:
[B, Tokens, Features]
"""
N = self.num_tokens
# =================================================
# Branch 1 forward pass
# =================================================
h1 = (
torch.einsum('bti,tij->btj', x, self.mod1_w1)
+ self.mod1_b1
)
out_m1 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h1),
self.mod1_w2
)
+ self.mod1_b2
)
# =================================================
# Branch 2 forward pass
# =================================================
h2 = (
torch.einsum('bti,tij->btj', x, self.mod2_w1)
+ self.mod2_b1
)
out_m2 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h2),
self.mod2_w2
)
+ self.mod2_b2
)
# Numerical stabilization
out_m2_safe = out_m2 + 1e-5
# =================================================
# Pairwise relational comparison
# =================================================
# Token-to-token directional comparison
compare = torch.tanh(
out_m1.unsqueeze(2) /
out_m2_safe.unsqueeze(1)
)
# Reverse-direction comparison
compare2 = torch.tanh(
out_m1.unsqueeze(1) /
out_m2_safe.unsqueeze(2)
)
# =================================================
# Transform relational maps
# =================================================
bias_reshaped = self.trans_b.view(1, 1, N, 1)
trans_compare = (
torch.einsum(
'bije,jef->bijf',
compare,
self.trans_w
)
+ bias_reshaped
)
trans_compare2 = (
torch.einsum(
'bije,jef->bijf',
compare2,
self.trans_w
)
+ bias_reshaped
)
# =================================================
# Bidirectional interaction fusion
# =================================================
interaction = (
trans_compare * x.unsqueeze(2)
+ trans_compare2 * x.unsqueeze(1)
) / 2
# Remove self-interaction
mask = 1.0 - torch.eye(N, device=x.device)
interaction_masked = (
interaction * mask.view(1, N, N, 1)
)
# Aggregate all external token interactions
return interaction_masked.sum(dim=2) / (N - 1.0)
# =========================================================
# 4. MAIN ARCHITECTURE — LOOKTHEM STL V1
# =========================================================
class LookThemSTLV1(nn.Module):
"""
Dual-stream relational vision architecture.
Stream A:
Macro-spatial extraction using aggressive downsampling.
Stream B:
Higher-detail extraction using slower reduction.
Both streams are fused inside relational LookThem layers.
"""
def __init__(self):
super(LookThemSTLV1, self).__init__()
# =================================================
# STREAM A — MACRO STRUCTURE STREAM
# =================================================
#
# Aggressive downsampling path focused on
# large-scale spatial structure extraction.
#
self.stream_a = nn.Sequential(
nn.Conv2d(
3, 16,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16, 32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32, 64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
# Final spatial alignment
nn.AdaptiveMaxPool2d((8, 8))
)
# =================================================
# STREAM B — MICRO DETAIL STREAM
# =================================================
#
# Slower reduction preserves more local detail
# before relational processing.
#
self.stream_b = nn.Sequential(
nn.Conv2d(
3, 16,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16, 32,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32, 64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
# Match Stream A token resolution
nn.AdaptiveMaxPool2d((8, 8))
)
# =================================================
# STREAM-SPECIFIC RELATIONAL PROCESSORS
# =================================================
self.lookthemA = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=16
)
self.lookthemB = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=16
)
# =================================================
# FUSION RELATIONAL PROCESSOR
# =================================================
#
# Receives concatenated features from both streams.
#
self.lookthem = LookThemLayer(
num_tokens=64,
in_features=128,
hidden_dim=32
)
# =================================================
# TOKEN COMPRESSOR
# =================================================
#
# Compresses token feature width before
# dense classification.
#
self.compressor = nn.AdaptiveAvgPool1d(32)
# =================================================
# CLASSIFIER HEAD
# =================================================
#
# Progressive dense head with dropout
# regularization to reduce overfitting.
#
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 32, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 10)
)
def forward(self, x):
batch_size = x.size(0)
# =================================================
# STREAM A FORWARD PASS
# =================================================
feat_a = self.stream_a(x)
# Convert spatial map into token representation
feat_a_flat = feat_a.view(batch_size, 64, 64)
feat_a_tokens = feat_a_flat.transpose(1, 2)
# Relational processing
feat_a_lt = self.lookthemA(feat_a_tokens)
# =================================================
# STREAM B FORWARD PASS
# =================================================
feat_b = self.stream_b(x)
feat_b_tokens = (
feat_b
.view(batch_size, 64, 64)
.transpose(1, 2)
)
feat_b_lt = self.lookthemB(feat_b_tokens)
# =================================================
# ASYMMETRIC FEATURE-LEVEL FUSION
# =================================================
#
# Keeps token count fixed while expanding
# feature dimensionality.
#
tokens_combined = torch.cat(
[feat_a_lt, feat_b_lt],
dim=2
)
# =================================================
# FINAL RELATIONAL COGNITION
# =================================================
out_lookthem = self.lookthem(tokens_combined)
# Token compression
compressed = self.compressor(out_lookthem)
# Final classification
return self.classifier(compressed)
# =========================================================
# 5. TRAINING RUNTIME + CHECKPOINT SYSTEM
# =========================================================
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
model = LookThemSTLV1().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
model.parameters(),
lr=0.001,
weight_decay=1e-4
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=40
)
start_epoch = 0
checkpoint_path = "lookthem_stl_checkpoint.pth"
# =========================================================
# CHECKPOINT RESUME
# =========================================================
if os.path.exists(checkpoint_path):
print(
"Checkpoint detected. "
"Resuming previous experiment..."
)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(
checkpoint['model_state_dict']
)
optimizer.load_state_dict(
checkpoint['optimizer_state_dict']
)
scheduler.load_state_dict(
checkpoint['scheduler_state_dict']
)
start_epoch = checkpoint['epoch']
print(
f"Successfully resumed from "
f"epoch {start_epoch + 1}"
)
print(
f"Starting LookThem STL V1 training on {device}..."
)
# =========================================================
# TRAINING LOOP
# =========================================================
for epoch in range(start_epoch, 100):
model.train()
total_loss = 0
correct = 0
total = 0
for data, target in train_loader:
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
scheduler.step()
acc = 100. * correct / total
current_lr = optimizer.param_groups[0]['lr']
print(
f"Epoch {epoch+1:02d}/100 | "
f"Train Loss: "
f"{total_loss / len(train_loader):.4f} | "
f"Train Acc: {acc:.2f}% | "
f"LR: {current_lr:.6f}"
)
# -----------------------------------------------------
# Periodic checkpoint save
# -----------------------------------------------------
if (epoch + 1) % 5 == 0:
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, checkpoint_path)
print(
f"[CHECKPOINT] "
f"Epoch {epoch+1} saved successfully."
)
# =========================================================
# 6. FINAL VALIDATION
# =========================================================
model.eval()
test_loss = 0
test_correct = 0
test_total = 0
print("\nStarting final validation...")
with torch.no_grad():
for data, target in val_loader:
data = data.to(device)
target = target.to(device)
output = model(data)
loss = criterion(output, target)
test_loss += loss.item()
_, predicted = output.max(1)
test_total += target.size(0)
test_correct += predicted.eq(target).sum().item()
final_test_acc = 100. * test_correct / test_total
print("=== FINAL LOOKTHEM STL V1 RESULTS ===")
print(
f"Test Loss: "
f"{test_loss / len(val_loader):.4f} | "
f"Test Accuracy: {final_test_acc:.2f}%"
)
# Save final trained weights
torch.save(model.state_dict(), "LookThem_STL.pth")
print(
f"Training complete! "
f"Final model size: "
f"{os.path.getsize('LookThem_STL.pth') / (1024*1024):.2f} MB"
)
Inference
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import math
# =========================================================
# 1. LOOKTHEM CORE LAYER
# =========================================================
class LookThemLayer(nn.Module):
"""
Relational token-processing layer used by
the LookThem STL architecture.
"""
def __init__(self, num_tokens, in_features, hidden_dim):
super(LookThemLayer, self).__init__()
self.num_tokens = num_tokens
self.in_features = in_features
# -------------------------------------------------
# Branch 1
# -------------------------------------------------
self.mod1_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod1_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod1_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod1_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
# -------------------------------------------------
# Branch 2
# -------------------------------------------------
self.mod2_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod2_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod2_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod2_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
# -------------------------------------------------
# Relational transformation
# -------------------------------------------------
self.trans_w = nn.Parameter(
torch.randn(num_tokens, 1, 1)
)
self.trans_b = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self._init_weights()
def _init_weights(self):
for w in [
self.mod1_w1,
self.mod2_w1,
self.mod1_w2,
self.mod2_w2,
self.trans_w
]:
nn.init.kaiming_uniform_(
w,
a=math.sqrt(5)
)
def forward(self, x):
N = self.num_tokens
# =================================================
# Branch 1
# =================================================
h1 = (
torch.einsum(
'bti,tij->btj',
x,
self.mod1_w1
)
+ self.mod1_b1
)
out_m1 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h1),
self.mod1_w2
)
+ self.mod1_b2
)
# =================================================
# Branch 2
# =================================================
h2 = (
torch.einsum(
'bti,tij->btj',
x,
self.mod2_w1
)
+ self.mod2_b1
)
out_m2 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h2),
self.mod2_w2
)
+ self.mod2_b2
)
# Numerical stabilization
out_m2_safe = out_m2 + 1e-5
# =================================================
# Pairwise comparison
# =================================================
compare = torch.tanh(
out_m1.unsqueeze(2) /
out_m2_safe.unsqueeze(1)
)
compare2 = torch.tanh(
out_m1.unsqueeze(1) /
out_m2_safe.unsqueeze(2)
)
# =================================================
# Relational transformation
# =================================================
bias_reshaped = self.trans_b.view(
1,
1,
N,
1
)
trans_compare = (
torch.einsum(
'bije,jef->bijf',
compare,
self.trans_w
)
+ bias_reshaped
)
trans_compare2 = (
torch.einsum(
'bije,jef->bijf',
compare2,
self.trans_w
)
+ bias_reshaped
)
# =================================================
# Interaction fusion
# =================================================
interaction = (
trans_compare * x.unsqueeze(2)
+ trans_compare2 * x.unsqueeze(1)
) / 2
# Remove self-interaction
mask = 1.0 - torch.eye(
N,
device=x.device
)
interaction_masked = (
interaction *
mask.view(1, N, N, 1)
)
return (
interaction_masked.sum(dim=2)
/ (N - 1.0)
)
# =========================================================
# 2. LOOKTHEM STL MODEL
# =========================================================
class LookThemSTLV1(nn.Module):
def __init__(self):
super(LookThemSTLV1, self).__init__()
# =================================================
# STREAM A — MACRO STRUCTURE
# =================================================
self.stream_a = nn.Sequential(
nn.Conv2d(
3,
16,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16,
32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32,
64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8))
)
# =================================================
# STREAM B — MICRO DETAIL
# =================================================
self.stream_b = nn.Sequential(
nn.Conv2d(
3,
16,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16,
32,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32,
64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8))
)
# =================================================
# RELATIONAL PROCESSORS
# =================================================
self.lookthemA = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=16
)
self.lookthemB = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=16
)
self.lookthem = LookThemLayer(
num_tokens=64,
in_features=128,
hidden_dim=32
)
# =================================================
# TOKEN COMPRESSOR
# =================================================
self.compressor = nn.AdaptiveAvgPool1d(32)
# =================================================
# CLASSIFIER HEAD
# =================================================
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 32, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 10)
)
def forward(self, x):
batch_size = x.size(0)
# =================================================
# STREAM A
# =================================================
feat_a = self.stream_a(x)
feat_a_flat = feat_a.view(
batch_size,
64,
64
)
feat_a_tokens = feat_a_flat.transpose(1, 2)
feat_a_lt = self.lookthemA(feat_a_tokens)
# =================================================
# STREAM B
# =================================================
feat_b = self.stream_b(x)
feat_b_tokens = (
feat_b
.view(batch_size, 64, 64)
.transpose(1, 2)
)
feat_b_lt = self.lookthemB(feat_b_tokens)
# =================================================
# FEATURE FUSION
# =================================================
tokens_combined = torch.cat(
[feat_a_lt, feat_b_lt],
dim=2
)
# =================================================
# RELATIONAL COGNITION
# =================================================
out_lookthem = self.lookthem(tokens_combined)
compressed = self.compressor(out_lookthem)
return self.classifier(compressed)
# =========================================================
# 3. DEVICE SETUP
# =========================================================
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
print(f"Using device: {device}")
# =========================================================
# 4. CLASS LABELS
# =========================================================
classes = [
"airplane",
"bird",
"car",
"cat",
"deer",
"dog",
"horse",
"monkey",
"ship",
"truck"
]
# =========================================================
# 5. IMAGE TRANSFORM
# =========================================================
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)
)
])
# =========================================================
# 6. LOAD MODEL
# =========================================================
model = LookThemSTLV1().to(device)
model.load_state_dict(
torch.load(
"LookThem_STL.pth",
map_location=device
)
)
model.eval()
print("Model loaded successfully!")
# =========================================================
# 7. LOAD IMAGE
# =========================================================
# Replace with your image path
image_path = "test.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image)
# Add batch dimension
input_tensor = input_tensor.unsqueeze(0).to(device)
# =========================================================
# 8. INFERENCE
# =========================================================
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output, dim=1)
confidence, predicted = torch.max(
probabilities,
dim=1
)
predicted_class = classes[predicted.item()]
confidence_score = confidence.item() * 100
# =========================================================
# 9. RESULT
# =========================================================
print("\n===== INFERENCE RESULT =====")
print(f"Predicted Class : {predicted_class}")
print(f"Confidence : {confidence_score:.2f}%")
print("\n===== CLASS PROBABILITIES =====")
for idx, class_name in enumerate(classes):
prob = probabilities[0][idx].item() * 100
print(f"{class_name:<10} : {prob:.2f}%")