import os import tempfile from pathlib import Path from functools import lru_cache import numpy as np import torch import gradio as gr from PIL import Image from torch.utils.data import DataLoader from lightning.pytorch import Trainer from anomalib.data import PredictDataset from anomalib.models import Patchcore from anomalib.data.dataclasses.torch import ImageBatch CKPT_PATH = os.getenv("CKPT_PATH", "model.ckpt") # 0 -> chibi, 1 -> non-chibi LABEL_MAP = {0: "chibi", 1: "non-chibi"} IMAGE_SIZE = (256, 256) # ---------- Utils ---------- def _viridis_like_colormap(x01: np.ndarray) -> np.ndarray: """Map [0..1] -> RGB uint8 using a viridis-like gradient (no extra deps).""" stops = np.array( [ [0.00, 68, 1, 84], [0.25, 59, 82, 139], [0.50, 33, 145, 140], [0.75, 94, 201, 97], [1.00, 253, 231, 37], ], dtype=np.float32, ) x = np.clip(x01, 0.0, 1.0).astype(np.float32) out = np.zeros((*x.shape, 3), dtype=np.float32) for i in range(len(stops) - 1): x0, r0, g0, b0 = stops[i] x1, r1, g1, b1 = stops[i + 1] mask = (x >= x0) & (x <= x1) if not np.any(mask): continue t = (x[mask] - x0) / (x1 - x0 + 1e-12) out[mask, 0] = r0 + (r1 - r0) * t out[mask, 1] = g0 + (g1 - g0) * t out[mask, 2] = b0 + (b1 - b0) * t return np.clip(out, 0, 255).astype(np.uint8) def make_heatmap_overlay(pil_img: Image.Image, anomaly_map, alpha: float = 0.45) -> Image.Image: """Overlay anomaly heatmap on resized RGB image.""" img = pil_img.convert("RGB").resize(IMAGE_SIZE) img_np = np.asarray(img).astype(np.float32) if isinstance(anomaly_map, torch.Tensor): amap = anomaly_map.detach().cpu().numpy() else: amap = np.asarray(anomaly_map) amap = np.squeeze(amap) if amap.ndim != 2: amap = amap.reshape(IMAGE_SIZE[1], IMAGE_SIZE[0]) amin = float(np.min(amap)) amax = float(np.max(amap)) amap01 = (amap - amin) / (amax - amin + 1e-8) heat = _viridis_like_colormap(amap01).astype(np.float32) overlay = (1.0 - alpha) * img_np + alpha * heat overlay = np.clip(overlay, 0, 255).astype(np.uint8) return Image.fromarray(overlay) def _first_pred(pred_batches): """Trainer.predict may return list or list-of-list depending on versions.""" if not pred_batches: return None x = pred_batches[0] if isinstance(x, list) and x: x = x[0] return x def _safe_load_state_dict(model: Patchcore, ckpt_path: Path): ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) state_dict = ckpt.get("state_dict", ckpt) missing, unexpected = model.load_state_dict(state_dict, strict=False) return list(missing), list(unexpected) # ---------- Runtime (cached) ---------- @lru_cache(maxsize=1) def get_runtime(): ckpt_path = Path(CKPT_PATH) if not ckpt_path.exists(): raise FileNotFoundError( f"Checkpoint not found: {ckpt_path.resolve()}\n" f"Tip: upload it into your Space repo (e.g. weights/model.ckpt) and set CKPT_PATH accordingly." ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") accelerator = "gpu" if device.type == "cuda" else "cpu" trainer = Trainer( accelerator=accelerator, devices=1, logger=False, enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, ) model = Patchcore() missing, unexpected = _safe_load_state_dict(model, ckpt_path) model.to(device) model.eval() return trainer, model, device.type, missing, unexpected # ---------- Inference ---------- def predict_single(image: Image.Image): if image is None: return None, None, None, "No image provided." trainer, model, dev_type, missing, unexpected = get_runtime() with tempfile.TemporaryDirectory() as td: img_file = Path(td) / "input.png" image.convert("RGB").save(img_file) dataset = PredictDataset(path=img_file, image_size=IMAGE_SIZE) loader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=ImageBatch.collate, ) with torch.inference_mode(): pred_batches = trainer.predict( model=model, dataloaders=loader, return_predictions=True, ) p = _first_pred(pred_batches) if p is None: return None, None, None, "No prediction returned." pred_label = int(p.pred_label.item()) if isinstance(p.pred_label, torch.Tensor) else int(p.pred_label) pred_score = float(p.pred_score.item()) if isinstance(p.pred_score, torch.Tensor) else float(p.pred_score) label_name = LABEL_MAP.get(pred_label, str(pred_label)) heat_overlay = None try: if hasattr(p, "anomaly_map") and p.anomaly_map is not None: heat_overlay = make_heatmap_overlay(image, p.anomaly_map, alpha=0.45) except Exception: heat_overlay = None debug = ( f"device={dev_type} | pred_label={pred_label} -> '{label_name}' | anomaly_score={pred_score:.6f}\n" f"state_dict: missing={len(missing)} unexpected={len(unexpected)}" ) return label_name, pred_score, heat_overlay, debug # ---------- UI (English) ---------- css = """ #title {text-align: center;} .subtle {opacity: 0.78; font-size: 0.95rem;} """ with gr.Blocks() as demo: gr.Markdown( """

🧸 Chibi Style Detection

Upload a single image. Output label: chibi (0) or non-chibi (1).
Note: The anomaly score is not a probability (it is PatchCore's anomaly score).

""" ) with gr.Row(): with gr.Column(scale=1): inp = gr.Image(label="Input Image", sources=["upload", "webcam"], type="pil") btn = gr.Button("Detect", variant="primary") gr.Markdown( "

Tip: Use clear character images for best results.

" ) with gr.Column(scale=1): out_label = gr.Label(label="Prediction") out_score = gr.Number(label="Anomaly Score") out_heat = gr.Image(label="Heatmap Overlay", type="pil") btn.click( fn=predict_single, inputs=[inp], outputs=[out_label, out_score, out_heat], ) if __name__ == "__main__": port = int(os.getenv("PORT", "7860")) demo.queue(default_concurrency_limit=1) demo.launch( server_name="0.0.0.0", server_port=port, ssr_mode=False, theme=gr.themes.Soft(), css=css, )