| """ |
| branches/vit_branch.py |
| ----------------------- |
| Branch 4: Vision Transformer (ViT) Global Semantic Branch |
| STATUS: BASELINE — ViT-B/16 via PyTorch + timm |
| - Loads weights from models/vit_branch.pth if available |
| - Falls back to neutral 0.5 confidence if untrained or timm missing |
| |
| Trained on CIFAKE: 99.30% validation accuracy. |
| """ |
|
|
| import numpy as np |
| import os |
| from pathlib import Path |
|
|
| MODEL_PATH = Path(__file__).parent.parent / "models" / "vit_branch.pth" |
|
|
| _model = None |
| _is_trained = False |
| _device = None |
|
|
|
|
| def _build_vit_model(): |
| """Build ViT-B/16 binary classifier using timm.""" |
| import torch |
| import torch.nn as nn |
| import timm |
|
|
| class ViTForensicClassifier(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.backbone = timm.create_model( |
| "vit_base_patch16_224", |
| pretrained=False, |
| num_classes=0, |
| global_pool="token", |
| ) |
| self.head = nn.Sequential( |
| nn.Linear(768, 256), |
| nn.GELU(), |
| nn.Dropout(0.4), |
| nn.Linear(256, 2), |
| ) |
|
|
| def forward(self, x): |
| features = self.backbone(x) |
| return self.head(features) |
|
|
| return ViTForensicClassifier() |
|
|
|
|
| def _load_vit_model(): |
| """Load or build the ViT model. Cached globally.""" |
| global _model, _is_trained, _device |
| if _model is not None: |
| return _model, _is_trained |
|
|
| try: |
| import torch |
| except ImportError: |
| print("[ViT Branch] ⚠ PyTorch not installed.") |
| return None, False |
|
|
| try: |
| import timm |
| except ImportError: |
| print("[ViT Branch] ⚠ timm not installed. Run: pip install timm") |
| return None, False |
|
|
| _device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| try: |
| _model = _build_vit_model().to(_device) |
| _model.eval() |
| except Exception as e: |
| print(f"[ViT Branch] ⚠ Failed to build model: {e}") |
| _model = None |
| return None, False |
|
|
| if MODEL_PATH.exists(): |
| try: |
| state = torch.load(str(MODEL_PATH), map_location=_device) |
| _model.load_state_dict(state) |
| _is_trained = True |
| print(f"[ViT Branch] ✓ Loaded trained weights from {MODEL_PATH}") |
| except Exception as e: |
| print(f"[ViT Branch] ⚠ Failed to load weights: {e}") |
| _is_trained = False |
| else: |
| print(f"[ViT Branch] ℹ No trained weights at {MODEL_PATH}. Returning neutral.") |
| _is_trained = False |
|
|
| return _model, _is_trained |
|
|
|
|
| def run_vit_branch(img: np.ndarray) -> dict: |
| """ |
| Run the ViT Global Semantic Branch. |
| |
| Args: |
| img : float32 numpy array (H, W, 3) in [0, 1] |
| |
| Returns: |
| dict with prob_fake, confidence, attn_weights, model_loaded |
| """ |
| _FALLBACK = {"prob_fake": 0.5, "confidence": 0.0, |
| "attn_weights": None, "model_loaded": False} |
|
|
| try: |
| import torch |
| import torch.nn.functional as F |
| except ImportError: |
| return _FALLBACK |
|
|
| model, is_trained = _load_vit_model() |
|
|
| if model is None or not is_trained: |
| return _FALLBACK |
|
|
| device = _device or "cpu" |
|
|
| |
| MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| img_norm = (img - MEAN) / STD |
|
|
| import cv2 |
| if img_norm.shape[:2] != (224, 224): |
| img_norm = cv2.resize(img_norm, (224, 224), interpolation=cv2.INTER_AREA) |
|
|
| tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).float().to(device) |
|
|
| try: |
| with torch.no_grad(): |
| logits = model(tensor) |
| probs = F.softmax(logits, dim=-1) |
| prob_fake = float(probs[0, 1].cpu().numpy()) |
| except Exception as e: |
| print(f"[ViT Branch] ⚠ Inference error: {e}") |
| return _FALLBACK |
|
|
| confidence = float(np.clip(abs(prob_fake - 0.5) * 2.0, 0.1, 0.98)) |
|
|
| return { |
| "prob_fake": prob_fake, |
| "confidence": confidence, |
| "attn_weights": None, |
| "model_loaded": True, |
| } |
|
|