XDL-Colitis-Demo / src /inference.py
Arviano's picture
Add GradCAM target layer selection and UI defaults integration
67f1c25
from collections import Counter, defaultdict
from dataclasses import dataclass
from html import escape
import os
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from src.densenet import DenseNet121
from src.preprocessing import preprocess
CASE_CONFIGS = {
"Multiclass (4 Classes)": {
"key": "multiclass",
"labels": ["Kolitis Ulseratif", "Infeksi", "Crohn", "Tuberculosis"],
"default_model_path": "models/multiclass.pth",
"env_var": "MODEL_PATH_MULTICLASS",
},
"Crohn vs TB": {
"key": "crohn_tb",
"labels": ["Crohn", "Tuberculosis"],
"default_model_path": "models/crohn_tb.pth",
"env_var": "MODEL_PATH_CROHN_TB",
},
"Kolitis Ulseratif vs Infeksi": {
"key": "ku_infeksi",
"labels": ["Kolitis Ulseratif", "Infeksi"],
"default_model_path": "models/ku_infeksi.pth",
"env_var": "MODEL_PATH_KU_INFEKSI",
},
}
DEFAULT_CASE_NAME = "Multiclass (4 Classes)"
CASE_OPTIONS = list(CASE_CONFIGS.keys())
SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
DEFAULT_SAVE_DIR = "xdl_results"
GRADCAM_TARGET_LAYER_OPTIONS = (
"denseblock3",
"transition2",
"transition1",
"denseblock4",
"transition3",
"norm5_last",
)
DEFAULT_GRADCAM_TARGET_LAYER = "denseblock3"
def _detect_device() -> torch.device:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="CUDA initialization:.*")
try:
has_cuda = torch.cuda.is_available()
except Exception:
has_cuda = False
return torch.device("cuda" if has_cuda else "cpu")
DEVICE = _detect_device()
_case_state_cache: dict[str, dict] = {}
val_transform = preprocess(target_input_size=(3, 299, 299))
@dataclass(frozen=True)
class ClassifiedPrediction:
path: Path
pred_idx: int
confidence: float
def _get_case_config(case_name: str) -> dict:
return CASE_CONFIGS.get(case_name, CASE_CONFIGS[DEFAULT_CASE_NAME])
def _resolve_case_model_path(case_cfg: dict) -> str:
return os.getenv(case_cfg["env_var"], case_cfg["default_model_path"])
def _get_case_state(case_name: str) -> dict:
case_cfg = _get_case_config(case_name)
case_key = case_cfg["key"]
if case_key in _case_state_cache:
return _case_state_cache[case_key]
labels = case_cfg["labels"]
model_path = _resolve_case_model_path(case_cfg)
model = DenseNet121(num_classes=len(labels)).to(DEVICE)
model_error: Optional[str] = None
try:
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
except Exception as exc:
model_error = f"Model failed to load for case `{case_name}` from `{model_path}`: {exc}"
state = {
"case_name": case_name,
"labels": labels,
"model_path": model_path,
"model": model,
"model_error": model_error,
}
_case_state_cache[case_key] = state
return state
def _load_xdl_modules():
"""Lazy-load optional XDL dependencies."""
try:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except Exception as exc:
raise RuntimeError(f"GradCAM import failed: {exc}")
try:
import cv2
except Exception as exc:
raise RuntimeError(f"OpenCV import failed: {exc}")
try:
from src.xdl import (
_get_target_layer,
_preprocess_image,
_process_smoothgrad_map,
smoothgrad,
)
except Exception as exc:
raise RuntimeError(f"Failed to load XDL utilities from src/xdl.py: {exc}")
return {
"cv2": cv2,
"GradCAM": GradCAM,
"show_cam_on_image": show_cam_on_image,
"ClassifierOutputTarget": ClassifierOutputTarget,
"_get_target_layer": _get_target_layer,
"_preprocess_image": _preprocess_image,
"_process_smoothgrad_map": _process_smoothgrad_map,
"smoothgrad": smoothgrad,
}
def _iter_image_paths(folder: Path) -> List[Path]:
return sorted(
p
for p in folder.rglob("*")
if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS
)
def _normalize_uploaded_paths(uploaded_files: Optional[Sequence[str] | str]) -> List[Path]:
if uploaded_files is None:
return []
if isinstance(uploaded_files, str):
raw_paths: Iterable[str] = [uploaded_files]
else:
raw_paths = uploaded_files
collected: List[Path] = []
for raw_path in raw_paths:
if not raw_path:
continue
path = Path(str(raw_path))
if path.is_dir():
collected.extend(_iter_image_paths(path))
elif path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS:
collected.append(path)
return sorted({p.resolve() for p in collected})
def _resolve_input_images(uploaded_files: Optional[Sequence[str] | str], folder_path: str) -> Tuple[List[Path], str]:
uploaded_paths = _normalize_uploaded_paths(uploaded_files)
local_paths: List[Path] = []
if folder_path:
folder = Path(folder_path).expanduser().resolve()
if not folder.exists() or not folder.is_dir():
return [], f"Invalid folder: `{folder}`"
local_paths = _iter_image_paths(folder)
all_paths = sorted({p.resolve() for p in uploaded_paths + local_paths})
if not all_paths:
return [], "Upload a folder (or files) or provide a valid folder path."
return all_paths, ""
def _aggregate_classification(classified: List[ClassifiedPrediction], labels: List[str]) -> Tuple[str, float]:
class_counter = Counter(item.pred_idx for item in classified)
top_count = max(class_counter.values())
tied_classes = [idx for idx, count in class_counter.items() if count == top_count]
if len(tied_classes) == 1:
final_idx = tied_classes[0]
else:
class_conf = defaultdict(list)
for item in classified:
class_conf[item.pred_idx].append(item.confidence)
final_idx = max(tied_classes, key=lambda idx: float(np.mean(class_conf[idx])))
final_conf = float(np.mean([item.confidence for item in classified if item.pred_idx == final_idx]))
return labels[final_idx], final_conf
def _predict_top1(model: DenseNet121, image: Image.Image) -> Tuple[int, float, torch.Tensor]:
model_device = next(model.parameters()).device
input_tensor = val_transform(image).unsqueeze(0).to(model_device)
with torch.no_grad():
logits = model(input_tensor)[0]
probs = F.softmax(logits, dim=0)
pred_idx = int(torch.argmax(probs).item())
confidence = float(probs[pred_idx].item())
return pred_idx, confidence, input_tensor
def _classify_image_paths(
model: DenseNet121,
labels: List[str],
image_paths: List[Path],
threshold: float,
) -> Tuple[List[ClassifiedPrediction], List[List[str]]]:
classified: List[ClassifiedPrediction] = []
rows: List[List[str]] = []
for img_path in image_paths:
try:
image = Image.open(img_path).convert("RGB")
except Exception as exc:
rows.append([img_path.name, "error", "-", str(exc)])
continue
pred_idx, confidence, _ = _predict_top1(model, image)
if confidence < threshold:
rows.append([img_path.name, "below_threshold", labels[pred_idx], f"{confidence:.4f}"])
continue
prediction = ClassifiedPrediction(path=img_path, pred_idx=pred_idx, confidence=confidence)
classified.append(prediction)
rows.append([img_path.name, "classified", labels[pred_idx], f"{confidence:.4f}"])
return classified, rows
def _resize_rgb(image: np.ndarray, side: int) -> np.ndarray:
pil_image = Image.fromarray(image)
if hasattr(Image, "Resampling"):
resized = pil_image.resize((side, side), Image.Resampling.BILINEAR)
else:
resized = pil_image.resize((side, side), Image.BILINEAR)
return np.array(resized)
def _build_compact_visual_row(
original: np.ndarray,
gradcam_img: np.ndarray,
smoothgrad_img: np.ndarray,
tile_size: int = 192,
) -> np.ndarray:
original_small = _resize_rgb(original, tile_size)
gradcam_small = _resize_rgb(gradcam_img, tile_size)
smoothgrad_small = _resize_rgb(smoothgrad_img, tile_size)
return np.concatenate([original_small, gradcam_small, smoothgrad_small], axis=1)
def _safe_stem(path: Path) -> str:
stem = path.stem
cleaned = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in stem)
return cleaned[:80] if cleaned else "image"
def _save_compact_panel(panel: np.ndarray, image_path: Path, save_dir: Path) -> Path:
base = _safe_stem(image_path)
out_path = save_dir / f"{base}_xdl.jpg"
suffix = 1
while out_path.exists():
out_path = save_dir / f"{base}_xdl_{suffix}.jpg"
suffix += 1
Image.fromarray(panel).save(out_path, format="JPEG", quality=95)
return out_path
def _render_summary_html(
*,
case_name: str,
model_path: str,
device_name: str,
processed: int,
classified: int,
threshold: float,
final_class: str,
mean_confidence: Optional[float],
distribution: str,
xdl_status: str,
save_status: str,
) -> str:
if mean_confidence is None:
mean_conf_text = "N/A"
progress_percent = 0.0
else:
progress_percent = float(np.clip(mean_confidence * 100.0, 0.0, 100.0))
mean_conf_text = f"{mean_confidence:.4f}"
rows = [
("Case", case_name),
("Model Path", model_path),
("Device", device_name),
("Processed Images", str(processed)),
("Classified / Skipped", f"{classified} / {processed - classified}"),
("Confidence Threshold", f"{threshold:.2f}"),
("Final Class", final_class),
("Mean Confidence", mean_conf_text),
("Class Distribution", distribution),
("XDL", xdl_status),
("Saved XDL", save_status),
]
row_html = "".join(
"<tr>"
f"<th style='text-align:left;padding:8px 10px;border-bottom:1px solid #e5e7eb;width:220px;color:#374151'>{escape(key)}</th>"
f"<td style='padding:8px 10px;border-bottom:1px solid #e5e7eb;color:#111827'>{escape(value)}</td>"
"</tr>"
for key, value in rows
)
return (
"<div style='border:1px solid #d1d5db;border-radius:10px;padding:12px;background:#f9fafb'>"
"<div style='font-weight:600;font-size:16px;margin-bottom:8px;color:#111827'>Batch Summary</div>"
"<div style='display:flex;gap:14px;align-items:stretch;flex-wrap:wrap'>"
"<div style='flex:1 1 430px;min-width:320px'>"
"<table style='width:100%;border-collapse:collapse;font-size:14px'>"
f"{row_html}"
"</table>"
"</div>"
"<div style='flex:0 0 220px;display:flex;align-items:center;justify-content:center'>"
"<div style='display:flex;flex-direction:column;align-items:center;gap:8px'>"
"<div style='font-size:13px;color:#374151'>Predicted Class</div>"
f"<div style='font-size:15px;font-weight:600;color:#111827'>{escape(final_class)}</div>"
"<div style='position:relative;width:140px;height:140px'>"
"<div style='position:absolute;inset:0;border-radius:50%;"
f"background:conic-gradient(#0ea5e9 {progress_percent:.2f}%, #e5e7eb 0);'></div>"
"<div style='position:absolute;inset:12px;border-radius:50%;background:white;"
"display:flex;flex-direction:column;align-items:center;justify-content:center'>"
"<div style='font-size:12px;color:#6b7280'>Mean Conf</div>"
f"<div style='font-size:20px;font-weight:700;color:#0f172a'>{progress_percent:.1f}%</div>"
"</div>"
"</div>"
"</div>"
"</div>"
"</div>"
)
def _render_error_html(message: str) -> str:
return (
"<div style='border:1px solid #fca5a5;border-radius:10px;padding:12px;background:#fef2f2;color:#7f1d1d'>"
"<div style='font-weight:600;margin-bottom:6px'>Input/Error</div>"
f"<div>{escape(message)}</div>"
"</div>"
)
def batch_predict_with_xdl(
uploaded_files,
selected_case: str,
folder_path: str,
confidence_threshold: float,
smoothgrad_samples: int,
smoothgrad_noise: float,
save_xdl_results: bool,
save_xdl_dir: str,
gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER,
):
last_output = None
for payload in batch_predict_with_xdl_stream(
uploaded_files=uploaded_files,
selected_case=selected_case,
folder_path=folder_path,
confidence_threshold=confidence_threshold,
smoothgrad_samples=smoothgrad_samples,
smoothgrad_noise=smoothgrad_noise,
save_xdl_results=save_xdl_results,
save_xdl_dir=save_xdl_dir,
gradcam_target_layer=gradcam_target_layer,
):
last_output = payload
if last_output is None:
return _render_error_html("No output generated."), [], []
return last_output
def batch_predict_with_xdl_stream(
uploaded_files,
selected_case: str,
folder_path: str,
confidence_threshold: float,
smoothgrad_samples: int,
smoothgrad_noise: float,
save_xdl_results: bool,
save_xdl_dir: str,
gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER,
) -> Iterator[Tuple[str, List[List[str]], List[Tuple[np.ndarray, str]]]]:
case_state = _get_case_state(selected_case)
model: DenseNet121 = case_state["model"]
labels: List[str] = case_state["labels"]
case_name: str = case_state["case_name"]
model_path: str = case_state["model_path"]
model_error: Optional[str] = case_state["model_error"]
if model_error:
yield _render_error_html(model_error), [], []
return
model_device = next(model.parameters()).device
threshold = float(np.clip(confidence_threshold, 0.0, 1.0))
smoothgrad_samples = int(max(1, smoothgrad_samples))
smoothgrad_noise = float(max(0.0, smoothgrad_noise))
gradcam_target_layer = str(gradcam_target_layer or DEFAULT_GRADCAM_TARGET_LAYER).strip().lower()
if gradcam_target_layer not in GRADCAM_TARGET_LAYER_OPTIONS:
gradcam_target_layer = DEFAULT_GRADCAM_TARGET_LAYER
image_paths, input_error = _resolve_input_images(uploaded_files, folder_path)
if input_error:
yield _render_error_html(input_error), [], []
return
save_dir_path: Optional[Path] = None
save_error = ""
saved_count = 0
if save_xdl_results:
raw_dir = save_xdl_dir.strip() if save_xdl_dir else DEFAULT_SAVE_DIR
try:
save_dir_path = Path(raw_dir).expanduser().resolve()
save_dir_path.mkdir(parents=True, exist_ok=True)
except Exception as exc:
save_error = f"Save disabled: {exc}"
save_dir_path = None
classified, rows = _classify_image_paths(model, labels, image_paths, threshold)
if save_xdl_results:
if save_error:
save_status = save_error
elif save_dir_path is not None and not classified:
save_status = f"No files saved (0 classified). Target: {save_dir_path}"
else:
save_status = "Pending..."
else:
save_status = "Disabled"
if classified:
final_class, mean_conf = _aggregate_classification(classified, labels)
class_counter = Counter(item.pred_idx for item in classified)
class_stats = ", ".join(f"{labels[idx]}: {count}" for idx, count in class_counter.items())
initial_xdl_status = f"Processing overlays... (GradCAM layer: {gradcam_target_layer})"
else:
final_class = "N/A"
mean_conf = None
class_stats = "N/A"
initial_xdl_status = "Skipped (no classified images)"
summary_initial = _render_summary_html(
case_name=case_name,
model_path=model_path,
device_name=model_device.type,
processed=len(image_paths),
classified=len(classified),
threshold=threshold,
final_class=final_class,
mean_confidence=mean_conf,
distribution=class_stats,
xdl_status=initial_xdl_status,
save_status=save_status,
)
# First UI update: classification results first.
yield summary_initial, rows, []
if not classified:
return
xdl = None
xdl_error = ""
try:
xdl = _load_xdl_modules()
except RuntimeError as exc:
xdl_error = str(exc)
gallery_items: List[Tuple[np.ndarray, str]] = []
xdl_error_count = 0
if xdl is not None:
try:
target_layer = xdl["_get_target_layer"](model, layer_name=gradcam_target_layer)
except TypeError:
# Backward compatibility for older helper signature: _get_target_layer(model)
target_layer = xdl["_get_target_layer"](model)
cam = xdl["GradCAM"](model=model, target_layers=[target_layer])
for item in classified:
try:
image = Image.open(item.path).convert("RGB")
input_tensor = val_transform(image).unsqueeze(0).to(model_device)
base_img_float, base_img_uint8 = xdl["_preprocess_image"](input_tensor[0])
h, w = base_img_uint8.shape[:2]
grayscale_cam = cam(
input_tensor=input_tensor,
targets=[xdl["ClassifierOutputTarget"](item.pred_idx)],
)[0, :]
gradcam_overlay = xdl["show_cam_on_image"](base_img_float, grayscale_cam, use_rgb=True)
smooth_raw = xdl["smoothgrad"](
model,
input_tensor,
item.pred_idx,
n_samples=smoothgrad_samples,
noise_level=smoothgrad_noise,
use_amp=(model_device.type == "cuda"),
)
_, smooth_heatmap = xdl["_process_smoothgrad_map"](
smooth_raw,
img_shape=(h, w),
percentile=95,
colormap="hot",
)
# Match old-original-xdl.py blending behavior exactly.
smooth_overlay = xdl["cv2"].addWeighted(base_img_uint8, 0.5, smooth_heatmap, 0.5, 0)
compact_panel = _build_compact_visual_row(base_img_uint8, gradcam_overlay, smooth_overlay)
caption = f"{item.path.name} | {labels[item.pred_idx]} ({item.confidence:.3f})"
gallery_items.append((compact_panel, caption))
if save_dir_path is not None:
_save_compact_panel(compact_panel, item.path, save_dir_path)
saved_count += 1
except Exception as exc:
xdl_error_count += 1
rows.append([item.path.name, "xdl_error", labels[item.pred_idx], str(exc)])
if xdl_error:
xdl_status = f"Disabled: {xdl_error}"
elif xdl_error_count:
xdl_status = f"Completed with {xdl_error_count} overlay errors"
else:
xdl_status = f"Completed ({len(gallery_items)} overlays, layer: {gradcam_target_layer})"
if save_xdl_results:
if save_error:
save_status = save_error
elif save_dir_path is not None:
save_status = f"Saved {saved_count} files to {save_dir_path}"
else:
save_status = "Disabled"
else:
save_status = "Disabled"
summary_final = _render_summary_html(
case_name=case_name,
model_path=model_path,
device_name=model_device.type,
processed=len(image_paths),
classified=len(classified),
threshold=threshold,
final_class=final_class,
mean_confidence=mean_conf,
distribution=class_stats,
xdl_status=xdl_status,
save_status=save_status,
)
yield summary_final, rows, gallery_items