File size: 15,332 Bytes
df4a21a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 | """
Wrapper for Gradient Field CNN submodel.
"""
import json
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from PIL import Image
from torchvision import transforms
from app.core.errors import InferenceError, ConfigurationError
from app.core.logging import get_logger
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
from app.services.explainability import heatmap_to_base64, compute_focus_summary
logger = get_logger(__name__)
class CompactGradientNet(nn.Module):
"""
CNN for gradient field classification with discriminative features.
Input: Luminance image (1-channel)
Internal: Computes 6-channel gradient field [luminance, Gx, Gy, magnitude, angle, coherence]
Output: Logits and embeddings
"""
def __init__(self, depth=4, base_filters=32, dropout=0.3, embedding_dim=128):
super().__init__()
# Sobel kernels
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
dtype=torch.float32).view(1, 1, 3, 3)
self.register_buffer('sobel_x', sobel_x)
self.register_buffer('sobel_y', sobel_y)
# Gaussian kernel for structure tensor smoothing
gaussian = torch.tensor([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4],
[6, 24, 36, 24, 6], [4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]], dtype=torch.float32) / 256.0
self.register_buffer('gaussian', gaussian.view(1, 1, 5, 5))
# Input normalization and channel mixing
self.input_norm = nn.BatchNorm2d(6)
self.channel_mix = nn.Sequential(
nn.Conv2d(6, 6, kernel_size=1),
nn.ReLU()
)
# CNN layers
layers = []
in_ch = 6
for i in range(depth):
out_ch = base_filters * (2**i)
layers.extend([
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.MaxPool2d(2)
])
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
in_ch = out_ch
self.cnn = nn.Sequential(*layers)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.embedding = nn.Linear(out_ch, embedding_dim)
self.classifier = nn.Linear(embedding_dim, 1)
def compute_gradient_field(self, luminance):
"""Compute 6-channel gradient field on GPU (includes luminance)."""
G_x = F.conv2d(luminance, self.sobel_x, padding=1)
G_y = F.conv2d(luminance, self.sobel_y, padding=1)
magnitude = torch.sqrt(G_x**2 + G_y**2 + 1e-8)
angle = torch.atan2(G_y, G_x) / math.pi
# Structure tensor for coherence
Gxx, Gxy, Gyy = G_x * G_x, G_x * G_y, G_y * G_y
Sxx = F.conv2d(Gxx, self.gaussian, padding=2)
Sxy = F.conv2d(Gxy, self.gaussian, padding=2)
Syy = F.conv2d(Gyy, self.gaussian, padding=2)
trace = Sxx + Syy
det_term = torch.sqrt((Sxx - Syy)**2 + 4 * Sxy**2 + 1e-8)
lambda1, lambda2 = 0.5 * (trace + det_term), 0.5 * (trace - det_term)
coherence = ((lambda1 - lambda2) / (lambda1 + lambda2 + 1e-8))**2
magnitude_scaled = torch.log1p(magnitude * 10)
return torch.cat([luminance, G_x, G_y, magnitude_scaled, angle, coherence], dim=1)
def forward(self, luminance):
x = self.compute_gradient_field(luminance)
x = self.input_norm(x)
x = self.channel_mix(x)
x = self.cnn(x)
x = self.global_pool(x).flatten(1)
emb = self.embedding(x)
logit = self.classifier(emb)
return logit.squeeze(1), emb
class GradfieldCNNWrapper(BaseSubmodelWrapper):
"""
Wrapper for Gradient Field CNN model.
Model expects 256x256 luminance images.
Internally computes Sobel gradients and other discriminative features.
"""
# BT.709 luminance coefficients
R_COEFF = 0.2126
G_COEFF = 0.7152
B_COEFF = 0.0722
def __init__(
self,
repo_id: str,
config: Dict[str, Any],
local_path: str
):
super().__init__(repo_id, config, local_path)
self._model: Optional[nn.Module] = None
self._resize: Optional[transforms.Resize] = None
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._threshold = config.get("threshold", 0.5)
logger.info(f"Initialized GradfieldCNNWrapper for {repo_id}")
def load(self) -> None:
"""Load the Gradient Field CNN model with trained weights."""
# Try different weight file names
weights_path = None
for fname in ["gradient_field_cnn_v3_finetuned.pth", "gradient_field_cnn_v2.pth", "weights.pt", "model.pth"]:
candidate = Path(self.local_path) / fname
if candidate.exists():
weights_path = candidate
break
preprocess_path = Path(self.local_path) / "preprocess.json"
if weights_path is None:
raise ConfigurationError(
message=f"No weights file found in {self.local_path}",
details={"repo_id": self.repo_id}
)
try:
# Load preprocessing config
preprocess_config = {}
if preprocess_path.exists():
with open(preprocess_path, "r") as f:
preprocess_config = json.load(f)
# Get input size (default 256 for gradient field)
input_size = preprocess_config.get("input_size", 256)
if isinstance(input_size, list):
input_size = input_size[0]
self._resize = transforms.Resize((input_size, input_size))
# Get model parameters from config
model_params = self.config.get("model_parameters", {})
depth = model_params.get("depth", 4)
base_filters = model_params.get("base_filters", 32)
dropout = model_params.get("dropout", 0.3)
embedding_dim = model_params.get("embedding_dim", 128)
# Create model
self._model = CompactGradientNet(
depth=depth,
base_filters=base_filters,
dropout=dropout,
embedding_dim=embedding_dim
)
# Load trained weights
# Note: weights_only=False needed because checkpoint contains numpy types
state_dict = torch.load(weights_path, map_location=self._device, weights_only=False)
# Handle different checkpoint formats
if isinstance(state_dict, dict):
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "model" in state_dict:
state_dict = state_dict["model"]
self._model.load_state_dict(state_dict)
self._model.to(self._device)
self._model.eval()
# Mark as loaded
self._predict_fn = self._run_inference
logger.info(f"Loaded Gradient Field CNN model from {self.repo_id}")
except ConfigurationError:
raise
except Exception as e:
logger.error(f"Failed to load Gradient Field CNN model: {e}")
raise ConfigurationError(
message=f"Failed to load model: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
def _rgb_to_luminance(self, img_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert RGB tensor to luminance using BT.709 coefficients.
Args:
img_tensor: RGB tensor of shape (3, H, W) with values in [0, 1]
Returns:
Luminance tensor of shape (1, H, W)
"""
luminance = (
self.R_COEFF * img_tensor[0] +
self.G_COEFF * img_tensor[1] +
self.B_COEFF * img_tensor[2]
)
return luminance.unsqueeze(0)
def _run_inference(
self,
luminance_tensor: torch.Tensor,
explain: bool = False
) -> Dict[str, Any]:
"""Run model inference on preprocessed luminance tensor."""
heatmap = None
if explain:
# Custom GradCAM implementation for single-logit binary model
# Using absolute CAM values to capture both positive and negative contributions
# Target the last Conv2d layer (cnn[-5])
target_layer = self._model.cnn[-5]
activations = None
gradients = None
def forward_hook(module, input, output):
nonlocal activations
activations = output.detach()
def backward_hook(module, grad_input, grad_output):
nonlocal gradients
gradients = grad_output[0].detach()
h_fwd = target_layer.register_forward_hook(forward_hook)
h_bwd = target_layer.register_full_backward_hook(backward_hook)
try:
# Forward pass with gradients
input_tensor = luminance_tensor.clone().requires_grad_(True)
logits, embedding = self._model(input_tensor)
prob_fake = torch.sigmoid(logits).item()
pred_int = 1 if prob_fake >= self._threshold else 0
# Backward pass
self._model.zero_grad()
logits.backward()
if gradients is not None and activations is not None:
# Compute Grad-CAM weights (global average pooled gradients)
weights = gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
# Weighted combination of activation maps
cam = (weights * activations).sum(dim=1, keepdim=True) # [1, 1, H, W]
# Use absolute values instead of ReLU to capture all contributions
# This is important for models where negative gradients carry meaning
cam = torch.abs(cam)
# Normalize to [0, 1]
cam = cam - cam.min()
cam_max = cam.max()
if cam_max > 0:
cam = cam / cam_max
# Resize to output size (256x256)
cam = F.interpolate(
cam,
size=(256, 256),
mode='bilinear',
align_corners=False
)
heatmap = cam.squeeze().cpu().numpy()
else:
logger.warning("GradCAM: gradients or activations not captured")
heatmap = np.zeros((256, 256), dtype=np.float32)
finally:
h_fwd.remove()
h_bwd.remove()
else:
with torch.no_grad():
logits, embedding = self._model(luminance_tensor)
prob_fake = torch.sigmoid(logits).item()
pred_int = 1 if prob_fake >= self._threshold else 0
result = {
"logits": logits.detach().cpu().numpy().tolist() if hasattr(logits, 'detach') else logits.cpu().numpy().tolist(),
"prob_fake": prob_fake,
"pred_int": pred_int,
"embedding": embedding.detach().cpu().numpy().tolist() if explain else embedding.cpu().numpy().tolist()
}
if heatmap is not None:
result["heatmap"] = heatmap
return result
def predict(
self,
image: Optional[Image.Image] = None,
image_bytes: Optional[bytes] = None,
explain: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Run prediction on an image.
Args:
image: PIL Image object
image_bytes: Raw image bytes (will be converted to PIL Image)
explain: If True, compute GradCAM heatmap
Returns:
Standardized prediction dictionary with optional heatmap
"""
if self._model is None or self._resize is None:
raise InferenceError(
message="Model not loaded",
details={"repo_id": self.repo_id}
)
try:
# Convert bytes to PIL Image if needed
if image is None and image_bytes is not None:
import io
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif image is not None:
image = image.convert("RGB")
else:
raise InferenceError(
message="No image provided",
details={"repo_id": self.repo_id}
)
# Resize
image = self._resize(image)
# Convert to tensor
img_tensor = transforms.functional.to_tensor(image)
# Convert to luminance
luminance = self._rgb_to_luminance(img_tensor)
luminance = luminance.unsqueeze(0).to(self._device) # Add batch dim
# Run inference
result = self._run_inference(luminance, explain=explain)
# Standardize output
labels = self.config.get("labels", {"0": "real", "1": "fake"})
pred_int = result["pred_int"]
output = {
"pred_int": pred_int,
"pred": labels.get(str(pred_int), "unknown"),
"prob_fake": result["prob_fake"],
"meta": {
"model": self.name,
"threshold": self._threshold
}
}
# Add heatmap if requested
if explain and "heatmap" in result:
heatmap = result["heatmap"]
output["heatmap_base64"] = heatmap_to_base64(heatmap)
output["explainability_type"] = "grad_cam"
output["focus_summary"] = compute_focus_summary(heatmap) + " (edge-based analysis)"
return output
except InferenceError:
raise
except Exception as e:
logger.error(f"Prediction failed for {self.repo_id}: {e}")
raise InferenceError(
message=f"Prediction failed: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
|