lumicero's picture
Update app.py
59fea3a verified
import os
import tempfile
from pathlib import Path
from functools import lru_cache
import numpy as np
import torch
import gradio as gr
from PIL import Image
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from anomalib.data import PredictDataset
from anomalib.models import Patchcore
from anomalib.data.dataclasses.torch import ImageBatch
CKPT_PATH = os.getenv("CKPT_PATH", "model.ckpt")
# 0 -> chibi, 1 -> non-chibi
LABEL_MAP = {0: "chibi", 1: "non-chibi"}
IMAGE_SIZE = (256, 256)
# ---------- Utils ----------
def _viridis_like_colormap(x01: np.ndarray) -> np.ndarray:
"""Map [0..1] -> RGB uint8 using a viridis-like gradient (no extra deps)."""
stops = np.array(
[
[0.00, 68, 1, 84],
[0.25, 59, 82, 139],
[0.50, 33, 145, 140],
[0.75, 94, 201, 97],
[1.00, 253, 231, 37],
],
dtype=np.float32,
)
x = np.clip(x01, 0.0, 1.0).astype(np.float32)
out = np.zeros((*x.shape, 3), dtype=np.float32)
for i in range(len(stops) - 1):
x0, r0, g0, b0 = stops[i]
x1, r1, g1, b1 = stops[i + 1]
mask = (x >= x0) & (x <= x1)
if not np.any(mask):
continue
t = (x[mask] - x0) / (x1 - x0 + 1e-12)
out[mask, 0] = r0 + (r1 - r0) * t
out[mask, 1] = g0 + (g1 - g0) * t
out[mask, 2] = b0 + (b1 - b0) * t
return np.clip(out, 0, 255).astype(np.uint8)
def make_heatmap_overlay(pil_img: Image.Image, anomaly_map, alpha: float = 0.45) -> Image.Image:
"""Overlay anomaly heatmap on resized RGB image."""
img = pil_img.convert("RGB").resize(IMAGE_SIZE)
img_np = np.asarray(img).astype(np.float32)
if isinstance(anomaly_map, torch.Tensor):
amap = anomaly_map.detach().cpu().numpy()
else:
amap = np.asarray(anomaly_map)
amap = np.squeeze(amap)
if amap.ndim != 2:
amap = amap.reshape(IMAGE_SIZE[1], IMAGE_SIZE[0])
amin = float(np.min(amap))
amax = float(np.max(amap))
amap01 = (amap - amin) / (amax - amin + 1e-8)
heat = _viridis_like_colormap(amap01).astype(np.float32)
overlay = (1.0 - alpha) * img_np + alpha * heat
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
return Image.fromarray(overlay)
def _first_pred(pred_batches):
"""Trainer.predict may return list or list-of-list depending on versions."""
if not pred_batches:
return None
x = pred_batches[0]
if isinstance(x, list) and x:
x = x[0]
return x
def _safe_load_state_dict(model: Patchcore, ckpt_path: Path):
ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
return list(missing), list(unexpected)
# ---------- Runtime (cached) ----------
@lru_cache(maxsize=1)
def get_runtime():
ckpt_path = Path(CKPT_PATH)
if not ckpt_path.exists():
raise FileNotFoundError(
f"Checkpoint not found: {ckpt_path.resolve()}\n"
f"Tip: upload it into your Space repo (e.g. weights/model.ckpt) and set CKPT_PATH accordingly."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accelerator = "gpu" if device.type == "cuda" else "cpu"
trainer = Trainer(
accelerator=accelerator,
devices=1,
logger=False,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
model = Patchcore()
missing, unexpected = _safe_load_state_dict(model, ckpt_path)
model.to(device)
model.eval()
return trainer, model, device.type, missing, unexpected
# ---------- Inference ----------
def predict_single(image: Image.Image):
if image is None:
return None, None, None, "No image provided."
trainer, model, dev_type, missing, unexpected = get_runtime()
with tempfile.TemporaryDirectory() as td:
img_file = Path(td) / "input.png"
image.convert("RGB").save(img_file)
dataset = PredictDataset(path=img_file, image_size=IMAGE_SIZE)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=ImageBatch.collate,
)
with torch.inference_mode():
pred_batches = trainer.predict(
model=model,
dataloaders=loader,
return_predictions=True,
)
p = _first_pred(pred_batches)
if p is None:
return None, None, None, "No prediction returned."
pred_label = int(p.pred_label.item()) if isinstance(p.pred_label, torch.Tensor) else int(p.pred_label)
pred_score = float(p.pred_score.item()) if isinstance(p.pred_score, torch.Tensor) else float(p.pred_score)
label_name = LABEL_MAP.get(pred_label, str(pred_label))
heat_overlay = None
try:
if hasattr(p, "anomaly_map") and p.anomaly_map is not None:
heat_overlay = make_heatmap_overlay(image, p.anomaly_map, alpha=0.45)
except Exception:
heat_overlay = None
debug = (
f"device={dev_type} | pred_label={pred_label} -> '{label_name}' | anomaly_score={pred_score:.6f}\n"
f"state_dict: missing={len(missing)} unexpected={len(unexpected)}"
)
return label_name, pred_score, heat_overlay, debug
# ---------- UI (English) ----------
css = """
#title {text-align: center;}
.subtle {opacity: 0.78; font-size: 0.95rem;}
"""
with gr.Blocks() as demo:
gr.Markdown(
"""
<h1 id="title">🧸 Chibi Style Detection</h1>
<p class="subtle">
Upload a single image. Output label: <b>chibi</b> (0) or <b>non-chibi</b> (1).<br/>
Note: The anomaly score is not a probability (it is PatchCore's anomaly score).
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(label="Input Image", sources=["upload", "webcam"], type="pil")
btn = gr.Button("Detect", variant="primary")
gr.Markdown(
"<p class='subtle'>Tip: Use clear character images for best results.</p>"
)
with gr.Column(scale=1):
out_label = gr.Label(label="Prediction")
out_score = gr.Number(label="Anomaly Score")
out_heat = gr.Image(label="Heatmap Overlay", type="pil")
btn.click(
fn=predict_single,
inputs=[inp],
outputs=[out_label, out_score, out_heat],
)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
demo.queue(default_concurrency_limit=1)
demo.launch(
server_name="0.0.0.0",
server_port=port,
ssr_mode=False,
theme=gr.themes.Soft(),
css=css,
)