Spaces:
Sleeping
Sleeping
| from collections import Counter, defaultdict | |
| from dataclasses import dataclass | |
| from html import escape | |
| import os | |
| from pathlib import Path | |
| from typing import Iterable, Iterator, List, Optional, Sequence, Tuple | |
| import warnings | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from src.densenet import DenseNet121 | |
| from src.preprocessing import preprocess | |
| CASE_CONFIGS = { | |
| "Multiclass (4 Classes)": { | |
| "key": "multiclass", | |
| "labels": ["Kolitis Ulseratif", "Infeksi", "Crohn", "Tuberculosis"], | |
| "default_model_path": "models/multiclass.pth", | |
| "env_var": "MODEL_PATH_MULTICLASS", | |
| }, | |
| "Crohn vs TB": { | |
| "key": "crohn_tb", | |
| "labels": ["Crohn", "Tuberculosis"], | |
| "default_model_path": "models/crohn_tb.pth", | |
| "env_var": "MODEL_PATH_CROHN_TB", | |
| }, | |
| "Kolitis Ulseratif vs Infeksi": { | |
| "key": "ku_infeksi", | |
| "labels": ["Kolitis Ulseratif", "Infeksi"], | |
| "default_model_path": "models/ku_infeksi.pth", | |
| "env_var": "MODEL_PATH_KU_INFEKSI", | |
| }, | |
| } | |
| DEFAULT_CASE_NAME = "Multiclass (4 Classes)" | |
| CASE_OPTIONS = list(CASE_CONFIGS.keys()) | |
| SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} | |
| DEFAULT_SAVE_DIR = "xdl_results" | |
| GRADCAM_TARGET_LAYER_OPTIONS = ( | |
| "denseblock3", | |
| "transition2", | |
| "transition1", | |
| "denseblock4", | |
| "transition3", | |
| "norm5_last", | |
| ) | |
| DEFAULT_GRADCAM_TARGET_LAYER = "denseblock3" | |
| def _detect_device() -> torch.device: | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", message="CUDA initialization:.*") | |
| try: | |
| has_cuda = torch.cuda.is_available() | |
| except Exception: | |
| has_cuda = False | |
| return torch.device("cuda" if has_cuda else "cpu") | |
| DEVICE = _detect_device() | |
| _case_state_cache: dict[str, dict] = {} | |
| val_transform = preprocess(target_input_size=(3, 299, 299)) | |
| class ClassifiedPrediction: | |
| path: Path | |
| pred_idx: int | |
| confidence: float | |
| def _get_case_config(case_name: str) -> dict: | |
| return CASE_CONFIGS.get(case_name, CASE_CONFIGS[DEFAULT_CASE_NAME]) | |
| def _resolve_case_model_path(case_cfg: dict) -> str: | |
| return os.getenv(case_cfg["env_var"], case_cfg["default_model_path"]) | |
| def _get_case_state(case_name: str) -> dict: | |
| case_cfg = _get_case_config(case_name) | |
| case_key = case_cfg["key"] | |
| if case_key in _case_state_cache: | |
| return _case_state_cache[case_key] | |
| labels = case_cfg["labels"] | |
| model_path = _resolve_case_model_path(case_cfg) | |
| model = DenseNet121(num_classes=len(labels)).to(DEVICE) | |
| model_error: Optional[str] = None | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| except Exception as exc: | |
| model_error = f"Model failed to load for case `{case_name}` from `{model_path}`: {exc}" | |
| state = { | |
| "case_name": case_name, | |
| "labels": labels, | |
| "model_path": model_path, | |
| "model": model, | |
| "model_error": model_error, | |
| } | |
| _case_state_cache[case_key] = state | |
| return state | |
| def _load_xdl_modules(): | |
| """Lazy-load optional XDL dependencies.""" | |
| try: | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| except Exception as exc: | |
| raise RuntimeError(f"GradCAM import failed: {exc}") | |
| try: | |
| import cv2 | |
| except Exception as exc: | |
| raise RuntimeError(f"OpenCV import failed: {exc}") | |
| try: | |
| from src.xdl import ( | |
| _get_target_layer, | |
| _preprocess_image, | |
| _process_smoothgrad_map, | |
| smoothgrad, | |
| ) | |
| except Exception as exc: | |
| raise RuntimeError(f"Failed to load XDL utilities from src/xdl.py: {exc}") | |
| return { | |
| "cv2": cv2, | |
| "GradCAM": GradCAM, | |
| "show_cam_on_image": show_cam_on_image, | |
| "ClassifierOutputTarget": ClassifierOutputTarget, | |
| "_get_target_layer": _get_target_layer, | |
| "_preprocess_image": _preprocess_image, | |
| "_process_smoothgrad_map": _process_smoothgrad_map, | |
| "smoothgrad": smoothgrad, | |
| } | |
| def _iter_image_paths(folder: Path) -> List[Path]: | |
| return sorted( | |
| p | |
| for p in folder.rglob("*") | |
| if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS | |
| ) | |
| def _normalize_uploaded_paths(uploaded_files: Optional[Sequence[str] | str]) -> List[Path]: | |
| if uploaded_files is None: | |
| return [] | |
| if isinstance(uploaded_files, str): | |
| raw_paths: Iterable[str] = [uploaded_files] | |
| else: | |
| raw_paths = uploaded_files | |
| collected: List[Path] = [] | |
| for raw_path in raw_paths: | |
| if not raw_path: | |
| continue | |
| path = Path(str(raw_path)) | |
| if path.is_dir(): | |
| collected.extend(_iter_image_paths(path)) | |
| elif path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS: | |
| collected.append(path) | |
| return sorted({p.resolve() for p in collected}) | |
| def _resolve_input_images(uploaded_files: Optional[Sequence[str] | str], folder_path: str) -> Tuple[List[Path], str]: | |
| uploaded_paths = _normalize_uploaded_paths(uploaded_files) | |
| local_paths: List[Path] = [] | |
| if folder_path: | |
| folder = Path(folder_path).expanduser().resolve() | |
| if not folder.exists() or not folder.is_dir(): | |
| return [], f"Invalid folder: `{folder}`" | |
| local_paths = _iter_image_paths(folder) | |
| all_paths = sorted({p.resolve() for p in uploaded_paths + local_paths}) | |
| if not all_paths: | |
| return [], "Upload a folder (or files) or provide a valid folder path." | |
| return all_paths, "" | |
| def _aggregate_classification(classified: List[ClassifiedPrediction], labels: List[str]) -> Tuple[str, float]: | |
| class_counter = Counter(item.pred_idx for item in classified) | |
| top_count = max(class_counter.values()) | |
| tied_classes = [idx for idx, count in class_counter.items() if count == top_count] | |
| if len(tied_classes) == 1: | |
| final_idx = tied_classes[0] | |
| else: | |
| class_conf = defaultdict(list) | |
| for item in classified: | |
| class_conf[item.pred_idx].append(item.confidence) | |
| final_idx = max(tied_classes, key=lambda idx: float(np.mean(class_conf[idx]))) | |
| final_conf = float(np.mean([item.confidence for item in classified if item.pred_idx == final_idx])) | |
| return labels[final_idx], final_conf | |
| def _predict_top1(model: DenseNet121, image: Image.Image) -> Tuple[int, float, torch.Tensor]: | |
| model_device = next(model.parameters()).device | |
| input_tensor = val_transform(image).unsqueeze(0).to(model_device) | |
| with torch.no_grad(): | |
| logits = model(input_tensor)[0] | |
| probs = F.softmax(logits, dim=0) | |
| pred_idx = int(torch.argmax(probs).item()) | |
| confidence = float(probs[pred_idx].item()) | |
| return pred_idx, confidence, input_tensor | |
| def _classify_image_paths( | |
| model: DenseNet121, | |
| labels: List[str], | |
| image_paths: List[Path], | |
| threshold: float, | |
| ) -> Tuple[List[ClassifiedPrediction], List[List[str]]]: | |
| classified: List[ClassifiedPrediction] = [] | |
| rows: List[List[str]] = [] | |
| for img_path in image_paths: | |
| try: | |
| image = Image.open(img_path).convert("RGB") | |
| except Exception as exc: | |
| rows.append([img_path.name, "error", "-", str(exc)]) | |
| continue | |
| pred_idx, confidence, _ = _predict_top1(model, image) | |
| if confidence < threshold: | |
| rows.append([img_path.name, "below_threshold", labels[pred_idx], f"{confidence:.4f}"]) | |
| continue | |
| prediction = ClassifiedPrediction(path=img_path, pred_idx=pred_idx, confidence=confidence) | |
| classified.append(prediction) | |
| rows.append([img_path.name, "classified", labels[pred_idx], f"{confidence:.4f}"]) | |
| return classified, rows | |
| def _resize_rgb(image: np.ndarray, side: int) -> np.ndarray: | |
| pil_image = Image.fromarray(image) | |
| if hasattr(Image, "Resampling"): | |
| resized = pil_image.resize((side, side), Image.Resampling.BILINEAR) | |
| else: | |
| resized = pil_image.resize((side, side), Image.BILINEAR) | |
| return np.array(resized) | |
| def _build_compact_visual_row( | |
| original: np.ndarray, | |
| gradcam_img: np.ndarray, | |
| smoothgrad_img: np.ndarray, | |
| tile_size: int = 192, | |
| ) -> np.ndarray: | |
| original_small = _resize_rgb(original, tile_size) | |
| gradcam_small = _resize_rgb(gradcam_img, tile_size) | |
| smoothgrad_small = _resize_rgb(smoothgrad_img, tile_size) | |
| return np.concatenate([original_small, gradcam_small, smoothgrad_small], axis=1) | |
| def _safe_stem(path: Path) -> str: | |
| stem = path.stem | |
| cleaned = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in stem) | |
| return cleaned[:80] if cleaned else "image" | |
| def _save_compact_panel(panel: np.ndarray, image_path: Path, save_dir: Path) -> Path: | |
| base = _safe_stem(image_path) | |
| out_path = save_dir / f"{base}_xdl.jpg" | |
| suffix = 1 | |
| while out_path.exists(): | |
| out_path = save_dir / f"{base}_xdl_{suffix}.jpg" | |
| suffix += 1 | |
| Image.fromarray(panel).save(out_path, format="JPEG", quality=95) | |
| return out_path | |
| def _render_summary_html( | |
| *, | |
| case_name: str, | |
| model_path: str, | |
| device_name: str, | |
| processed: int, | |
| classified: int, | |
| threshold: float, | |
| final_class: str, | |
| mean_confidence: Optional[float], | |
| distribution: str, | |
| xdl_status: str, | |
| save_status: str, | |
| ) -> str: | |
| if mean_confidence is None: | |
| mean_conf_text = "N/A" | |
| progress_percent = 0.0 | |
| else: | |
| progress_percent = float(np.clip(mean_confidence * 100.0, 0.0, 100.0)) | |
| mean_conf_text = f"{mean_confidence:.4f}" | |
| rows = [ | |
| ("Case", case_name), | |
| ("Model Path", model_path), | |
| ("Device", device_name), | |
| ("Processed Images", str(processed)), | |
| ("Classified / Skipped", f"{classified} / {processed - classified}"), | |
| ("Confidence Threshold", f"{threshold:.2f}"), | |
| ("Final Class", final_class), | |
| ("Mean Confidence", mean_conf_text), | |
| ("Class Distribution", distribution), | |
| ("XDL", xdl_status), | |
| ("Saved XDL", save_status), | |
| ] | |
| row_html = "".join( | |
| "<tr>" | |
| f"<th style='text-align:left;padding:8px 10px;border-bottom:1px solid #e5e7eb;width:220px;color:#374151'>{escape(key)}</th>" | |
| f"<td style='padding:8px 10px;border-bottom:1px solid #e5e7eb;color:#111827'>{escape(value)}</td>" | |
| "</tr>" | |
| for key, value in rows | |
| ) | |
| return ( | |
| "<div style='border:1px solid #d1d5db;border-radius:10px;padding:12px;background:#f9fafb'>" | |
| "<div style='font-weight:600;font-size:16px;margin-bottom:8px;color:#111827'>Batch Summary</div>" | |
| "<div style='display:flex;gap:14px;align-items:stretch;flex-wrap:wrap'>" | |
| "<div style='flex:1 1 430px;min-width:320px'>" | |
| "<table style='width:100%;border-collapse:collapse;font-size:14px'>" | |
| f"{row_html}" | |
| "</table>" | |
| "</div>" | |
| "<div style='flex:0 0 220px;display:flex;align-items:center;justify-content:center'>" | |
| "<div style='display:flex;flex-direction:column;align-items:center;gap:8px'>" | |
| "<div style='font-size:13px;color:#374151'>Predicted Class</div>" | |
| f"<div style='font-size:15px;font-weight:600;color:#111827'>{escape(final_class)}</div>" | |
| "<div style='position:relative;width:140px;height:140px'>" | |
| "<div style='position:absolute;inset:0;border-radius:50%;" | |
| f"background:conic-gradient(#0ea5e9 {progress_percent:.2f}%, #e5e7eb 0);'></div>" | |
| "<div style='position:absolute;inset:12px;border-radius:50%;background:white;" | |
| "display:flex;flex-direction:column;align-items:center;justify-content:center'>" | |
| "<div style='font-size:12px;color:#6b7280'>Mean Conf</div>" | |
| f"<div style='font-size:20px;font-weight:700;color:#0f172a'>{progress_percent:.1f}%</div>" | |
| "</div>" | |
| "</div>" | |
| "</div>" | |
| "</div>" | |
| "</div>" | |
| ) | |
| def _render_error_html(message: str) -> str: | |
| return ( | |
| "<div style='border:1px solid #fca5a5;border-radius:10px;padding:12px;background:#fef2f2;color:#7f1d1d'>" | |
| "<div style='font-weight:600;margin-bottom:6px'>Input/Error</div>" | |
| f"<div>{escape(message)}</div>" | |
| "</div>" | |
| ) | |
| def batch_predict_with_xdl( | |
| uploaded_files, | |
| selected_case: str, | |
| folder_path: str, | |
| confidence_threshold: float, | |
| smoothgrad_samples: int, | |
| smoothgrad_noise: float, | |
| save_xdl_results: bool, | |
| save_xdl_dir: str, | |
| gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER, | |
| ): | |
| last_output = None | |
| for payload in batch_predict_with_xdl_stream( | |
| uploaded_files=uploaded_files, | |
| selected_case=selected_case, | |
| folder_path=folder_path, | |
| confidence_threshold=confidence_threshold, | |
| smoothgrad_samples=smoothgrad_samples, | |
| smoothgrad_noise=smoothgrad_noise, | |
| save_xdl_results=save_xdl_results, | |
| save_xdl_dir=save_xdl_dir, | |
| gradcam_target_layer=gradcam_target_layer, | |
| ): | |
| last_output = payload | |
| if last_output is None: | |
| return _render_error_html("No output generated."), [], [] | |
| return last_output | |
| def batch_predict_with_xdl_stream( | |
| uploaded_files, | |
| selected_case: str, | |
| folder_path: str, | |
| confidence_threshold: float, | |
| smoothgrad_samples: int, | |
| smoothgrad_noise: float, | |
| save_xdl_results: bool, | |
| save_xdl_dir: str, | |
| gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER, | |
| ) -> Iterator[Tuple[str, List[List[str]], List[Tuple[np.ndarray, str]]]]: | |
| case_state = _get_case_state(selected_case) | |
| model: DenseNet121 = case_state["model"] | |
| labels: List[str] = case_state["labels"] | |
| case_name: str = case_state["case_name"] | |
| model_path: str = case_state["model_path"] | |
| model_error: Optional[str] = case_state["model_error"] | |
| if model_error: | |
| yield _render_error_html(model_error), [], [] | |
| return | |
| model_device = next(model.parameters()).device | |
| threshold = float(np.clip(confidence_threshold, 0.0, 1.0)) | |
| smoothgrad_samples = int(max(1, smoothgrad_samples)) | |
| smoothgrad_noise = float(max(0.0, smoothgrad_noise)) | |
| gradcam_target_layer = str(gradcam_target_layer or DEFAULT_GRADCAM_TARGET_LAYER).strip().lower() | |
| if gradcam_target_layer not in GRADCAM_TARGET_LAYER_OPTIONS: | |
| gradcam_target_layer = DEFAULT_GRADCAM_TARGET_LAYER | |
| image_paths, input_error = _resolve_input_images(uploaded_files, folder_path) | |
| if input_error: | |
| yield _render_error_html(input_error), [], [] | |
| return | |
| save_dir_path: Optional[Path] = None | |
| save_error = "" | |
| saved_count = 0 | |
| if save_xdl_results: | |
| raw_dir = save_xdl_dir.strip() if save_xdl_dir else DEFAULT_SAVE_DIR | |
| try: | |
| save_dir_path = Path(raw_dir).expanduser().resolve() | |
| save_dir_path.mkdir(parents=True, exist_ok=True) | |
| except Exception as exc: | |
| save_error = f"Save disabled: {exc}" | |
| save_dir_path = None | |
| classified, rows = _classify_image_paths(model, labels, image_paths, threshold) | |
| if save_xdl_results: | |
| if save_error: | |
| save_status = save_error | |
| elif save_dir_path is not None and not classified: | |
| save_status = f"No files saved (0 classified). Target: {save_dir_path}" | |
| else: | |
| save_status = "Pending..." | |
| else: | |
| save_status = "Disabled" | |
| if classified: | |
| final_class, mean_conf = _aggregate_classification(classified, labels) | |
| class_counter = Counter(item.pred_idx for item in classified) | |
| class_stats = ", ".join(f"{labels[idx]}: {count}" for idx, count in class_counter.items()) | |
| initial_xdl_status = f"Processing overlays... (GradCAM layer: {gradcam_target_layer})" | |
| else: | |
| final_class = "N/A" | |
| mean_conf = None | |
| class_stats = "N/A" | |
| initial_xdl_status = "Skipped (no classified images)" | |
| summary_initial = _render_summary_html( | |
| case_name=case_name, | |
| model_path=model_path, | |
| device_name=model_device.type, | |
| processed=len(image_paths), | |
| classified=len(classified), | |
| threshold=threshold, | |
| final_class=final_class, | |
| mean_confidence=mean_conf, | |
| distribution=class_stats, | |
| xdl_status=initial_xdl_status, | |
| save_status=save_status, | |
| ) | |
| # First UI update: classification results first. | |
| yield summary_initial, rows, [] | |
| if not classified: | |
| return | |
| xdl = None | |
| xdl_error = "" | |
| try: | |
| xdl = _load_xdl_modules() | |
| except RuntimeError as exc: | |
| xdl_error = str(exc) | |
| gallery_items: List[Tuple[np.ndarray, str]] = [] | |
| xdl_error_count = 0 | |
| if xdl is not None: | |
| try: | |
| target_layer = xdl["_get_target_layer"](model, layer_name=gradcam_target_layer) | |
| except TypeError: | |
| # Backward compatibility for older helper signature: _get_target_layer(model) | |
| target_layer = xdl["_get_target_layer"](model) | |
| cam = xdl["GradCAM"](model=model, target_layers=[target_layer]) | |
| for item in classified: | |
| try: | |
| image = Image.open(item.path).convert("RGB") | |
| input_tensor = val_transform(image).unsqueeze(0).to(model_device) | |
| base_img_float, base_img_uint8 = xdl["_preprocess_image"](input_tensor[0]) | |
| h, w = base_img_uint8.shape[:2] | |
| grayscale_cam = cam( | |
| input_tensor=input_tensor, | |
| targets=[xdl["ClassifierOutputTarget"](item.pred_idx)], | |
| )[0, :] | |
| gradcam_overlay = xdl["show_cam_on_image"](base_img_float, grayscale_cam, use_rgb=True) | |
| smooth_raw = xdl["smoothgrad"]( | |
| model, | |
| input_tensor, | |
| item.pred_idx, | |
| n_samples=smoothgrad_samples, | |
| noise_level=smoothgrad_noise, | |
| use_amp=(model_device.type == "cuda"), | |
| ) | |
| _, smooth_heatmap = xdl["_process_smoothgrad_map"]( | |
| smooth_raw, | |
| img_shape=(h, w), | |
| percentile=95, | |
| colormap="hot", | |
| ) | |
| # Match old-original-xdl.py blending behavior exactly. | |
| smooth_overlay = xdl["cv2"].addWeighted(base_img_uint8, 0.5, smooth_heatmap, 0.5, 0) | |
| compact_panel = _build_compact_visual_row(base_img_uint8, gradcam_overlay, smooth_overlay) | |
| caption = f"{item.path.name} | {labels[item.pred_idx]} ({item.confidence:.3f})" | |
| gallery_items.append((compact_panel, caption)) | |
| if save_dir_path is not None: | |
| _save_compact_panel(compact_panel, item.path, save_dir_path) | |
| saved_count += 1 | |
| except Exception as exc: | |
| xdl_error_count += 1 | |
| rows.append([item.path.name, "xdl_error", labels[item.pred_idx], str(exc)]) | |
| if xdl_error: | |
| xdl_status = f"Disabled: {xdl_error}" | |
| elif xdl_error_count: | |
| xdl_status = f"Completed with {xdl_error_count} overlay errors" | |
| else: | |
| xdl_status = f"Completed ({len(gallery_items)} overlays, layer: {gradcam_target_layer})" | |
| if save_xdl_results: | |
| if save_error: | |
| save_status = save_error | |
| elif save_dir_path is not None: | |
| save_status = f"Saved {saved_count} files to {save_dir_path}" | |
| else: | |
| save_status = "Disabled" | |
| else: | |
| save_status = "Disabled" | |
| summary_final = _render_summary_html( | |
| case_name=case_name, | |
| model_path=model_path, | |
| device_name=model_device.type, | |
| processed=len(image_paths), | |
| classified=len(classified), | |
| threshold=threshold, | |
| final_class=final_class, | |
| mean_confidence=mean_conf, | |
| distribution=class_stats, | |
| xdl_status=xdl_status, | |
| save_status=save_status, | |
| ) | |
| yield summary_final, rows, gallery_items | |