neuro-ssl / model.py
ABCREATIVEAKSHAY's picture
Deploy NeuroSSL dementia classifier application
01668e2 verified
Raw
History Blame Contribute Delete
9.39 kB
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from einops import rearrange
class DropPath(nn.Module):
def __init__(self, drop_prob=0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class PatchEmbed2D(nn.Module):
def __init__(self, img_size, patch_size, in_chans=1, embed_dim=384):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return rearrange(x, 'b e h w -> b (h w) e')
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0, drop_path=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(mlp_dim, embed_dim), nn.Dropout(dropout)
)
self.drop_path = DropPath(drop_path)
def forward(self, x):
x_norm = self.norm1(x)
attn_out, attn_weights = self.attn(x_norm, x_norm, x_norm, need_weights=True, average_attn_weights=False)
x = x + self.drop_path(attn_out)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn_weights
class MultiScaleViT2DEncoder(nn.Module):
def __init__(self, img_size, patch_size, embed_dim=384, depth=12, num_heads=6,
mlp_ratio=4.0, dropout=0.0, drop_path_rate=0.1):
super().__init__()
self.patch_embed = PatchEmbed2D(img_size, patch_size, 1, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout, dpr[i]) for i in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.scale_indices = [2, 5, 8, 11]
self.scale_fusion = nn.Sequential(
nn.Linear(embed_dim * len(self.scale_indices), embed_dim), nn.GELU(), nn.LayerNorm(embed_dim)
)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x, mask=None, return_all_tokens=False):
B = x.size(0)
tokens = self.patch_embed(x)
tokens = tokens + self.pos_embed[:, 1:tokens.shape[1]+1]
if mask is not None:
tokens = tokens[mask].reshape(B, -1, tokens.size(-1))
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, tokens], dim=1)
tokens = tokens + self.pos_embed[:, :tokens.shape[1]]
intermediate_features = []
attn_maps = []
for i, block in enumerate(self.blocks):
tokens, attn = block(tokens)
if i in self.scale_indices:
intermediate_features.append(tokens[:, 1:])
attn_maps.append(attn)
tokens = self.norm(tokens)
if len(intermediate_features) > 0:
pooled_scales = [f.mean(dim=1) for f in intermediate_features]
multi_scale_vec = torch.cat(pooled_scales, dim=-1)
multi_scale_vec = self.scale_fusion(multi_scale_vec)
else:
multi_scale_vec = tokens[:, 0]
if return_all_tokens:
return tokens, multi_scale_vec, attn_maps
return tokens
class SecondOrderClassifier(nn.Module):
def __init__(self, encoder, num_classes=2, dropout=0.3, use_second_order=True):
super().__init__()
self.encoder = encoder
self.use_second_order = use_second_order
embed_dim = encoder.blocks[0].norm1.normalized_shape[0]
if use_second_order:
self.cov_proj = nn.Sequential(
nn.Linear(embed_dim * embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout),
nn.Linear(embed_dim * 4, embed_dim)
)
self.head = nn.Sequential(
nn.LayerNorm(embed_dim * 2), nn.Dropout(dropout), nn.Linear(embed_dim * 2, num_classes)
)
else:
self.head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(dropout), nn.Linear(embed_dim, num_classes))
self.mc_dropout = nn.Dropout(dropout)
def forward(self, x, return_features=False, mc_dropout=False):
tokens, multi_scale, _ = self.encoder(x, return_all_tokens=True)
cls = tokens[:, 0]
if self.use_second_order:
patches = tokens[:, 1:]
mean_patches = patches.mean(dim=1, keepdim=True)
centered = patches - mean_patches
cov = torch.bmm(centered.transpose(1,2), centered) / centered.size(1)
cov_vec = cov.reshape(cov.size(0), -1)
second_order = self.cov_proj(cov_vec)
second_order = F.normalize(second_order, p=2, dim=-1)
features = torch.cat([cls, second_order], dim=-1)
else:
features = cls
if mc_dropout:
features = self.mc_dropout(features)
logits = self.head(features)
if return_features:
return logits, features, cls
return logits
def get_model(checkpoint_path, device):
"""Load model once and transfer to target device."""
print(f"Loading checkpoint from {checkpoint_path}...")
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
encoder = MultiScaleViT2DEncoder((224, 224), (16, 16), 384, 12, 6, drop_path_rate=0.0)
model = SecondOrderClassifier(encoder, num_classes=2, dropout=0.3, use_second_order=True)
model.load_state_dict(ckpt['model_state'])
model = model.to(device).eval()
print("Model loaded successfully!")
return model
def preprocess_image(image_path_or_bytes, target_size=(224, 224)):
"""Convert and normalize image to Z-score tensor."""
if isinstance(image_path_or_bytes, (str, bytes)) and isinstance(image_path_or_bytes, str):
img = Image.open(image_path_or_bytes).convert('L')
else:
img = Image.open(io.BytesIO(image_path_or_bytes)).convert('L')
img = img.resize((target_size[1], target_size[0]), Image.BILINEAR)
tensor = torch.from_numpy(np.array(img, dtype=np.float32)).unsqueeze(0) / 255.0
mu = tensor.mean()
std = tensor.std() + 1e-8
tensor = (tensor - mu) / std
return tensor.unsqueeze(0), img
def attention_rollout(model, x, device):
"""Compute the vision transformer attention maps using rollout."""
model.eval()
with torch.no_grad():
_, _, attn_maps = model.encoder(x.to(device), return_all_tokens=True)
B = x.size(0)
result = torch.eye(attn_maps[0].size(-1), device=device).unsqueeze(0).repeat(B, 1, 1)
for attn in attn_maps:
attn = attn.mean(dim=1)
attn = attn + torch.eye(attn.size(-1), device=device).unsqueeze(0)
attn = attn / attn.sum(dim=-1, keepdim=True)
result = torch.bmm(result, attn)
mask = result[:, 0, 1:]
H = W = int(np.sqrt(mask.size(1)))
mask = mask.reshape(B, H, W)
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
return mask
def predict_with_uncertainty(model, tensor, device, temperature=1.6995, mc_samples=10):
"""Run model inference, MC Dropout uncertainty estimation, and attention rollout."""
tensor = tensor.to(device)
# 1. Main deterministic calibrated prediction
model.eval()
with torch.no_grad():
logits = model(tensor)
raw_probs = torch.softmax(logits, dim=1)
calibrated_logits = logits / temperature
calibrated_probs = torch.softmax(calibrated_logits, dim=1)
# 2. MC Dropout predictions for uncertainty quantification
mc_probs = []
if mc_samples > 0:
model.train()
with torch.no_grad():
for _ in range(mc_samples):
logit = model(tensor, mc_dropout=True)
prob = torch.softmax(logit, dim=1)[:, 1]
mc_probs.append(prob.item())
model.eval()
mc_std = np.std(mc_probs)
else:
mc_std = 0.0
# 3. Attention rollout Map
attn_mask = attention_rollout(model, tensor, device)
attn_mask_np = attn_mask[0].cpu().numpy() # shape (14, 14)
return {
'raw_probability': float(raw_probs[0, 1]),
'calibrated_probability': float(calibrated_probs[0, 1]),
'uncertainty': float(mc_std),
'attention_map': attn_mask_np.tolist() # converts to nested list (14x14)
}