from collections import Counter, defaultdict from dataclasses import dataclass from html import escape import os from pathlib import Path from typing import Iterable, Iterator, List, Optional, Sequence, Tuple import warnings import numpy as np import torch import torch.nn.functional as F from PIL import Image from src.densenet import DenseNet121 from src.preprocessing import preprocess CASE_CONFIGS = { "Multiclass (4 Classes)": { "key": "multiclass", "labels": ["Kolitis Ulseratif", "Infeksi", "Crohn", "Tuberculosis"], "default_model_path": "models/multiclass.pth", "env_var": "MODEL_PATH_MULTICLASS", }, "Crohn vs TB": { "key": "crohn_tb", "labels": ["Crohn", "Tuberculosis"], "default_model_path": "models/crohn_tb.pth", "env_var": "MODEL_PATH_CROHN_TB", }, "Kolitis Ulseratif vs Infeksi": { "key": "ku_infeksi", "labels": ["Kolitis Ulseratif", "Infeksi"], "default_model_path": "models/ku_infeksi.pth", "env_var": "MODEL_PATH_KU_INFEKSI", }, } DEFAULT_CASE_NAME = "Multiclass (4 Classes)" CASE_OPTIONS = list(CASE_CONFIGS.keys()) SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} DEFAULT_SAVE_DIR = "xdl_results" GRADCAM_TARGET_LAYER_OPTIONS = ( "denseblock3", "transition2", "transition1", "denseblock4", "transition3", "norm5_last", ) DEFAULT_GRADCAM_TARGET_LAYER = "denseblock3" def _detect_device() -> torch.device: with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="CUDA initialization:.*") try: has_cuda = torch.cuda.is_available() except Exception: has_cuda = False return torch.device("cuda" if has_cuda else "cpu") DEVICE = _detect_device() _case_state_cache: dict[str, dict] = {} val_transform = preprocess(target_input_size=(3, 299, 299)) @dataclass(frozen=True) class ClassifiedPrediction: path: Path pred_idx: int confidence: float def _get_case_config(case_name: str) -> dict: return CASE_CONFIGS.get(case_name, CASE_CONFIGS[DEFAULT_CASE_NAME]) def _resolve_case_model_path(case_cfg: dict) -> str: return os.getenv(case_cfg["env_var"], case_cfg["default_model_path"]) def _get_case_state(case_name: str) -> dict: case_cfg = _get_case_config(case_name) case_key = case_cfg["key"] if case_key in _case_state_cache: return _case_state_cache[case_key] labels = case_cfg["labels"] model_path = _resolve_case_model_path(case_cfg) model = DenseNet121(num_classes=len(labels)).to(DEVICE) model_error: Optional[str] = None try: model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() except Exception as exc: model_error = f"Model failed to load for case `{case_name}` from `{model_path}`: {exc}" state = { "case_name": case_name, "labels": labels, "model_path": model_path, "model": model, "model_error": model_error, } _case_state_cache[case_key] = state return state def _load_xdl_modules(): """Lazy-load optional XDL dependencies.""" try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget except Exception as exc: raise RuntimeError(f"GradCAM import failed: {exc}") try: import cv2 except Exception as exc: raise RuntimeError(f"OpenCV import failed: {exc}") try: from src.xdl import ( _get_target_layer, _preprocess_image, _process_smoothgrad_map, smoothgrad, ) except Exception as exc: raise RuntimeError(f"Failed to load XDL utilities from src/xdl.py: {exc}") return { "cv2": cv2, "GradCAM": GradCAM, "show_cam_on_image": show_cam_on_image, "ClassifierOutputTarget": ClassifierOutputTarget, "_get_target_layer": _get_target_layer, "_preprocess_image": _preprocess_image, "_process_smoothgrad_map": _process_smoothgrad_map, "smoothgrad": smoothgrad, } def _iter_image_paths(folder: Path) -> List[Path]: return sorted( p for p in folder.rglob("*") if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS ) def _normalize_uploaded_paths(uploaded_files: Optional[Sequence[str] | str]) -> List[Path]: if uploaded_files is None: return [] if isinstance(uploaded_files, str): raw_paths: Iterable[str] = [uploaded_files] else: raw_paths = uploaded_files collected: List[Path] = [] for raw_path in raw_paths: if not raw_path: continue path = Path(str(raw_path)) if path.is_dir(): collected.extend(_iter_image_paths(path)) elif path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS: collected.append(path) return sorted({p.resolve() for p in collected}) def _resolve_input_images(uploaded_files: Optional[Sequence[str] | str], folder_path: str) -> Tuple[List[Path], str]: uploaded_paths = _normalize_uploaded_paths(uploaded_files) local_paths: List[Path] = [] if folder_path: folder = Path(folder_path).expanduser().resolve() if not folder.exists() or not folder.is_dir(): return [], f"Invalid folder: `{folder}`" local_paths = _iter_image_paths(folder) all_paths = sorted({p.resolve() for p in uploaded_paths + local_paths}) if not all_paths: return [], "Upload a folder (or files) or provide a valid folder path." return all_paths, "" def _aggregate_classification(classified: List[ClassifiedPrediction], labels: List[str]) -> Tuple[str, float]: class_counter = Counter(item.pred_idx for item in classified) top_count = max(class_counter.values()) tied_classes = [idx for idx, count in class_counter.items() if count == top_count] if len(tied_classes) == 1: final_idx = tied_classes[0] else: class_conf = defaultdict(list) for item in classified: class_conf[item.pred_idx].append(item.confidence) final_idx = max(tied_classes, key=lambda idx: float(np.mean(class_conf[idx]))) final_conf = float(np.mean([item.confidence for item in classified if item.pred_idx == final_idx])) return labels[final_idx], final_conf def _predict_top1(model: DenseNet121, image: Image.Image) -> Tuple[int, float, torch.Tensor]: model_device = next(model.parameters()).device input_tensor = val_transform(image).unsqueeze(0).to(model_device) with torch.no_grad(): logits = model(input_tensor)[0] probs = F.softmax(logits, dim=0) pred_idx = int(torch.argmax(probs).item()) confidence = float(probs[pred_idx].item()) return pred_idx, confidence, input_tensor def _classify_image_paths( model: DenseNet121, labels: List[str], image_paths: List[Path], threshold: float, ) -> Tuple[List[ClassifiedPrediction], List[List[str]]]: classified: List[ClassifiedPrediction] = [] rows: List[List[str]] = [] for img_path in image_paths: try: image = Image.open(img_path).convert("RGB") except Exception as exc: rows.append([img_path.name, "error", "-", str(exc)]) continue pred_idx, confidence, _ = _predict_top1(model, image) if confidence < threshold: rows.append([img_path.name, "below_threshold", labels[pred_idx], f"{confidence:.4f}"]) continue prediction = ClassifiedPrediction(path=img_path, pred_idx=pred_idx, confidence=confidence) classified.append(prediction) rows.append([img_path.name, "classified", labels[pred_idx], f"{confidence:.4f}"]) return classified, rows def _resize_rgb(image: np.ndarray, side: int) -> np.ndarray: pil_image = Image.fromarray(image) if hasattr(Image, "Resampling"): resized = pil_image.resize((side, side), Image.Resampling.BILINEAR) else: resized = pil_image.resize((side, side), Image.BILINEAR) return np.array(resized) def _build_compact_visual_row( original: np.ndarray, gradcam_img: np.ndarray, smoothgrad_img: np.ndarray, tile_size: int = 192, ) -> np.ndarray: original_small = _resize_rgb(original, tile_size) gradcam_small = _resize_rgb(gradcam_img, tile_size) smoothgrad_small = _resize_rgb(smoothgrad_img, tile_size) return np.concatenate([original_small, gradcam_small, smoothgrad_small], axis=1) def _safe_stem(path: Path) -> str: stem = path.stem cleaned = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in stem) return cleaned[:80] if cleaned else "image" def _save_compact_panel(panel: np.ndarray, image_path: Path, save_dir: Path) -> Path: base = _safe_stem(image_path) out_path = save_dir / f"{base}_xdl.jpg" suffix = 1 while out_path.exists(): out_path = save_dir / f"{base}_xdl_{suffix}.jpg" suffix += 1 Image.fromarray(panel).save(out_path, format="JPEG", quality=95) return out_path def _render_summary_html( *, case_name: str, model_path: str, device_name: str, processed: int, classified: int, threshold: float, final_class: str, mean_confidence: Optional[float], distribution: str, xdl_status: str, save_status: str, ) -> str: if mean_confidence is None: mean_conf_text = "N/A" progress_percent = 0.0 else: progress_percent = float(np.clip(mean_confidence * 100.0, 0.0, 100.0)) mean_conf_text = f"{mean_confidence:.4f}" rows = [ ("Case", case_name), ("Model Path", model_path), ("Device", device_name), ("Processed Images", str(processed)), ("Classified / Skipped", f"{classified} / {processed - classified}"), ("Confidence Threshold", f"{threshold:.2f}"), ("Final Class", final_class), ("Mean Confidence", mean_conf_text), ("Class Distribution", distribution), ("XDL", xdl_status), ("Saved XDL", save_status), ] row_html = "".join( "" f"{escape(key)}" f"{escape(value)}" "" for key, value in rows ) return ( "
" "
Batch Summary
" "
" "
" "" f"{row_html}" "
" "
" "
" "
" "
Predicted Class
" f"
{escape(final_class)}
" "
" "
" "
" "
Mean Conf
" f"
{progress_percent:.1f}%
" "
" "
" "
" "
" "
" ) def _render_error_html(message: str) -> str: return ( "
" "
Input/Error
" f"
{escape(message)}
" "
" ) def batch_predict_with_xdl( uploaded_files, selected_case: str, folder_path: str, confidence_threshold: float, smoothgrad_samples: int, smoothgrad_noise: float, save_xdl_results: bool, save_xdl_dir: str, gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER, ): last_output = None for payload in batch_predict_with_xdl_stream( uploaded_files=uploaded_files, selected_case=selected_case, folder_path=folder_path, confidence_threshold=confidence_threshold, smoothgrad_samples=smoothgrad_samples, smoothgrad_noise=smoothgrad_noise, save_xdl_results=save_xdl_results, save_xdl_dir=save_xdl_dir, gradcam_target_layer=gradcam_target_layer, ): last_output = payload if last_output is None: return _render_error_html("No output generated."), [], [] return last_output def batch_predict_with_xdl_stream( uploaded_files, selected_case: str, folder_path: str, confidence_threshold: float, smoothgrad_samples: int, smoothgrad_noise: float, save_xdl_results: bool, save_xdl_dir: str, gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER, ) -> Iterator[Tuple[str, List[List[str]], List[Tuple[np.ndarray, str]]]]: case_state = _get_case_state(selected_case) model: DenseNet121 = case_state["model"] labels: List[str] = case_state["labels"] case_name: str = case_state["case_name"] model_path: str = case_state["model_path"] model_error: Optional[str] = case_state["model_error"] if model_error: yield _render_error_html(model_error), [], [] return model_device = next(model.parameters()).device threshold = float(np.clip(confidence_threshold, 0.0, 1.0)) smoothgrad_samples = int(max(1, smoothgrad_samples)) smoothgrad_noise = float(max(0.0, smoothgrad_noise)) gradcam_target_layer = str(gradcam_target_layer or DEFAULT_GRADCAM_TARGET_LAYER).strip().lower() if gradcam_target_layer not in GRADCAM_TARGET_LAYER_OPTIONS: gradcam_target_layer = DEFAULT_GRADCAM_TARGET_LAYER image_paths, input_error = _resolve_input_images(uploaded_files, folder_path) if input_error: yield _render_error_html(input_error), [], [] return save_dir_path: Optional[Path] = None save_error = "" saved_count = 0 if save_xdl_results: raw_dir = save_xdl_dir.strip() if save_xdl_dir else DEFAULT_SAVE_DIR try: save_dir_path = Path(raw_dir).expanduser().resolve() save_dir_path.mkdir(parents=True, exist_ok=True) except Exception as exc: save_error = f"Save disabled: {exc}" save_dir_path = None classified, rows = _classify_image_paths(model, labels, image_paths, threshold) if save_xdl_results: if save_error: save_status = save_error elif save_dir_path is not None and not classified: save_status = f"No files saved (0 classified). Target: {save_dir_path}" else: save_status = "Pending..." else: save_status = "Disabled" if classified: final_class, mean_conf = _aggregate_classification(classified, labels) class_counter = Counter(item.pred_idx for item in classified) class_stats = ", ".join(f"{labels[idx]}: {count}" for idx, count in class_counter.items()) initial_xdl_status = f"Processing overlays... (GradCAM layer: {gradcam_target_layer})" else: final_class = "N/A" mean_conf = None class_stats = "N/A" initial_xdl_status = "Skipped (no classified images)" summary_initial = _render_summary_html( case_name=case_name, model_path=model_path, device_name=model_device.type, processed=len(image_paths), classified=len(classified), threshold=threshold, final_class=final_class, mean_confidence=mean_conf, distribution=class_stats, xdl_status=initial_xdl_status, save_status=save_status, ) # First UI update: classification results first. yield summary_initial, rows, [] if not classified: return xdl = None xdl_error = "" try: xdl = _load_xdl_modules() except RuntimeError as exc: xdl_error = str(exc) gallery_items: List[Tuple[np.ndarray, str]] = [] xdl_error_count = 0 if xdl is not None: try: target_layer = xdl["_get_target_layer"](model, layer_name=gradcam_target_layer) except TypeError: # Backward compatibility for older helper signature: _get_target_layer(model) target_layer = xdl["_get_target_layer"](model) cam = xdl["GradCAM"](model=model, target_layers=[target_layer]) for item in classified: try: image = Image.open(item.path).convert("RGB") input_tensor = val_transform(image).unsqueeze(0).to(model_device) base_img_float, base_img_uint8 = xdl["_preprocess_image"](input_tensor[0]) h, w = base_img_uint8.shape[:2] grayscale_cam = cam( input_tensor=input_tensor, targets=[xdl["ClassifierOutputTarget"](item.pred_idx)], )[0, :] gradcam_overlay = xdl["show_cam_on_image"](base_img_float, grayscale_cam, use_rgb=True) smooth_raw = xdl["smoothgrad"]( model, input_tensor, item.pred_idx, n_samples=smoothgrad_samples, noise_level=smoothgrad_noise, use_amp=(model_device.type == "cuda"), ) _, smooth_heatmap = xdl["_process_smoothgrad_map"]( smooth_raw, img_shape=(h, w), percentile=95, colormap="hot", ) # Match old-original-xdl.py blending behavior exactly. smooth_overlay = xdl["cv2"].addWeighted(base_img_uint8, 0.5, smooth_heatmap, 0.5, 0) compact_panel = _build_compact_visual_row(base_img_uint8, gradcam_overlay, smooth_overlay) caption = f"{item.path.name} | {labels[item.pred_idx]} ({item.confidence:.3f})" gallery_items.append((compact_panel, caption)) if save_dir_path is not None: _save_compact_panel(compact_panel, item.path, save_dir_path) saved_count += 1 except Exception as exc: xdl_error_count += 1 rows.append([item.path.name, "xdl_error", labels[item.pred_idx], str(exc)]) if xdl_error: xdl_status = f"Disabled: {xdl_error}" elif xdl_error_count: xdl_status = f"Completed with {xdl_error_count} overlay errors" else: xdl_status = f"Completed ({len(gallery_items)} overlays, layer: {gradcam_target_layer})" if save_xdl_results: if save_error: save_status = save_error elif save_dir_path is not None: save_status = f"Saved {saved_count} files to {save_dir_path}" else: save_status = "Disabled" else: save_status = "Disabled" summary_final = _render_summary_html( case_name=case_name, model_path=model_path, device_name=model_device.type, processed=len(image_paths), classified=len(classified), threshold=threshold, final_class=final_class, mean_confidence=mean_conf, distribution=class_stats, xdl_status=xdl_status, save_status=save_status, ) yield summary_final, rows, gallery_items