Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image | |
| from einops import rearrange | |
| class DropPath(nn.Module): | |
| def __init__(self, drop_prob=0.0): | |
| super().__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x): | |
| if self.drop_prob == 0.0 or not self.training: | |
| return x | |
| keep_prob = 1 - self.drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
| random_tensor.floor_() | |
| return x.div(keep_prob) * random_tensor | |
| class PatchEmbed2D(nn.Module): | |
| def __init__(self, img_size, patch_size, in_chans=1, embed_dim=384): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| def forward(self, x): | |
| x = self.proj(x) | |
| return rearrange(x, 'b e h w -> b (h w) e') | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0, drop_path=0.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| mlp_dim = int(embed_dim * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Dropout(dropout), | |
| nn.Linear(mlp_dim, embed_dim), nn.Dropout(dropout) | |
| ) | |
| self.drop_path = DropPath(drop_path) | |
| def forward(self, x): | |
| x_norm = self.norm1(x) | |
| attn_out, attn_weights = self.attn(x_norm, x_norm, x_norm, need_weights=True, average_attn_weights=False) | |
| x = x + self.drop_path(attn_out) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x, attn_weights | |
| class MultiScaleViT2DEncoder(nn.Module): | |
| def __init__(self, img_size, patch_size, embed_dim=384, depth=12, num_heads=6, | |
| mlp_ratio=4.0, dropout=0.0, drop_path_rate=0.1): | |
| super().__init__() | |
| self.patch_embed = PatchEmbed2D(img_size, patch_size, 1, embed_dim) | |
| num_patches = self.patch_embed.num_patches | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim)) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout, dpr[i]) for i in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.scale_indices = [2, 5, 8, 11] | |
| self.scale_fusion = nn.Sequential( | |
| nn.Linear(embed_dim * len(self.scale_indices), embed_dim), nn.GELU(), nn.LayerNorm(embed_dim) | |
| ) | |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| nn.init.trunc_normal_(self.cls_token, std=0.02) | |
| def forward(self, x, mask=None, return_all_tokens=False): | |
| B = x.size(0) | |
| tokens = self.patch_embed(x) | |
| tokens = tokens + self.pos_embed[:, 1:tokens.shape[1]+1] | |
| if mask is not None: | |
| tokens = tokens[mask].reshape(B, -1, tokens.size(-1)) | |
| cls = self.cls_token.expand(B, -1, -1) | |
| tokens = torch.cat([cls, tokens], dim=1) | |
| tokens = tokens + self.pos_embed[:, :tokens.shape[1]] | |
| intermediate_features = [] | |
| attn_maps = [] | |
| for i, block in enumerate(self.blocks): | |
| tokens, attn = block(tokens) | |
| if i in self.scale_indices: | |
| intermediate_features.append(tokens[:, 1:]) | |
| attn_maps.append(attn) | |
| tokens = self.norm(tokens) | |
| if len(intermediate_features) > 0: | |
| pooled_scales = [f.mean(dim=1) for f in intermediate_features] | |
| multi_scale_vec = torch.cat(pooled_scales, dim=-1) | |
| multi_scale_vec = self.scale_fusion(multi_scale_vec) | |
| else: | |
| multi_scale_vec = tokens[:, 0] | |
| if return_all_tokens: | |
| return tokens, multi_scale_vec, attn_maps | |
| return tokens | |
| class SecondOrderClassifier(nn.Module): | |
| def __init__(self, encoder, num_classes=2, dropout=0.3, use_second_order=True): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.use_second_order = use_second_order | |
| embed_dim = encoder.blocks[0].norm1.normalized_shape[0] | |
| if use_second_order: | |
| self.cov_proj = nn.Sequential( | |
| nn.Linear(embed_dim * embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout), | |
| nn.Linear(embed_dim * 4, embed_dim) | |
| ) | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(embed_dim * 2), nn.Dropout(dropout), nn.Linear(embed_dim * 2, num_classes) | |
| ) | |
| else: | |
| self.head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(dropout), nn.Linear(embed_dim, num_classes)) | |
| self.mc_dropout = nn.Dropout(dropout) | |
| def forward(self, x, return_features=False, mc_dropout=False): | |
| tokens, multi_scale, _ = self.encoder(x, return_all_tokens=True) | |
| cls = tokens[:, 0] | |
| if self.use_second_order: | |
| patches = tokens[:, 1:] | |
| mean_patches = patches.mean(dim=1, keepdim=True) | |
| centered = patches - mean_patches | |
| cov = torch.bmm(centered.transpose(1,2), centered) / centered.size(1) | |
| cov_vec = cov.reshape(cov.size(0), -1) | |
| second_order = self.cov_proj(cov_vec) | |
| second_order = F.normalize(second_order, p=2, dim=-1) | |
| features = torch.cat([cls, second_order], dim=-1) | |
| else: | |
| features = cls | |
| if mc_dropout: | |
| features = self.mc_dropout(features) | |
| logits = self.head(features) | |
| if return_features: | |
| return logits, features, cls | |
| return logits | |
| def get_model(checkpoint_path, device): | |
| """Load model once and transfer to target device.""" | |
| print(f"Loading checkpoint from {checkpoint_path}...") | |
| ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) | |
| encoder = MultiScaleViT2DEncoder((224, 224), (16, 16), 384, 12, 6, drop_path_rate=0.0) | |
| model = SecondOrderClassifier(encoder, num_classes=2, dropout=0.3, use_second_order=True) | |
| model.load_state_dict(ckpt['model_state']) | |
| model = model.to(device).eval() | |
| print("Model loaded successfully!") | |
| return model | |
| def preprocess_image(image_path_or_bytes, target_size=(224, 224)): | |
| """Convert and normalize image to Z-score tensor.""" | |
| if isinstance(image_path_or_bytes, (str, bytes)) and isinstance(image_path_or_bytes, str): | |
| img = Image.open(image_path_or_bytes).convert('L') | |
| else: | |
| img = Image.open(io.BytesIO(image_path_or_bytes)).convert('L') | |
| img = img.resize((target_size[1], target_size[0]), Image.BILINEAR) | |
| tensor = torch.from_numpy(np.array(img, dtype=np.float32)).unsqueeze(0) / 255.0 | |
| mu = tensor.mean() | |
| std = tensor.std() + 1e-8 | |
| tensor = (tensor - mu) / std | |
| return tensor.unsqueeze(0), img | |
| def attention_rollout(model, x, device): | |
| """Compute the vision transformer attention maps using rollout.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| _, _, attn_maps = model.encoder(x.to(device), return_all_tokens=True) | |
| B = x.size(0) | |
| result = torch.eye(attn_maps[0].size(-1), device=device).unsqueeze(0).repeat(B, 1, 1) | |
| for attn in attn_maps: | |
| attn = attn.mean(dim=1) | |
| attn = attn + torch.eye(attn.size(-1), device=device).unsqueeze(0) | |
| attn = attn / attn.sum(dim=-1, keepdim=True) | |
| result = torch.bmm(result, attn) | |
| mask = result[:, 0, 1:] | |
| H = W = int(np.sqrt(mask.size(1))) | |
| mask = mask.reshape(B, H, W) | |
| mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8) | |
| return mask | |
| def predict_with_uncertainty(model, tensor, device, temperature=1.6995, mc_samples=10): | |
| """Run model inference, MC Dropout uncertainty estimation, and attention rollout.""" | |
| tensor = tensor.to(device) | |
| # 1. Main deterministic calibrated prediction | |
| model.eval() | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| raw_probs = torch.softmax(logits, dim=1) | |
| calibrated_logits = logits / temperature | |
| calibrated_probs = torch.softmax(calibrated_logits, dim=1) | |
| # 2. MC Dropout predictions for uncertainty quantification | |
| mc_probs = [] | |
| if mc_samples > 0: | |
| model.train() | |
| with torch.no_grad(): | |
| for _ in range(mc_samples): | |
| logit = model(tensor, mc_dropout=True) | |
| prob = torch.softmax(logit, dim=1)[:, 1] | |
| mc_probs.append(prob.item()) | |
| model.eval() | |
| mc_std = np.std(mc_probs) | |
| else: | |
| mc_std = 0.0 | |
| # 3. Attention rollout Map | |
| attn_mask = attention_rollout(model, tensor, device) | |
| attn_mask_np = attn_mask[0].cpu().numpy() # shape (14, 14) | |
| return { | |
| 'raw_probability': float(raw_probs[0, 1]), | |
| 'calibrated_probability': float(calibrated_probs[0, 1]), | |
| 'uncertainty': float(mc_std), | |
| 'attention_map': attn_mask_np.tolist() # converts to nested list (14x14) | |
| } | |