Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from functools import lru_cache | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from lightning.pytorch import Trainer | |
| from anomalib.data import PredictDataset | |
| from anomalib.models import Patchcore | |
| from anomalib.data.dataclasses.torch import ImageBatch | |
| CKPT_PATH = os.getenv("CKPT_PATH", "model.ckpt") | |
| # 0 -> chibi, 1 -> non-chibi | |
| LABEL_MAP = {0: "chibi", 1: "non-chibi"} | |
| IMAGE_SIZE = (256, 256) | |
| # ---------- Utils ---------- | |
| def _viridis_like_colormap(x01: np.ndarray) -> np.ndarray: | |
| """Map [0..1] -> RGB uint8 using a viridis-like gradient (no extra deps).""" | |
| stops = np.array( | |
| [ | |
| [0.00, 68, 1, 84], | |
| [0.25, 59, 82, 139], | |
| [0.50, 33, 145, 140], | |
| [0.75, 94, 201, 97], | |
| [1.00, 253, 231, 37], | |
| ], | |
| dtype=np.float32, | |
| ) | |
| x = np.clip(x01, 0.0, 1.0).astype(np.float32) | |
| out = np.zeros((*x.shape, 3), dtype=np.float32) | |
| for i in range(len(stops) - 1): | |
| x0, r0, g0, b0 = stops[i] | |
| x1, r1, g1, b1 = stops[i + 1] | |
| mask = (x >= x0) & (x <= x1) | |
| if not np.any(mask): | |
| continue | |
| t = (x[mask] - x0) / (x1 - x0 + 1e-12) | |
| out[mask, 0] = r0 + (r1 - r0) * t | |
| out[mask, 1] = g0 + (g1 - g0) * t | |
| out[mask, 2] = b0 + (b1 - b0) * t | |
| return np.clip(out, 0, 255).astype(np.uint8) | |
| def make_heatmap_overlay(pil_img: Image.Image, anomaly_map, alpha: float = 0.45) -> Image.Image: | |
| """Overlay anomaly heatmap on resized RGB image.""" | |
| img = pil_img.convert("RGB").resize(IMAGE_SIZE) | |
| img_np = np.asarray(img).astype(np.float32) | |
| if isinstance(anomaly_map, torch.Tensor): | |
| amap = anomaly_map.detach().cpu().numpy() | |
| else: | |
| amap = np.asarray(anomaly_map) | |
| amap = np.squeeze(amap) | |
| if amap.ndim != 2: | |
| amap = amap.reshape(IMAGE_SIZE[1], IMAGE_SIZE[0]) | |
| amin = float(np.min(amap)) | |
| amax = float(np.max(amap)) | |
| amap01 = (amap - amin) / (amax - amin + 1e-8) | |
| heat = _viridis_like_colormap(amap01).astype(np.float32) | |
| overlay = (1.0 - alpha) * img_np + alpha * heat | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| return Image.fromarray(overlay) | |
| def _first_pred(pred_batches): | |
| """Trainer.predict may return list or list-of-list depending on versions.""" | |
| if not pred_batches: | |
| return None | |
| x = pred_batches[0] | |
| if isinstance(x, list) and x: | |
| x = x[0] | |
| return x | |
| def _safe_load_state_dict(model: Patchcore, ckpt_path: Path): | |
| ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) | |
| state_dict = ckpt.get("state_dict", ckpt) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| return list(missing), list(unexpected) | |
| # ---------- Runtime (cached) ---------- | |
| def get_runtime(): | |
| ckpt_path = Path(CKPT_PATH) | |
| if not ckpt_path.exists(): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found: {ckpt_path.resolve()}\n" | |
| f"Tip: upload it into your Space repo (e.g. weights/model.ckpt) and set CKPT_PATH accordingly." | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| accelerator = "gpu" if device.type == "cuda" else "cpu" | |
| trainer = Trainer( | |
| accelerator=accelerator, | |
| devices=1, | |
| logger=False, | |
| enable_checkpointing=False, | |
| enable_model_summary=False, | |
| enable_progress_bar=False, | |
| ) | |
| model = Patchcore() | |
| missing, unexpected = _safe_load_state_dict(model, ckpt_path) | |
| model.to(device) | |
| model.eval() | |
| return trainer, model, device.type, missing, unexpected | |
| # ---------- Inference ---------- | |
| def predict_single(image: Image.Image): | |
| if image is None: | |
| return None, None, None, "No image provided." | |
| trainer, model, dev_type, missing, unexpected = get_runtime() | |
| with tempfile.TemporaryDirectory() as td: | |
| img_file = Path(td) / "input.png" | |
| image.convert("RGB").save(img_file) | |
| dataset = PredictDataset(path=img_file, image_size=IMAGE_SIZE) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=0, | |
| collate_fn=ImageBatch.collate, | |
| ) | |
| with torch.inference_mode(): | |
| pred_batches = trainer.predict( | |
| model=model, | |
| dataloaders=loader, | |
| return_predictions=True, | |
| ) | |
| p = _first_pred(pred_batches) | |
| if p is None: | |
| return None, None, None, "No prediction returned." | |
| pred_label = int(p.pred_label.item()) if isinstance(p.pred_label, torch.Tensor) else int(p.pred_label) | |
| pred_score = float(p.pred_score.item()) if isinstance(p.pred_score, torch.Tensor) else float(p.pred_score) | |
| label_name = LABEL_MAP.get(pred_label, str(pred_label)) | |
| heat_overlay = None | |
| try: | |
| if hasattr(p, "anomaly_map") and p.anomaly_map is not None: | |
| heat_overlay = make_heatmap_overlay(image, p.anomaly_map, alpha=0.45) | |
| except Exception: | |
| heat_overlay = None | |
| debug = ( | |
| f"device={dev_type} | pred_label={pred_label} -> '{label_name}' | anomaly_score={pred_score:.6f}\n" | |
| f"state_dict: missing={len(missing)} unexpected={len(unexpected)}" | |
| ) | |
| return label_name, pred_score, heat_overlay, debug | |
| # ---------- UI (English) ---------- | |
| css = """ | |
| #title {text-align: center;} | |
| .subtle {opacity: 0.78; font-size: 0.95rem;} | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| <h1 id="title">🧸 Chibi Style Detection</h1> | |
| <p class="subtle"> | |
| Upload a single image. Output label: <b>chibi</b> (0) or <b>non-chibi</b> (1).<br/> | |
| Note: The anomaly score is not a probability (it is PatchCore's anomaly score). | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Image(label="Input Image", sources=["upload", "webcam"], type="pil") | |
| btn = gr.Button("Detect", variant="primary") | |
| gr.Markdown( | |
| "<p class='subtle'>Tip: Use clear character images for best results.</p>" | |
| ) | |
| with gr.Column(scale=1): | |
| out_label = gr.Label(label="Prediction") | |
| out_score = gr.Number(label="Anomaly Score") | |
| out_heat = gr.Image(label="Heatmap Overlay", type="pil") | |
| btn.click( | |
| fn=predict_single, | |
| inputs=[inp], | |
| outputs=[out_label, out_score, out_heat], | |
| ) | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| demo.queue(default_concurrency_limit=1) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| ssr_mode=False, | |
| theme=gr.themes.Soft(), | |
| css=css, | |
| ) | |