| |
| """ |
| CNN-BiGRU Temporal Severity Prediction Module |
| Based on: Nature Scientific Reports (Nov 2025) - YOLOv11 + CNN-BiGRU |
| |
| Architecture: |
| CNN Feature Extractor β Bidirectional GRU β Dense Classifier |
| |
| The CNN extracts spatial features from cropped anomaly patches. |
| The BiGRU models temporal relationships across consecutive video frames, |
| enabling severity-level prediction that accounts for anomaly progression. |
| |
| Severity Classes: |
| 0 = Minor (cosmetic, no immediate repair needed) |
| 1 = Moderate (scheduled maintenance) |
| 2 = Severe (urgent repair required) |
| 3 = Critical (immediate hazard, road closure recommended) |
| """ |
|
|
| import os |
| import logging |
| from pathlib import Path |
| from typing import List, Optional, Tuple, Dict, Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
| ) |
| logger = logging.getLogger("CNN-BiGRU") |
|
|
| |
| |
| |
| SEVERITY_LABELS = ["Minor", "Moderate", "Severe", "Critical"] |
| ANOMALY_CLASSES = ["Alligator Crack", "Longitudinal Crack", "Pothole", "Transverse Crack"] |
|
|
| |
| PATCH_SIZE = 64 |
|
|
|
|
| |
| |
| |
| class CNNFeatureExtractor(nn.Module): |
| """ |
| Lightweight CNN that maps a (C, H, W) patch to a fixed-length |
| feature vector. Three conv blocks followed by global average pool. |
| |
| Output dimension: 256 |
| """ |
|
|
| def __init__(self, in_channels: int = 3): |
| super().__init__() |
| self.features = nn.Sequential( |
| |
| nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.Conv2d(128, 256, kernel_size=3, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| ) |
| self.out_dim = 256 |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (B, C, H, W) β (B, 256)""" |
| return self.features(x).view(x.size(0), -1) |
|
|
|
|
| |
| |
| |
| class CNNBiGRU(nn.Module): |
| """ |
| Hybrid CNN + Bidirectional GRU for temporal severity prediction. |
| |
| Input : a sequence of image patches (B, T, C, H, W) |
| β T consecutive frames of the same anomaly region |
| Output: severity class logits (B, num_severity_classes) |
| |
| Pipeline: |
| 1. CNN extracts a 256-d feature vector per frame |
| 2. BiGRU processes the sequence β 2Γhidden_size output |
| 3. Dense head maps to severity classes |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int = 3, |
| hidden_size: int = 128, |
| num_gru_layers: int = 2, |
| gru_dropout: float = 0.3, |
| num_severity_classes: int = 4, |
| fc_dropout: float = 0.5, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_severity_classes = num_severity_classes |
|
|
| |
| self.cnn = CNNFeatureExtractor(in_channels) |
|
|
| |
| self.bigru = nn.GRU( |
| input_size=self.cnn.out_dim, |
| hidden_size=hidden_size, |
| num_layers=num_gru_layers, |
| batch_first=True, |
| bidirectional=True, |
| dropout=gru_dropout if num_gru_layers > 1 else 0.0, |
| ) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_size * 2, 128), |
| nn.ReLU(inplace=True), |
| nn.Dropout(fc_dropout), |
| nn.Linear(128, 64), |
| nn.ReLU(inplace=True), |
| nn.Dropout(fc_dropout * 0.5), |
| nn.Linear(64, num_severity_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, T, C, H, W) β batch of temporal patch sequences. |
| |
| Returns: |
| logits: (B, num_severity_classes) |
| """ |
| B, T, C, H, W = x.size() |
|
|
| |
| x_flat = x.view(B * T, C, H, W) |
| feats = self.cnn(x_flat) |
| feats = feats.view(B, T, -1) |
|
|
| |
| gru_out, _ = self.bigru(feats) |
|
|
| |
| last = gru_out[:, -1, :] |
|
|
| |
| logits = self.classifier(last) |
| return logits |
|
|
| def predict_severity(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Convenience wrapper: returns (predicted_class, probabilities). |
| """ |
| self.eval() |
| with torch.no_grad(): |
| logits = self.forward(x) |
| probs = F.softmax(logits, dim=-1) |
| preds = probs.argmax(dim=-1) |
| return preds, probs |
|
|
| |
| |
| |
| def forward_with_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Same as forward() but uses attention over all time-steps |
| instead of only the last one. |
| |
| Returns: |
| logits: (B, num_severity_classes) |
| attn_wts: (B, T) β attention weights for interpretability |
| """ |
| B, T, C, H, W = x.size() |
|
|
| x_flat = x.view(B * T, C, H, W) |
| feats = self.cnn(x_flat).view(B, T, -1) |
|
|
| gru_out, _ = self.bigru(feats) |
|
|
| |
| energy = torch.tanh(gru_out) |
| energy = energy.sum(dim=-1) |
| attn_wts = F.softmax(energy, dim=-1) |
|
|
| |
| context = torch.bmm( |
| attn_wts.unsqueeze(1), gru_out |
| ).squeeze(1) |
|
|
| logits = self.classifier(context) |
| return logits, attn_wts |
|
|
|
|
| |
| |
| |
| class CNNSeverityClassifier(nn.Module): |
| """ |
| Fallback: classify severity from a *single* cropped patch |
| (no temporal context). Uses only the CNN + dense head. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int = 3, |
| num_severity_classes: int = 4, |
| dropout: float = 0.5, |
| ): |
| super().__init__() |
| self.cnn = CNNFeatureExtractor(in_channels) |
| self.head = nn.Sequential( |
| nn.Linear(self.cnn.out_dim, 128), |
| nn.ReLU(inplace=True), |
| nn.Dropout(dropout), |
| nn.Linear(128, num_severity_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (B, C, H, W) β logits (B, num_severity_classes)""" |
| return self.head(self.cnn(x)) |
|
|
|
|
| |
| |
| |
| class AnomalySequenceDataset(Dataset): |
| """ |
| Loads temporal sequences of cropped anomaly patches + severity labels. |
| |
| Expected directory layout: |
| root/ |
| sequences/ |
| seq_0000/ β one tracked anomaly |
| frame_00.jpg |
| frame_01.jpg |
| β¦ |
| seq_0001/ |
| β¦ |
| labels.csv β columns: seq_id, severity (0-3) |
| |
| If no labels.csv exists, severity is inferred from the anomaly-class |
| heuristic in `SEVERITY_WEIGHTS`. |
| """ |
|
|
| def __init__( |
| self, |
| root: str, |
| seq_len: int = 8, |
| patch_size: int = PATCH_SIZE, |
| transform=None, |
| ): |
| self.root = Path(root) |
| self.seq_len = seq_len |
| self.patch_size = patch_size |
| self.transform = transform |
|
|
| |
| seq_dir = self.root / "sequences" |
| if not seq_dir.exists(): |
| raise FileNotFoundError(f"No 'sequences' folder in {self.root}") |
|
|
| self.sequences = sorted([d for d in seq_dir.iterdir() if d.is_dir()]) |
|
|
| |
| label_file = self.root / "labels.csv" |
| self.labels: Dict[str, int] = {} |
| if label_file.exists(): |
| import csv |
| with open(label_file) as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| self.labels[row["seq_id"]] = int(row["severity"]) |
| else: |
| logger.warning("labels.csv not found β all severities default to 0") |
|
|
| logger.info("Loaded %d sequences from %s", len(self.sequences), self.root) |
|
|
| def __len__(self) -> int: |
| return len(self.sequences) |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: |
| seq_path = self.sequences[idx] |
| frames_paths = sorted(seq_path.glob("*.jpg")) + sorted(seq_path.glob("*.png")) |
|
|
| |
| if len(frames_paths) >= self.seq_len: |
| indices = np.linspace(0, len(frames_paths) - 1, self.seq_len, dtype=int) |
| frames_paths = [frames_paths[i] for i in indices] |
| else: |
| |
| while len(frames_paths) < self.seq_len: |
| frames_paths.append(frames_paths[-1]) |
|
|
| |
| import cv2 |
| frames: List[np.ndarray] = [] |
| for fp in frames_paths: |
| img = cv2.imread(str(fp)) |
| img = cv2.resize(img, (self.patch_size, self.patch_size)) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = img.astype(np.float32) / 255.0 |
| if self.transform: |
| img = self.transform(img) |
| frames.append(img) |
|
|
| |
| seq_np = np.stack(frames) |
| seq_tensor = torch.from_numpy(seq_np).permute(0, 3, 1, 2) |
|
|
| severity = self.labels.get(seq_path.name, 0) |
| return seq_tensor, severity |
|
|
|
|
| |
| |
| |
| class BiGRUTrainer: |
| """ |
| End-to-end trainer for the CNN-BiGRU severity model. |
| """ |
|
|
| def __init__( |
| self, |
| model: CNNBiGRU, |
| lr: float = 1e-3, |
| weight_decay: float = 1e-4, |
| device: Optional[str] = None, |
| ): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = model.to(self.device) |
| self.criterion = nn.CrossEntropyLoss() |
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), lr=lr, weight_decay=weight_decay, |
| ) |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| self.optimizer, T_max=50, eta_min=1e-6, |
| ) |
|
|
| def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]: |
| self.model.train() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| for batch_idx, (sequences, labels) in enumerate(dataloader): |
| sequences = sequences.to(self.device) |
| labels = labels.to(self.device).long() |
|
|
| self.optimizer.zero_grad() |
| logits = self.model(sequences) |
| loss = self.criterion(logits, labels) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.step() |
|
|
| total_loss += loss.item() * sequences.size(0) |
| preds = logits.argmax(dim=-1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| self.scheduler.step() |
| avg_loss = total_loss / max(total, 1) |
| accuracy = correct / max(total, 1) |
| return {"loss": avg_loss, "accuracy": accuracy} |
|
|
| @torch.no_grad() |
| def evaluate(self, dataloader: DataLoader) -> Dict[str, float]: |
| self.model.eval() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
| all_preds: List[int] = [] |
| all_labels: List[int] = [] |
|
|
| for sequences, labels in dataloader: |
| sequences = sequences.to(self.device) |
| labels = labels.to(self.device).long() |
|
|
| logits = self.model(sequences) |
| loss = self.criterion(logits, labels) |
|
|
| total_loss += loss.item() * sequences.size(0) |
| preds = logits.argmax(dim=-1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| all_preds.extend(preds.cpu().tolist()) |
| all_labels.extend(labels.cpu().tolist()) |
|
|
| avg_loss = total_loss / max(total, 1) |
| accuracy = correct / max(total, 1) |
| return { |
| "loss": avg_loss, |
| "accuracy": accuracy, |
| "predictions": all_preds, |
| "labels": all_labels, |
| } |
|
|
| def fit( |
| self, |
| train_loader: DataLoader, |
| val_loader: Optional[DataLoader] = None, |
| epochs: int = 50, |
| save_dir: str = "bigru_checkpoints", |
| patience: int = 10, |
| ) -> Dict[str, List[float]]: |
| """ |
| Full training loop with early stopping. |
| """ |
| save_path = Path(save_dir) |
| save_path.mkdir(parents=True, exist_ok=True) |
|
|
| history: Dict[str, List[float]] = { |
| "train_loss": [], "train_acc": [], |
| "val_loss": [], "val_acc": [], |
| } |
| best_val_loss = float("inf") |
| no_improve = 0 |
|
|
| for epoch in range(1, epochs + 1): |
| train_metrics = self.train_epoch(train_loader) |
| history["train_loss"].append(train_metrics["loss"]) |
| history["train_acc"].append(train_metrics["accuracy"]) |
|
|
| log_msg = ( |
| f"Epoch {epoch:3d}/{epochs} " |
| f"train_loss={train_metrics['loss']:.4f} " |
| f"train_acc={train_metrics['accuracy']:.4f}" |
| ) |
|
|
| if val_loader is not None: |
| val_metrics = self.evaluate(val_loader) |
| history["val_loss"].append(val_metrics["loss"]) |
| history["val_acc"].append(val_metrics["accuracy"]) |
| log_msg += ( |
| f" val_loss={val_metrics['loss']:.4f} " |
| f"val_acc={val_metrics['accuracy']:.4f}" |
| ) |
|
|
| if val_metrics["loss"] < best_val_loss: |
| best_val_loss = val_metrics["loss"] |
| no_improve = 0 |
| torch.save(self.model.state_dict(), save_path / "best_bigru.pth") |
| else: |
| no_improve += 1 |
|
|
| if no_improve >= patience: |
| logger.info("Early stopping at epoch %d", epoch) |
| break |
|
|
| logger.info(log_msg) |
|
|
| |
| if epoch % 10 == 0: |
| torch.save(self.model.state_dict(), save_path / f"bigru_epoch{epoch}.pth") |
|
|
| |
| torch.save(self.model.state_dict(), save_path / "last_bigru.pth") |
| logger.info("Training complete β best val_loss=%.4f", best_val_loss) |
| return history |
|
|
| def save(self, path: str): |
| torch.save(self.model.state_dict(), path) |
| logger.info("Model saved β %s", path) |
|
|
| def load(self, path: str): |
| self.model.load_state_dict(torch.load(path, map_location=self.device)) |
| logger.info("Model loaded β %s", path) |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("=" * 60) |
| print("CNN-BiGRU Module β Smoke Test") |
| print("=" * 60) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| model = CNNBiGRU( |
| in_channels=3, |
| hidden_size=128, |
| num_gru_layers=2, |
| num_severity_classes=4, |
| ).to(device) |
|
|
| |
| dummy = torch.randn(2, 8, 3, PATCH_SIZE, PATCH_SIZE, device=device) |
| logits = model(dummy) |
| print(f"CNNBiGRU output shape : {logits.shape}") |
| preds, probs = model.predict_severity(dummy) |
| print(f"Predicted classes : {preds.tolist()}") |
| print(f"Probabilities : {probs.cpu().numpy().round(3)}") |
|
|
| |
| logits_att, attn = model.forward_with_attention(dummy) |
| print(f"Attention weights : {attn.cpu().numpy().round(3)}") |
|
|
| |
| clf = CNNSeverityClassifier(num_severity_classes=4).to(device) |
| single = torch.randn(2, 3, PATCH_SIZE, PATCH_SIZE, device=device) |
| out = clf(single) |
| print(f"SingleFrame output : {out.shape}") |
|
|
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"\nCNN-BiGRU parameters : {total_params:,}") |
| print("β
All smoke tests passed") |
|
|