arm-model / model /cnn_bigru.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
CNN-BiGRU Temporal Severity Prediction Module
Based on: Nature Scientific Reports (Nov 2025) - YOLOv11 + CNN-BiGRU
Architecture:
CNN Feature Extractor β†’ Bidirectional GRU β†’ Dense Classifier
The CNN extracts spatial features from cropped anomaly patches.
The BiGRU models temporal relationships across consecutive video frames,
enabling severity-level prediction that accounts for anomaly progression.
Severity Classes:
0 = Minor (cosmetic, no immediate repair needed)
1 = Moderate (scheduled maintenance)
2 = Severe (urgent repair required)
3 = Critical (immediate hazard, road closure recommended)
"""
import os
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("CNN-BiGRU")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SEVERITY_LABELS = ["Minor", "Moderate", "Severe", "Critical"]
ANOMALY_CLASSES = ["Alligator Crack", "Longitudinal Crack", "Pothole", "Transverse Crack"]
# Default CNN input patch size
PATCH_SIZE = 64
# ═══════════════════════════════════════════════════════════════════════════
# CNN Feature Extractor
# ═══════════════════════════════════════════════════════════════════════════
class CNNFeatureExtractor(nn.Module):
"""
Lightweight CNN that maps a (C, H, W) patch to a fixed-length
feature vector. Three conv blocks followed by global average pool.
Output dimension: 256
"""
def __init__(self, in_channels: int = 3):
super().__init__()
self.features = nn.Sequential(
# Block 1: 3 β†’ 64
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 2: 64 β†’ 128
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 3: 128 β†’ 256
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.out_dim = 256
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, C, H, W) β†’ (B, 256)"""
return self.features(x).view(x.size(0), -1)
# ═══════════════════════════════════════════════════════════════════════════
# CNN-BiGRU Model
# ═══════════════════════════════════════════════════════════════════════════
class CNNBiGRU(nn.Module):
"""
Hybrid CNN + Bidirectional GRU for temporal severity prediction.
Input : a sequence of image patches (B, T, C, H, W)
– T consecutive frames of the same anomaly region
Output: severity class logits (B, num_severity_classes)
Pipeline:
1. CNN extracts a 256-d feature vector per frame
2. BiGRU processes the sequence β†’ 2Γ—hidden_size output
3. Dense head maps to severity classes
"""
def __init__(
self,
in_channels: int = 3,
hidden_size: int = 128,
num_gru_layers: int = 2,
gru_dropout: float = 0.3,
num_severity_classes: int = 4,
fc_dropout: float = 0.5,
):
super().__init__()
self.hidden_size = hidden_size
self.num_severity_classes = num_severity_classes
# --- CNN backbone ---
self.cnn = CNNFeatureExtractor(in_channels)
# --- Bidirectional GRU ---
self.bigru = nn.GRU(
input_size=self.cnn.out_dim, # 256
hidden_size=hidden_size, # 128
num_layers=num_gru_layers, # 2
batch_first=True,
bidirectional=True,
dropout=gru_dropout if num_gru_layers > 1 else 0.0,
)
# --- Classification head ---
self.classifier = nn.Sequential(
nn.Linear(hidden_size * 2, 128), # bidirectional β†’ 256
nn.ReLU(inplace=True),
nn.Dropout(fc_dropout),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(fc_dropout * 0.5),
nn.Linear(64, num_severity_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, T, C, H, W) – batch of temporal patch sequences.
Returns:
logits: (B, num_severity_classes)
"""
B, T, C, H, W = x.size()
# 1. Extract CNN features for every frame in the sequence
x_flat = x.view(B * T, C, H, W) # (B*T, C, H, W)
feats = self.cnn(x_flat) # (B*T, 256)
feats = feats.view(B, T, -1) # (B, T, 256)
# 2. BiGRU temporal modelling
gru_out, _ = self.bigru(feats) # (B, T, 2*hidden)
# 3. Take output at last time-step
last = gru_out[:, -1, :] # (B, 2*hidden)
# 4. Classify severity
logits = self.classifier(last) # (B, num_classes)
return logits
def predict_severity(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convenience wrapper: returns (predicted_class, probabilities).
"""
self.eval()
with torch.no_grad():
logits = self.forward(x)
probs = F.softmax(logits, dim=-1)
preds = probs.argmax(dim=-1)
return preds, probs
# ------------------------------------------------------------------
# Attention-weighted variant (optional enhancement)
# ------------------------------------------------------------------
def forward_with_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Same as forward() but uses attention over all time-steps
instead of only the last one.
Returns:
logits: (B, num_severity_classes)
attn_wts: (B, T) – attention weights for interpretability
"""
B, T, C, H, W = x.size()
x_flat = x.view(B * T, C, H, W)
feats = self.cnn(x_flat).view(B, T, -1)
gru_out, _ = self.bigru(feats) # (B, T, 2*hidden)
# Simple additive attention
energy = torch.tanh(gru_out) # (B, T, 2*hidden)
energy = energy.sum(dim=-1) # (B, T)
attn_wts = F.softmax(energy, dim=-1) # (B, T)
# Weighted combination
context = torch.bmm(
attn_wts.unsqueeze(1), gru_out # (B,1,T) x (B,T,2H) β†’ (B,1,2H)
).squeeze(1) # (B, 2*hidden)
logits = self.classifier(context)
return logits, attn_wts
# ═══════════════════════════════════════════════════════════════════════════
# Single-Frame Severity Classifier (when temporal data unavailable)
# ═══════════════════════════════════════════════════════════════════════════
class CNNSeverityClassifier(nn.Module):
"""
Fallback: classify severity from a *single* cropped patch
(no temporal context). Uses only the CNN + dense head.
"""
def __init__(
self,
in_channels: int = 3,
num_severity_classes: int = 4,
dropout: float = 0.5,
):
super().__init__()
self.cnn = CNNFeatureExtractor(in_channels)
self.head = nn.Sequential(
nn.Linear(self.cnn.out_dim, 128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(128, num_severity_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, C, H, W) β†’ logits (B, num_severity_classes)"""
return self.head(self.cnn(x))
# ═══════════════════════════════════════════════════════════════════════════
# Dataset for CNN-BiGRU training
# ═══════════════════════════════════════════════════════════════════════════
class AnomalySequenceDataset(Dataset):
"""
Loads temporal sequences of cropped anomaly patches + severity labels.
Expected directory layout:
root/
sequences/
seq_0000/ ← one tracked anomaly
frame_00.jpg
frame_01.jpg
…
seq_0001/
…
labels.csv ← columns: seq_id, severity (0-3)
If no labels.csv exists, severity is inferred from the anomaly-class
heuristic in `SEVERITY_WEIGHTS`.
"""
def __init__(
self,
root: str,
seq_len: int = 8,
patch_size: int = PATCH_SIZE,
transform=None,
):
self.root = Path(root)
self.seq_len = seq_len
self.patch_size = patch_size
self.transform = transform
# Discover sequences
seq_dir = self.root / "sequences"
if not seq_dir.exists():
raise FileNotFoundError(f"No 'sequences' folder in {self.root}")
self.sequences = sorted([d for d in seq_dir.iterdir() if d.is_dir()])
# Load labels
label_file = self.root / "labels.csv"
self.labels: Dict[str, int] = {}
if label_file.exists():
import csv
with open(label_file) as f:
reader = csv.DictReader(f)
for row in reader:
self.labels[row["seq_id"]] = int(row["severity"])
else:
logger.warning("labels.csv not found – all severities default to 0")
logger.info("Loaded %d sequences from %s", len(self.sequences), self.root)
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
seq_path = self.sequences[idx]
frames_paths = sorted(seq_path.glob("*.jpg")) + sorted(seq_path.glob("*.png"))
# Sample / pad to fixed seq_len
if len(frames_paths) >= self.seq_len:
indices = np.linspace(0, len(frames_paths) - 1, self.seq_len, dtype=int)
frames_paths = [frames_paths[i] for i in indices]
else:
# Repeat last frame to pad
while len(frames_paths) < self.seq_len:
frames_paths.append(frames_paths[-1])
# Load & preprocess
import cv2
frames: List[np.ndarray] = []
for fp in frames_paths:
img = cv2.imread(str(fp))
img = cv2.resize(img, (self.patch_size, self.patch_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
if self.transform:
img = self.transform(img)
frames.append(img)
# Stack β†’ (T, H, W, C) β†’ (T, C, H, W)
seq_np = np.stack(frames)
seq_tensor = torch.from_numpy(seq_np).permute(0, 3, 1, 2) # (T, C, H, W)
severity = self.labels.get(seq_path.name, 0)
return seq_tensor, severity
# ═══════════════════════════════════════════════════════════════════════════
# Training utilities
# ═══════════════════════════════════════════════════════════════════════════
class BiGRUTrainer:
"""
End-to-end trainer for the CNN-BiGRU severity model.
"""
def __init__(
self,
model: CNNBiGRU,
lr: float = 1e-3,
weight_decay: float = 1e-4,
device: Optional[str] = None,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=lr, weight_decay=weight_decay,
)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=50, eta_min=1e-6,
)
def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.train()
total_loss = 0.0
correct = 0
total = 0
for batch_idx, (sequences, labels) in enumerate(dataloader):
sequences = sequences.to(self.device) # (B, T, C, H, W)
labels = labels.to(self.device).long() # (B,)
self.optimizer.zero_grad()
logits = self.model(sequences) # (B, num_classes)
loss = self.criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
total_loss += loss.item() * sequences.size(0)
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
self.scheduler.step()
avg_loss = total_loss / max(total, 1)
accuracy = correct / max(total, 1)
return {"loss": avg_loss, "accuracy": accuracy}
@torch.no_grad()
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
all_preds: List[int] = []
all_labels: List[int] = []
for sequences, labels in dataloader:
sequences = sequences.to(self.device)
labels = labels.to(self.device).long()
logits = self.model(sequences)
loss = self.criterion(logits, labels)
total_loss += loss.item() * sequences.size(0)
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(labels.cpu().tolist())
avg_loss = total_loss / max(total, 1)
accuracy = correct / max(total, 1)
return {
"loss": avg_loss,
"accuracy": accuracy,
"predictions": all_preds,
"labels": all_labels,
}
def fit(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
epochs: int = 50,
save_dir: str = "bigru_checkpoints",
patience: int = 10,
) -> Dict[str, List[float]]:
"""
Full training loop with early stopping.
"""
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
history: Dict[str, List[float]] = {
"train_loss": [], "train_acc": [],
"val_loss": [], "val_acc": [],
}
best_val_loss = float("inf")
no_improve = 0
for epoch in range(1, epochs + 1):
train_metrics = self.train_epoch(train_loader)
history["train_loss"].append(train_metrics["loss"])
history["train_acc"].append(train_metrics["accuracy"])
log_msg = (
f"Epoch {epoch:3d}/{epochs} "
f"train_loss={train_metrics['loss']:.4f} "
f"train_acc={train_metrics['accuracy']:.4f}"
)
if val_loader is not None:
val_metrics = self.evaluate(val_loader)
history["val_loss"].append(val_metrics["loss"])
history["val_acc"].append(val_metrics["accuracy"])
log_msg += (
f" val_loss={val_metrics['loss']:.4f} "
f"val_acc={val_metrics['accuracy']:.4f}"
)
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
no_improve = 0
torch.save(self.model.state_dict(), save_path / "best_bigru.pth")
else:
no_improve += 1
if no_improve >= patience:
logger.info("Early stopping at epoch %d", epoch)
break
logger.info(log_msg)
# Periodic checkpoint
if epoch % 10 == 0:
torch.save(self.model.state_dict(), save_path / f"bigru_epoch{epoch}.pth")
# Save final
torch.save(self.model.state_dict(), save_path / "last_bigru.pth")
logger.info("Training complete – best val_loss=%.4f", best_val_loss)
return history
def save(self, path: str):
torch.save(self.model.state_dict(), path)
logger.info("Model saved β†’ %s", path)
def load(self, path: str):
self.model.load_state_dict(torch.load(path, map_location=self.device))
logger.info("Model loaded ← %s", path)
# ═══════════════════════════════════════════════════════════════════════════
# Quick self-test
# ═══════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print("=" * 60)
print("CNN-BiGRU Module – Smoke Test")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Test CNNBiGRU ---
model = CNNBiGRU(
in_channels=3,
hidden_size=128,
num_gru_layers=2,
num_severity_classes=4,
).to(device)
# Synthetic input: batch=2, seq_len=8, C=3, H=64, W=64
dummy = torch.randn(2, 8, 3, PATCH_SIZE, PATCH_SIZE, device=device)
logits = model(dummy)
print(f"CNNBiGRU output shape : {logits.shape}") # (2, 4)
preds, probs = model.predict_severity(dummy)
print(f"Predicted classes : {preds.tolist()}")
print(f"Probabilities : {probs.cpu().numpy().round(3)}")
# --- Test attention variant ---
logits_att, attn = model.forward_with_attention(dummy)
print(f"Attention weights : {attn.cpu().numpy().round(3)}")
# --- Test single-frame classifier ---
clf = CNNSeverityClassifier(num_severity_classes=4).to(device)
single = torch.randn(2, 3, PATCH_SIZE, PATCH_SIZE, device=device)
out = clf(single)
print(f"SingleFrame output : {out.shape}") # (2, 4)
# --- Parameter counts ---
total_params = sum(p.numel() for p in model.parameters())
print(f"\nCNN-BiGRU parameters : {total_params:,}")
print("βœ… All smoke tests passed")