import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from torchvision import transforms as T from PIL import Image import math import numpy as np import random from facenet_pytorch import InceptionResnetV1 from collections import OrderedDict # ========================================================================= # PART 0: Regularization & Diagnostics # ========================================================================= class DropPath(nn.Module): """Stochastic Depth (DropPath) regularization.""" def __init__(self, drop_prob=0.1): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): if self.drop_prob == 0. or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() output = x.div(keep_prob) * random_tensor return output class LayerActivations: """Optimized Hook-based logger for debugging feature collapse.""" def __init__(self): self.hooks = {} self.stats = OrderedDict() def register_hook(self, layer_name, layer): def hook_fn(module, input, output): out_tensor = output.detach() self.stats[layer_name] = { "mean": out_tensor.mean().item(), "std": out_tensor.std().item(), "max": out_tensor.max().item(), "shape": tuple(out_tensor.shape) } self.hooks[layer_name] = layer.register_forward_hook(hook_fn) def remove_hooks(self): for h in self.hooks.values(): h.remove() self.hooks = {} self.stats = OrderedDict() def get_stats(self): return self.stats # ========================================================================= # PART 1: Architecture Modules # ========================================================================= class CoordinateAttention(nn.Module): def __init__(self, in_channels, reduction=32): super(CoordinateAttention, self).__init__() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) mip = max(8, in_channels // reduction) self.conv1 = nn.Conv2d(in_channels, mip, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(mip) self.act = nn.Hardswish() self.conv_h = nn.Conv2d(mip, in_channels, kernel_size=1, stride=1, padding=0) self.conv_w = nn.Conv2d(mip, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): identity = x n, c, h, w = x.size() x_h = self.pool_h(x) x_w = self.pool_w(x).permute(0, 1, 3, 2) y = torch.cat([x_h, x_w], dim=2) y = self.conv1(y) y = self.bn1(y) y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2) x_w = x_w.permute(0, 1, 3, 2) a_h = torch.sigmoid(self.conv_h(x_h)) a_w = torch.sigmoid(self.conv_w(x_w)) return identity * a_h * a_w, a_h class VIB_UIFS(nn.Module): def __init__(self, in_channels, latent_dim=512, num_heads=4): super(VIB_UIFS, self).__init__() self.num_heads = num_heads self.head_dim = latent_dim // num_heads self.pool = nn.AdaptiveAvgPool2d((1, 1)) # FIX: Changed LeakyReLU to GELU per requirements self.fc_shared = nn.Sequential( nn.Linear(in_channels, in_channels // 2), nn.LayerNorm(in_channels // 2), nn.GELU(), nn.Dropout(0.3) # Increased dropout slightly ) self.head_processors = nn.ModuleList([ nn.Linear(in_channels // 2, latent_dim) for _ in range(num_heads) ]) self.fc_mu_heads = nn.ModuleList([ nn.Linear(latent_dim, self.head_dim) for _ in range(num_heads) ]) self.fc_logvar_heads = nn.ModuleList([ nn.Linear(latent_dim, self.head_dim) for _ in range(num_heads) ]) self.quality_head = nn.Sequential( nn.Linear(in_channels // 2, 64), nn.GELU(), nn.Linear(64, 1), nn.Sigmoid() ) self.register_buffer('prior_mu', torch.zeros(latent_dim)) def forward(self, x): flat = self.pool(x).flatten(1) shared = self.fc_shared(flat) mu_heads = [] logvar_heads = [] for i in range(self.num_heads): # FIX: GELU for head processors head_feat = F.gelu(self.head_processors[i](shared)) mu_heads.append(self.fc_mu_heads[i](head_feat)) logvar_heads.append(self.fc_logvar_heads[i](head_feat)) mu = torch.cat(mu_heads, dim=1) logvar = torch.cat(logvar_heads, dim=1) quality_score = self.quality_head(shared) # FIX: Tighter variance bounds logvar = torch.clamp(logvar, min=-10, max=4) std = torch.exp(0.5 * logvar) modulation = 1.3 - (quality_score * 0.6) std = std * modulation if self.training: eps = torch.randn_like(std) z = mu + eps * std else: z = mu return mu, std, z, torch.stack(mu_heads, dim=1), quality_score class IAM_Block(nn.Module): def __init__(self, latent_dim, channels): super(IAM_Block, self).__init__() self.fc_params = nn.Sequential( nn.Linear(latent_dim, latent_dim), nn.GELU(), # FIX: GELU nn.utils.spectral_norm(nn.Linear(latent_dim, channels * 2)) ) self.gate_fc = nn.Sequential( nn.Linear(latent_dim, channels), nn.Sigmoid() ) self.norm = nn.GroupNorm(num_groups=32, num_channels=channels, affine=False) self.alpha = nn.Parameter(torch.tensor(0.1)) self.drop_path = DropPath(0.15) # Slightly stronger drop path def forward(self, spatial_features, z): params = self.fc_params(z).unsqueeze(2).unsqueeze(3) gamma, beta = params.chunk(2, dim=1) gate = self.gate_fc(z).unsqueeze(2).unsqueeze(3) normalized = self.norm(spatial_features) modulated = normalized * (1 + gate * gamma) + (gate * beta) modulated = self.drop_path(modulated) weight_mod = torch.sigmoid(self.alpha) weight_orig = 1.0 - weight_mod return weight_orig * spatial_features + weight_mod * modulated class CrossAttentionPooling(nn.Module): def __init__(self, feature_dim, latent_dim, num_heads=8): super(CrossAttentionPooling, self).__init__() self.num_heads = num_heads self.head_dim = feature_dim // num_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.w_q = nn.Linear(latent_dim, feature_dim) self.w_k = nn.Conv2d(feature_dim, feature_dim, kernel_size=1) self.w_v = nn.Conv2d(feature_dim, feature_dim, kernel_size=1) self.ln_q = nn.LayerNorm(feature_dim) self.ln_k = nn.LayerNorm(feature_dim) self.proj_out = nn.Sequential( nn.Linear(feature_dim, latent_dim), nn.GELU(), # FIX: GELU nn.Dropout(0.25), nn.Linear(latent_dim, latent_dim) ) def forward(self, refined_feats, z_prior): B, C, H, W = refined_feats.shape N = H * W q = self.w_q(z_prior) q = self.ln_q(q).view(B, self.num_heads, self.head_dim).unsqueeze(2) k = self.w_k(refined_feats).view(B, C, N).permute(0, 2, 1) k = self.ln_k(k).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = self.w_v(refined_feats).view(B, C, N).permute(0, 2, 1) v = v.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) attn_logits = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn_weights = F.softmax(attn_logits, dim=-1) if self.training: attn_weights_drop = F.dropout(attn_weights, p=0.1, training=True) else: attn_weights_drop = attn_weights context = torch.matmul(attn_weights_drop, v).squeeze(2).reshape(B, C) return self.proj_out(context), attn_weights # ========================================================================= # PART 2: Main Model (Robust Config) # ========================================================================= class DSIR_VIB(nn.Module): def __init__(self, latent_dim=512): super(DSIR_VIB, self).__init__() print("Initializing DSIR-VIB with Strong Regularization & GELU...") self.backbone = InceptionResnetV1(pretrained='vggface2', classify=False) self._freeze_early_layers() self.diagnostics = LayerActivations() self.feat_dim = 1792 self.coord_attn = CoordinateAttention(self.feat_dim) self.uifs = VIB_UIFS(self.feat_dim, latent_dim, num_heads=4) self.iam = IAM_Block(latent_dim, self.feat_dim) self.cross_attn = CrossAttentionPooling(self.feat_dim, latent_dim, num_heads=8) # FIX: GELU and stronger dropout self.feature_proj = nn.Sequential( nn.Linear(self.feat_dim, self.feat_dim), nn.LayerNorm(self.feat_dim), nn.GELU(), nn.Dropout(0.3) ) self.final_project = nn.Linear(latent_dim * 2, latent_dim) self._init_warm_start() def _freeze_early_layers(self): freeze_until = ['conv2d_1a', 'conv2d_2a', 'conv2d_2b', 'maxpool_3a', 'conv2d_3b', 'conv2d_4a', 'conv2d_4b', 'repeat_1', 'mixed_6a', 'repeat_2'] for name, module in self.backbone.named_children(): if name in freeze_until: for param in module.parameters(): param.requires_grad = False else: if name in ['mixed_7a', 'repeat_3', 'block8', 'avgpool_1a']: for param in module.parameters(): param.requires_grad = True def _init_warm_start(self): # FIX: Better initialization for skip connections if hasattr(self.iam, 'gate_fc'): for m in self.iam.gate_fc: if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0, std=0.01) nn.init.constant_(m.bias, -2.0) # FIX: Xavier init with proper gain for GELU nn.init.xavier_uniform_(self.final_project.weight, gain=1.0) nn.init.zeros_(self.final_project.bias) def extract_spatial(self, x): x = self.backbone.conv2d_1a(x) x = self.backbone.conv2d_2a(x) x = self.backbone.conv2d_2b(x) x = self.backbone.maxpool_3a(x) x = self.backbone.conv2d_3b(x) x = self.backbone.conv2d_4a(x) x = self.backbone.conv2d_4b(x) x = self.backbone.repeat_1(x) x = self.backbone.mixed_6a(x) x = self.backbone.repeat_2(x) x = self.backbone.mixed_7a(x) x = self.backbone.repeat_3(x) x = self.backbone.block8(x) x, attn_map = self.coord_attn(x) B, C, H, W = x.shape x_flat = x.view(B, C, -1).mean(dim=2) x_proj = self.feature_proj(x_flat).view(B, C, 1, 1).expand_as(x) # Residual connection x = x + 0.1 * x_proj return x, attn_map def forward(self, x, return_heads_quality=False, return_intermediate=False, debug=False): if debug: self.diagnostics.register_hook('conv2d_4a', self.backbone.conv2d_4a) self.diagnostics.register_hook('mixed_6a', self.backbone.mixed_6a) self.diagnostics.register_hook('block8', self.backbone.block8) spatial_feats, spatial_attn_map = self.extract_spatial(x) if debug: print("\n--- Diagnostic Log ---") stats = self.diagnostics.get_stats() for layer, data in stats.items(): print(f"[{layer}] Mean: {data['mean']:.4f} | Std: {data['std']:.4f} | Max: {data['max']:.4f}") print("----------------------\n") self.diagnostics.remove_hooks() mu_prior, std, z_prior, mu_heads, quality = self.uifs(spatial_feats) refined_feats = self.iam(spatial_feats, z_prior) z_res, cross_attn_weights = self.cross_attn(refined_feats, z_prior) combined = torch.cat([mu_prior, z_res], dim=1) mu_final = self.final_project(combined) mu_final = F.normalize(mu_final, p=2, dim=1) if return_heads_quality: return { 'mu_final': mu_final, 'std': std, 'mu_prior': mu_prior, 'z_res': z_res, 'mu_heads': mu_heads, 'quality': quality, 'cross_attn': cross_attn_weights, 'spatial_feats': spatial_feats } if return_intermediate: return mu_final, std, mu_prior, z_res return mu_final, std # ========================================================================= # PART 3: Helpers (CRITICAL UPDATES) # ========================================================================= def wasserstein_distance(mu1, std1, mu2, std2, temperature=0.07): """ STRICTER distance metric: 50% Angular + 30% Cosine + 15% Euclidean + 5% Uncertainty """ # Normalize inputs mu1_norm = F.normalize(mu1, dim=-1) mu2_norm = F.normalize(mu2, dim=-1) # 1. Cosine Similarity (-1 to 1) cosine_sim = torch.sum(mu1_norm * mu2_norm, dim=-1) # 2. Angular Distance (Most Discriminative) # Clamp for numerical stability of acos cosine_sim_clamped = torch.clamp(cosine_sim, -0.9999, 0.9999) angular_dist = torch.acos(cosine_sim_clamped) / math.pi # Normalized 0-1 # 3. Standard Cosine Distance cosine_dist = 1.0 - cosine_sim # 4. Euclidean Distance (Auxiliary) euclidean_dist = torch.norm(mu1 - mu2, p=2, dim=-1) # 5. Uncertainty Penalty uncertainty_penalty = torch.abs(std1.mean() - std2.mean()) # MIXED COMPONENTS (50% Ang, 30% Cos, 15% Euc, 5% Unc) total_dist = (0.50 * angular_dist) + \ (0.30 * cosine_dist) + \ (0.15 * euclidean_dist) + \ (0.05 * uncertainty_penalty) # Apply strict temperature scaling return total_dist / temperature def get_transforms(augment=False): if augment: return transforms.Compose([ transforms.Resize((160, 160)), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05), transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) else: return transforms.Compose([ transforms.Resize((160, 160)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def create_synthetic_variation(img, severity=1): """Create synthetic variations of face images during enrollment.""" # Base image img_t = T.ToTensor()(img) # Apply augmentations based on severity if severity == 1: # Mild variations trans = T.Compose([ T.RandomHorizontalFlip(p=0.3), T.ColorJitter(brightness=0.1, contrast=0.1) ]) elif severity == 2: # Moderate variations trans = T.Compose([ T.RandomAffine(degrees=10, translate=(0.05, 0.05)), T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1) ]) else: # Strong variations trans = T.Compose([ T.RandomAffine(degrees=20, translate=(0.1, 0.1)), T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)) ]) img_t = trans(img_t) return T.ToPILImage()(img_t) # ========================================================================= # PART 4: Losses # ========================================================================= class EfficientCenterLoss(nn.Module): def __init__(self, num_classes, feat_dim, lambda_c=0.01): super(EfficientCenterLoss, self).__init__() self.num_classes = num_classes self.lambda_c = lambda_c self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) def forward(self, features, labels): batch_centers = self.centers[labels] loss = F.mse_loss(features, batch_centers, reduction='mean') return self.lambda_c * loss class AngularMarginLoss(nn.Module): def __init__(self, num_classes, feat_dim, s=64.0, m=0.5): super(AngularMarginLoss, self).__init__() self.s = s self.m = m self.weight = nn.Parameter(torch.randn(num_classes, feat_dim)) nn.init.xavier_uniform_(self.weight) def forward(self, features, labels, quality_scores=None): weight_norm = F.normalize(self.weight, dim=1) feature_norm = F.normalize(features, dim=1) cosine = F.linear(feature_norm, weight_norm) if quality_scores is not None: adaptive_m = self.m * (1.0 + quality_scores) else: feature_norms = torch.norm(features, p=2, dim=1, keepdim=True) adaptive_m = self.m * (1.0 + torch.sigmoid(feature_norms - 1.0)) one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, labels.view(-1, 1), 1.0) theta = torch.acos(torch.clamp(cosine, -0.9999, 0.9999)) target_logit = torch.cos(theta + adaptive_m) output = one_hot * target_logit + (1.0 - one_hot) * cosine output *= self.s return F.cross_entropy(output, labels) class ConsistencyLoss(nn.Module): def __init__(self, temperature=0.07): super().__init__() self.temperature = temperature def forward(self, mu_a, mu_b): mu_a_norm = F.normalize(mu_a, dim=1) mu_b_norm = F.normalize(mu_b, dim=1) pos_sim = torch.sum(mu_a_norm * mu_b_norm, dim=1) / self.temperature return -pos_sim.mean() class DiscriminativeMiningLoss(nn.Module): def __init__(self, num_regions=64, momentum=0.9): super().__init__() self.register_buffer('running_var', torch.zeros(num_regions)) self.momentum = momentum def forward(self, attn_weights, features): B, C, H, W = features.shape N = H * W feats_flat = features.view(B, C, N) spatial_var = torch.var(feats_flat, dim=0, unbiased=False).mean(dim=0) if self.training: if self.running_var.shape[0] != N: self.running_var = torch.zeros(N, device=features.device) self.running_var = self.momentum * self.running_var + \ (1 - self.momentum) * spatial_var.detach() attn_map = attn_weights.mean(dim=1).squeeze(1) target_var = self.running_var if self.training else spatial_var return -torch.mean(attn_map * target_var) class DSIR_Enhanced_Loss(nn.Module): def __init__(self, num_classes, embedding_dim=512, beta_start=0.0, beta_target=1e-4, warmup_epochs=5): super(DSIR_Enhanced_Loss, self).__init__() self.angular_loss = AngularMarginLoss(num_classes, embedding_dim) self.consistency_loss = ConsistencyLoss() self.center_loss = EfficientCenterLoss(num_classes, embedding_dim) self.mining_loss = DiscriminativeMiningLoss() self.beta = beta_start self.beta_start = beta_start self.beta_target = beta_target self.warmup_epochs = warmup_epochs def update_beta(self, epoch): if epoch < self.warmup_epochs: self.beta = self.beta_start + (self.beta_target - self.beta_start) * (epoch / self.warmup_epochs) else: self.beta = self.beta_target def forward(self, output_dict, labels, output_dict_b=None): mu = output_dict['mu_final'] std = output_dict['std'] quality = output_dict['quality'] loss_id = self.angular_loss(mu, labels, quality_scores=quality) loss_center = self.center_loss(mu, labels) loss_mining = self.mining_loss(output_dict['cross_attn'], output_dict['spatial_feats']) loss_orth = torch.mean(torch.abs(torch.sum( F.normalize(output_dict['mu_prior'], dim=1) * F.normalize(output_dict['z_res'], dim=1), dim=1 ))) var = std.pow(2) kl_loss = -0.5 * torch.sum(1 + torch.log(var) - mu.pow(2) - var, dim=1).mean() loss_const = 0 if output_dict_b is not None: loss_const = self.consistency_loss(mu, output_dict_b['mu_final']) total_loss = loss_id + 0.01 * loss_center + self.beta * kl_loss + \ 0.5 * loss_const + 0.1 * loss_mining + 0.1 * loss_orth return total_loss, { 'id': loss_id.item(), 'kl': kl_loss.item(), 'orth': loss_orth.item() } # ========================================================================= # PART 5: Inference & Application Logic (DIVERSITY ENABLED) # ========================================================================= # Global Storage GALLERY = {} def select_diverse_embeddings(all_mus, all_stds, top_k=5): """ Smart Embedding Selection: Uses greedy farthest-first traversal to select the most diverse set of embeddings. """ if len(all_mus) <= top_k: return all_mus, all_stds # Convert to tensor mus_tensor = torch.stack(all_mus) n = len(mus_tensor) # Start with the embedding that has highest norm (often best quality) norms = torch.norm(mus_tensor, dim=1) first_idx = torch.argmax(norms).item() selected_indices = [first_idx] # Greedy selection while len(selected_indices) < top_k: max_dist = -1 best_candidate = -1 for i in range(n): if i in selected_indices: continue # Find distance to closest selected point min_dist_to_selected = float('inf') for s_idx in selected_indices: # Cosine distance d = 1.0 - F.cosine_similarity(mus_tensor[i:i+1], mus_tensor[s_idx:s_idx+1]).item() if d < min_dist_to_selected: min_dist_to_selected = d # Maximize the minimum distance (Farthest Point Sampling) if min_dist_to_selected > max_dist: max_dist = min_dist_to_selected best_candidate = i if best_candidate != -1: selected_indices.append(best_candidate) else: break selected_mus = [all_mus[i] for i in selected_indices] selected_stds = [all_stds[i] for i in selected_indices] return selected_mus, selected_stds def enroll_with_variations(model, device, files, name, num_variations=8): """ Enroll with 8 variations per image to capture full identity range. """ if not files or not name: return "Error: Missing files or name" transform_base = get_transforms(augment=False) all_mus = [] all_stds = [] print(f"\n=== Enrolling {name} with {num_variations} variations per image ===") try: for i, f in enumerate(files): path = f.name if hasattr(f, 'name') else f img = Image.open(path).convert('RGB') for v in range(num_variations): if v == 0: img_var = img else: # Severity increases with index img_var = create_synthetic_variation(img, severity=(v % 3) + 1) t = transform_base(img_var).unsqueeze(0).to(device) with torch.no_grad(): mu, std = model(t) # Quality filtering: skip very low norm embeddings (often blur/occlusion) if torch.norm(mu) > 5.0: all_mus.append(mu[0].cpu()) all_stds.append(std[0].cpu()) # Select best diverse embeddings selected_mus, selected_stds = select_diverse_embeddings(all_mus, all_stds, top_k=8) GALLERY[name] = { 'mus': selected_mus, 'stds': selected_stds, 'count': len(selected_mus) } return f"Enrolled '{name}' with {len(selected_mus)} diverse embeddings" except Exception as e: import traceback traceback.print_exc() return f"Error: {str(e)}" def recognize_with_variations(model, device, image): if image is None: return "No Image", None if not GALLERY: return "Gallery Empty", None transform_base = get_transforms(augment=False) try: t = transform_base(image).unsqueeze(0).to(device) with torch.no_grad(): p_mu, p_std = model(t) p_mu, p_std = p_mu[0].cpu(), p_std[0].cpu() res = {} print(f"\n=== Recognizing Probe ===") for name, data in GALLERY.items(): distances = [] for i in range(data['count']): gallery_mu = data['mus'][i] gallery_std = data['stds'][i] # Use stricter temperature=0.07 d = wasserstein_distance(p_mu, p_std, gallery_mu, gallery_std, temperature=0.07) distances.append(d.item()) # Statistical Scoring: Combine Best, Median and Average best_d = min(distances) median_d = np.median(distances) avg_d = np.mean(distances) # Weighted score favoring the best match but penalized by inconsistency final_score = (0.6 * best_d) + (0.2 * median_d) + (0.2 * avg_d) res[name] = final_score print(f" > {name}: Score={final_score:.4f} (Best={best_d:.4f})") sorted_res = dict(sorted(res.items(), key=lambda x: x[1])) if not sorted_res: return "No Matches Found", None best_name = list(sorted_res.keys())[0] best_score = sorted_res[best_name] # FIXED THRESHOLD: 8.0 (due to low temperature 0.07 scaling) threshold = 8.0 if best_score < threshold: if len(sorted_res) > 1: second_score = list(sorted_res.values())[1] margin = second_score - best_score if margin > 2.0: # Margin scaled for temp=0.07 conf = "HIGH CONFIDENCE" else: conf = "MEDIUM CONFIDENCE" else: conf = "SINGLE IDENTITY" msg = f"✅ MATCH: {best_name}\nScore: {best_score:.4f} ({conf})" else: msg = f"❌ UNKNOWN PERSON\n(Best: {best_name}, Score: {best_score:.4f})" return msg, sorted_res except Exception as e: return f"Error: {str(e)}", None def check_gallery_health(): """Check if gallery embeddings are diverse enough""" print("\n=== Gallery Health Check ===") if not GALLERY: print("Gallery is empty.") return for name, data in GALLERY.items(): if data['count'] < 3: print(f"⚠️ {name}: Only {data['count']} embedding(s) - Suggest Re-enrollment") continue mus = data['mus'] sims = [] for i in range(len(mus)): for j in range(i+1, len(mus)): sim = F.cosine_similarity(mus[i].unsqueeze(0), mus[j].unsqueeze(0)).item() sims.append(sim) avg_sim = np.mean(sims) if sims else 1.0 health_score = 1.0 - avg_sim # Higher is better (more diversity) print(f"{name}: Diversity Score = {health_score:.3f}") if avg_sim > 0.99: print(f" ⚠️ COLLAPSE RISK: Embeddings identical. Re-enroll with variation.") elif avg_sim < 0.85: print(f" ✅ HEALTHY: Good internal variation.") def check_model_health(model, device): """Check if model is producing diverse embeddings""" print("\n[System] Running Model Health Check...") model.eval() # Check 1: Random Noise Separation noise1 = torch.randn(1, 3, 160, 160).to(device) noise2 = torch.randn(1, 3, 160, 160).to(device) with torch.no_grad(): mu1, std1 = model(noise1) mu2, std2 = model(noise2) # Cosine Sim similarity = F.cosine_similarity(mu1, mu2).item() # Variance check std_mean = std1.mean().item() print(f" > Noise Similarity: {similarity:.4f} (Should be < 0.1)") print(f" > Latent Std Mean: {std_mean:.4f} (Should be ~0.8-1.2)") if similarity > 0.9: print(" !! CRITICAL FAILURE !! Model has collapsed (Outputs identical).") print(" -> Reinitializing last layer...") nn.init.xavier_uniform_(model.final_project.weight) elif std_mean < 0.1: print(" !! WARNING !! Variance collapse detected.") else: print(" > PASS: Model Healthy.") print("------------------------------------------------\n") # WRAPPERS FOR APP.PY COMPATIBILITY def process_and_enroll(model, device, files, name): return enroll_with_variations(model, device, files, name, num_variations=8) def recognize(model, device, image): return recognize_with_variations(model, device, image) def precision_weighted_fusion(mu_list, std_list): if len(mu_list) == 0: return None, None mus = torch.stack(mu_list) stds = torch.stack(std_list) fused_mu = torch.mean(mus, dim=0) fused_mu = F.normalize(fused_mu, p=2, dim=0) fused_std = torch.mean(stds, dim=0) return fused_mu, fused_std def test_time_augmentation(model, image, n_augments=5): device = next(model.parameters()).device transform_base = get_transforms(augment=False) mu_list = [] std_list = [] img_t = transform_base(image).unsqueeze(0).to(device) with torch.no_grad(): mu, std = model(img_t) mu_list.append(mu[0].cpu()) std_list.append(std[0].cpu()) for _ in range(n_augments - 1): aug_img = image.copy() if random.random() > 0.5: aug_img = T.functional.hflip(aug_img) aug_img = T.ColorJitter(brightness=0.1, contrast=0.1)(aug_img) img_t = transform_base(aug_img).unsqueeze(0).to(device) with torch.no_grad(): mu, std = model(img_t) mu_list.append(mu[0].cpu()) std_list.append(std[0].cpu()) return precision_weighted_fusion(mu_list, std_list) __all__ = [ 'DSIR_VIB', 'wasserstein_distance', 'precision_weighted_fusion', 'get_transforms', 'test_time_augmentation', 'process_and_enroll', 'recognize', 'check_gallery_health', 'check_model_health', 'GALLERY' ] if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Running on {device}") model = DSIR_VIB(latent_dim=512).to(device) model.eval() check_model_health(model, device)