dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
branches/vit_branch.py
-----------------------
Branch 4: Vision Transformer (ViT) Global Semantic Branch
STATUS: BASELINE — ViT-B/16 via PyTorch + timm
- Loads weights from models/vit_branch.pth if available
- Falls back to neutral 0.5 confidence if untrained or timm missing
Trained on CIFAKE: 99.30% validation accuracy.
"""
import numpy as np
import os
from pathlib import Path
MODEL_PATH = Path(__file__).parent.parent / "models" / "vit_branch.pth"
_model = None
_is_trained = False
_device = None
def _build_vit_model():
"""Build ViT-B/16 binary classifier using timm."""
import torch
import torch.nn as nn
import timm
class ViTForensicClassifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model(
"vit_base_patch16_224",
pretrained=False, # Don't re-download pretrained for inference
num_classes=0,
global_pool="token",
)
self.head = nn.Sequential(
nn.Linear(768, 256),
nn.GELU(),
nn.Dropout(0.4),
nn.Linear(256, 2),
)
def forward(self, x):
features = self.backbone(x)
return self.head(features)
return ViTForensicClassifier()
def _load_vit_model():
"""Load or build the ViT model. Cached globally."""
global _model, _is_trained, _device
if _model is not None:
return _model, _is_trained
try:
import torch
except ImportError:
print("[ViT Branch] ⚠ PyTorch not installed.")
return None, False
try:
import timm
except ImportError:
print("[ViT Branch] ⚠ timm not installed. Run: pip install timm")
return None, False
_device = "cuda" if torch.cuda.is_available() else "cpu"
try:
_model = _build_vit_model().to(_device)
_model.eval()
except Exception as e:
print(f"[ViT Branch] ⚠ Failed to build model: {e}")
_model = None
return None, False
if MODEL_PATH.exists():
try:
state = torch.load(str(MODEL_PATH), map_location=_device)
_model.load_state_dict(state)
_is_trained = True
print(f"[ViT Branch] ✓ Loaded trained weights from {MODEL_PATH}")
except Exception as e:
print(f"[ViT Branch] ⚠ Failed to load weights: {e}")
_is_trained = False
else:
print(f"[ViT Branch] ℹ No trained weights at {MODEL_PATH}. Returning neutral.")
_is_trained = False
return _model, _is_trained
def run_vit_branch(img: np.ndarray) -> dict:
"""
Run the ViT Global Semantic Branch.
Args:
img : float32 numpy array (H, W, 3) in [0, 1]
Returns:
dict with prob_fake, confidence, attn_weights, model_loaded
"""
_FALLBACK = {"prob_fake": 0.5, "confidence": 0.0,
"attn_weights": None, "model_loaded": False}
try:
import torch
import torch.nn.functional as F
except ImportError:
return _FALLBACK
model, is_trained = _load_vit_model()
if model is None or not is_trained:
return _FALLBACK
device = _device or "cpu"
# ImageNet normalization
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img_norm = (img - MEAN) / STD
import cv2
if img_norm.shape[:2] != (224, 224):
img_norm = cv2.resize(img_norm, (224, 224), interpolation=cv2.INTER_AREA)
tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).float().to(device)
try:
with torch.no_grad():
logits = model(tensor) # (1, 2)
probs = F.softmax(logits, dim=-1) # (1, 2)
prob_fake = float(probs[0, 1].cpu().numpy())
except Exception as e:
print(f"[ViT Branch] ⚠ Inference error: {e}")
return _FALLBACK
confidence = float(np.clip(abs(prob_fake - 0.5) * 2.0, 0.1, 0.98))
return {
"prob_fake": prob_fake,
"confidence": confidence,
"attn_weights": None, # Simplified — no attention hook
"model_loaded": True,
}