from collections import Counter, defaultdict from dataclasses import dataclass from html import escape import os from pathlib import Path from typing import Iterable, Iterator, List, Optional, Sequence, Tuple import warnings import numpy as np import torch import torch.nn.functional as F from PIL import Image from src.densenet import DenseNet121 from src.preprocessing import preprocess CASE_CONFIGS = { "Multiclass (4 Classes)": { "key": "multiclass", "labels": ["Kolitis Ulseratif", "Infeksi", "Crohn", "Tuberculosis"], "default_model_path": "models/multiclass.pth", "env_var": "MODEL_PATH_MULTICLASS", }, "Crohn vs TB": { "key": "crohn_tb", "labels": ["Crohn", "Tuberculosis"], "default_model_path": "models/crohn_tb.pth", "env_var": "MODEL_PATH_CROHN_TB", }, "Kolitis Ulseratif vs Infeksi": { "key": "ku_infeksi", "labels": ["Kolitis Ulseratif", "Infeksi"], "default_model_path": "models/ku_infeksi.pth", "env_var": "MODEL_PATH_KU_INFEKSI", }, } DEFAULT_CASE_NAME = "Multiclass (4 Classes)" CASE_OPTIONS = list(CASE_CONFIGS.keys()) SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} DEFAULT_SAVE_DIR = "xdl_results" GRADCAM_TARGET_LAYER_OPTIONS = ( "denseblock3", "transition2", "transition1", "denseblock4", "transition3", "norm5_last", ) DEFAULT_GRADCAM_TARGET_LAYER = "denseblock3" def _detect_device() -> torch.device: with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="CUDA initialization:.*") try: has_cuda = torch.cuda.is_available() except Exception: has_cuda = False return torch.device("cuda" if has_cuda else "cpu") DEVICE = _detect_device() _case_state_cache: dict[str, dict] = {} val_transform = preprocess(target_input_size=(3, 299, 299)) @dataclass(frozen=True) class ClassifiedPrediction: path: Path pred_idx: int confidence: float def _get_case_config(case_name: str) -> dict: return CASE_CONFIGS.get(case_name, CASE_CONFIGS[DEFAULT_CASE_NAME]) def _resolve_case_model_path(case_cfg: dict) -> str: return os.getenv(case_cfg["env_var"], case_cfg["default_model_path"]) def _get_case_state(case_name: str) -> dict: case_cfg = _get_case_config(case_name) case_key = case_cfg["key"] if case_key in _case_state_cache: return _case_state_cache[case_key] labels = case_cfg["labels"] model_path = _resolve_case_model_path(case_cfg) model = DenseNet121(num_classes=len(labels)).to(DEVICE) model_error: Optional[str] = None try: model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() except Exception as exc: model_error = f"Model failed to load for case `{case_name}` from `{model_path}`: {exc}" state = { "case_name": case_name, "labels": labels, "model_path": model_path, "model": model, "model_error": model_error, } _case_state_cache[case_key] = state return state def _load_xdl_modules(): """Lazy-load optional XDL dependencies.""" try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget except Exception as exc: raise RuntimeError(f"GradCAM import failed: {exc}") try: import cv2 except Exception as exc: raise RuntimeError(f"OpenCV import failed: {exc}") try: from src.xdl import ( _get_target_layer, _preprocess_image, _process_smoothgrad_map, smoothgrad, ) except Exception as exc: raise RuntimeError(f"Failed to load XDL utilities from src/xdl.py: {exc}") return { "cv2": cv2, "GradCAM": GradCAM, "show_cam_on_image": show_cam_on_image, "ClassifierOutputTarget": ClassifierOutputTarget, "_get_target_layer": _get_target_layer, "_preprocess_image": _preprocess_image, "_process_smoothgrad_map": _process_smoothgrad_map, "smoothgrad": smoothgrad, } def _iter_image_paths(folder: Path) -> List[Path]: return sorted( p for p in folder.rglob("*") if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS ) def _normalize_uploaded_paths(uploaded_files: Optional[Sequence[str] | str]) -> List[Path]: if uploaded_files is None: return [] if isinstance(uploaded_files, str): raw_paths: Iterable[str] = [uploaded_files] else: raw_paths = uploaded_files collected: List[Path] = [] for raw_path in raw_paths: if not raw_path: continue path = Path(str(raw_path)) if path.is_dir(): collected.extend(_iter_image_paths(path)) elif path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS: collected.append(path) return sorted({p.resolve() for p in collected}) def _resolve_input_images(uploaded_files: Optional[Sequence[str] | str], folder_path: str) -> Tuple[List[Path], str]: uploaded_paths = _normalize_uploaded_paths(uploaded_files) local_paths: List[Path] = [] if folder_path: folder = Path(folder_path).expanduser().resolve() if not folder.exists() or not folder.is_dir(): return [], f"Invalid folder: `{folder}`" local_paths = _iter_image_paths(folder) all_paths = sorted({p.resolve() for p in uploaded_paths + local_paths}) if not all_paths: return [], "Upload a folder (or files) or provide a valid folder path." return all_paths, "" def _aggregate_classification(classified: List[ClassifiedPrediction], labels: List[str]) -> Tuple[str, float]: class_counter = Counter(item.pred_idx for item in classified) top_count = max(class_counter.values()) tied_classes = [idx for idx, count in class_counter.items() if count == top_count] if len(tied_classes) == 1: final_idx = tied_classes[0] else: class_conf = defaultdict(list) for item in classified: class_conf[item.pred_idx].append(item.confidence) final_idx = max(tied_classes, key=lambda idx: float(np.mean(class_conf[idx]))) final_conf = float(np.mean([item.confidence for item in classified if item.pred_idx == final_idx])) return labels[final_idx], final_conf def _predict_top1(model: DenseNet121, image: Image.Image) -> Tuple[int, float, torch.Tensor]: model_device = next(model.parameters()).device input_tensor = val_transform(image).unsqueeze(0).to(model_device) with torch.no_grad(): logits = model(input_tensor)[0] probs = F.softmax(logits, dim=0) pred_idx = int(torch.argmax(probs).item()) confidence = float(probs[pred_idx].item()) return pred_idx, confidence, input_tensor def _classify_image_paths( model: DenseNet121, labels: List[str], image_paths: List[Path], threshold: float, ) -> Tuple[List[ClassifiedPrediction], List[List[str]]]: classified: List[ClassifiedPrediction] = [] rows: List[List[str]] = [] for img_path in image_paths: try: image = Image.open(img_path).convert("RGB") except Exception as exc: rows.append([img_path.name, "error", "-", str(exc)]) continue pred_idx, confidence, _ = _predict_top1(model, image) if confidence < threshold: rows.append([img_path.name, "below_threshold", labels[pred_idx], f"{confidence:.4f}"]) continue prediction = ClassifiedPrediction(path=img_path, pred_idx=pred_idx, confidence=confidence) classified.append(prediction) rows.append([img_path.name, "classified", labels[pred_idx], f"{confidence:.4f}"]) return classified, rows def _resize_rgb(image: np.ndarray, side: int) -> np.ndarray: pil_image = Image.fromarray(image) if hasattr(Image, "Resampling"): resized = pil_image.resize((side, side), Image.Resampling.BILINEAR) else: resized = pil_image.resize((side, side), Image.BILINEAR) return np.array(resized) def _build_compact_visual_row( original: np.ndarray, gradcam_img: np.ndarray, smoothgrad_img: np.ndarray, tile_size: int = 192, ) -> np.ndarray: original_small = _resize_rgb(original, tile_size) gradcam_small = _resize_rgb(gradcam_img, tile_size) smoothgrad_small = _resize_rgb(smoothgrad_img, tile_size) return np.concatenate([original_small, gradcam_small, smoothgrad_small], axis=1) def _safe_stem(path: Path) -> str: stem = path.stem cleaned = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in stem) return cleaned[:80] if cleaned else "image" def _save_compact_panel(panel: np.ndarray, image_path: Path, save_dir: Path) -> Path: base = _safe_stem(image_path) out_path = save_dir / f"{base}_xdl.jpg" suffix = 1 while out_path.exists(): out_path = save_dir / f"{base}_xdl_{suffix}.jpg" suffix += 1 Image.fromarray(panel).save(out_path, format="JPEG", quality=95) return out_path def _render_summary_html( *, case_name: str, model_path: str, device_name: str, processed: int, classified: int, threshold: float, final_class: str, mean_confidence: Optional[float], distribution: str, xdl_status: str, save_status: str, ) -> str: if mean_confidence is None: mean_conf_text = "N/A" progress_percent = 0.0 else: progress_percent = float(np.clip(mean_confidence * 100.0, 0.0, 100.0)) mean_conf_text = f"{mean_confidence:.4f}" rows = [ ("Case", case_name), ("Model Path", model_path), ("Device", device_name), ("Processed Images", str(processed)), ("Classified / Skipped", f"{classified} / {processed - classified}"), ("Confidence Threshold", f"{threshold:.2f}"), ("Final Class", final_class), ("Mean Confidence", mean_conf_text), ("Class Distribution", distribution), ("XDL", xdl_status), ("Saved XDL", save_status), ] row_html = "".join( "