""" ITDF — Implicit Topological De-Filtering Makeup Removal Demo Upload a makeup face image → get bare face + identity score + ECC topology plot. """ import os import math import yaml import torch import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from models.itdf import ITDF # ── constants ────────────────────────────────────────────────────────────────── MODEL_REPO = "priyadip/itdf-model" CKPT_FILE = "ckpt_final.pt" CONFIG_FILE = "default.yaml" IMG_SIZE = 512 DEVICE = torch.device("cpu") # ── load model at startup ────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN", None) print("Downloading config …") cfg_path = hf_hub_download( repo_id=MODEL_REPO, filename=CONFIG_FILE, repo_type="model", token=HF_TOKEN, ) with open(cfg_path) as f: cfg = yaml.safe_load(f) print("Downloading checkpoint …") ckpt_path = hf_hub_download( repo_id=MODEL_REPO, filename=CKPT_FILE, repo_type="model", token=HF_TOKEN, ) print("Loading ITDF model …") model = ITDF(cfg).to(DEVICE) ckpt = torch.load(ckpt_path, map_location=DEVICE) # checkpoint may wrap weights under 'model' key state = ckpt.get("model", ckpt.get("G", ckpt)) model.load_state_dict(state, strict=False) model.eval() print("Model ready.") # ── ArcFace / identity score ─────────────────────────────────────────────────── try: from facenet_pytorch import InceptionResnetV1 facenet = InceptionResnetV1(pretrained="vggface2").eval().to(DEVICE) FACENET_OK = True print("FaceNet ready.") except Exception as e: FACENET_OK = False print(f"FaceNet not available: {e}") def _facenet_embed(img_pil: Image.Image) -> torch.Tensor | None: """Return L2-normalised 512-d face embedding or None if FaceNet unavailable.""" if not FACENET_OK: return None import torchvision.transforms as T tf = T.Compose([ T.Resize((160, 160)), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) x = tf(img_pil).unsqueeze(0).to(DEVICE) with torch.no_grad(): emb = facenet(x) return torch.nn.functional.normalize(emb, dim=1) def identity_score(inp: Image.Image, out: Image.Image) -> float | None: e1 = _facenet_embed(inp) e2 = _facenet_embed(out) if e1 is None or e2 is None: return None return float((e1 * e2).sum().clamp(-1, 1).item()) # ── ECC computation ──────────────────────────────────────────────────────────── N_DIRECTIONS = 9 # 8 spatial-mix + 1 pure-intensity N_THRESHOLDS = 64 TAU = 0.1 def _ecc_from_heights(heights: torch.Tensor, H: int, W: int) -> torch.Tensor: """Smooth ECC (B, T) from per-pixel lower-star heights (B, H*W).""" thresholds = torch.linspace(0, 1, N_THRESHOLDS, device=heights.device) t = thresholds.unsqueeze(0).unsqueeze(0) # (1, 1, T) chi_v = torch.sigmoid((t - heights.unsqueeze(-1)) / TAU).sum(dim=1) h2d = heights.reshape(heights.shape[0], H, W) h_eh = torch.max(h2d[:, :, :-1], h2d[:, :, 1:]).reshape(heights.shape[0], -1) chi_eh = torch.sigmoid((t - h_eh.unsqueeze(-1)) / TAU).sum(dim=1) h_ev = torch.max(h2d[:, :-1, :], h2d[:, 1:, :]).reshape(heights.shape[0], -1) chi_ev = torch.sigmoid((t - h_ev.unsqueeze(-1)) / TAU).sum(dim=1) h_sq = torch.max( torch.max(h2d[:, :-1, :-1], h2d[:, :-1, 1:]), torch.max(h2d[:, 1:, :-1], h2d[:, 1:, 1:]), ).reshape(heights.shape[0], -1) chi_sq = torch.sigmoid((t - h_sq.unsqueeze(-1)) / TAU).sum(dim=1) return chi_v - chi_eh - chi_ev + chi_sq def compute_ecc(img_gray: torch.Tensor) -> torch.Tensor: """ Args: img_gray: (1, H, W) in [0, 1] Returns: ecc: (D+1, T) """ B, H, W = img_gray.shape ys = torch.linspace(0, 1, H, device=img_gray.device) xs = torch.linspace(0, 1, W, device=img_gray.device) gy, gx = torch.meshgrid(ys, xs, indexing="ij") coords_flat = torch.stack([gx, gy], dim=-1).reshape(-1, 2) intensity_flat = img_gray.reshape(B, -1) angles = torch.arange(N_DIRECTIONS - 1, dtype=torch.float32) * (math.pi / (N_DIRECTIONS - 1)) dirs = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1) eccs = [] for d_idx in range(N_DIRECTIONS - 1): nu = dirs[d_idx] spatial = (coords_flat * nu).sum(-1) spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min() + 1e-8) heights = 0.5 * spatial.unsqueeze(0) + 0.5 * intensity_flat eccs.append(_ecc_from_heights(heights, H, W)) eccs.append(_ecc_from_heights(intensity_flat, H, W)) return torch.stack(eccs, dim=1).squeeze(0) # (D+1, T) def ecc_plot(inp_pil: Image.Image, out_pil: Image.Image) -> plt.Figure: """Return matplotlib figure comparing ECC of input vs output.""" import torchvision.transforms.functional as TF def to_gray(pil): t = TF.to_tensor(pil.convert("RGB").resize((IMG_SIZE, IMG_SIZE))) return t.mean(0, keepdim=True).unsqueeze(0) # (1, 1, H, W) -> need (1,H,W) with torch.no_grad(): g_inp = to_gray(inp_pil).squeeze(0) # (1, H, W) g_out = to_gray(out_pil).squeeze(0) ecc_inp = compute_ecc(g_inp) # (D+1, T) ecc_out = compute_ecc(g_out) thresholds = np.linspace(0, 1, N_THRESHOLDS) direction_labels = [f"Dir {i+1}" for i in range(N_DIRECTIONS - 1)] + ["Intensity"] fig, axes = plt.subplots(3, 3, figsize=(12, 10)) axes = axes.flatten() fig.suptitle("Euler Characteristic Curves: Makeup (blue) vs Bare (orange)", fontsize=13) for d in range(N_DIRECTIONS): ax = axes[d] ax.plot(thresholds, ecc_inp[d].numpy(), color="#4C72B0", linewidth=1.5, label="Makeup input") ax.plot(thresholds, ecc_out[d].numpy(), color="#DD8452", linewidth=1.5, label="Bare output") ax.fill_between( thresholds, ecc_inp[d].numpy(), ecc_out[d].numpy(), alpha=0.15, color="purple" ) ax.set_title(direction_labels[d], fontsize=9) ax.set_xlabel("Threshold", fontsize=8) ax.set_ylabel("χ(t)", fontsize=8) ax.tick_params(labelsize=7) if d == 0: ax.legend(fontsize=7) plt.tight_layout() return fig # ── image preprocessing / postprocessing ────────────────────────────────────── def preprocess(pil_img: Image.Image) -> torch.Tensor: img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS) arr = np.array(img, dtype=np.float32) / 127.5 - 1.0 # [0,255] → [-1,1] t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # (1,3,H,W) return t.to(DEVICE) def postprocess(t: torch.Tensor) -> Image.Image: arr = t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() arr = ((arr + 1.0) * 127.5).clip(0, 255).astype(np.uint8) return Image.fromarray(arr) # ── main inference function ──────────────────────────────────────────────────── def remove_makeup(inp_img: Image.Image): if inp_img is None: return None, "No image uploaded.", None # 1) ITDF inference with torch.no_grad(): x = preprocess(inp_img) out = model(x) out_img = postprocess(out) # 2) Identity preservation score score = identity_score(inp_img, out_img) if score is not None: id_text = f"**Identity Preservation Score:** {score:.4f} / 1.0000\n\n" if score >= 0.85: id_text += "Excellent — face identity strongly preserved." elif score >= 0.70: id_text += "Good — face identity well preserved." elif score >= 0.50: id_text += "Moderate — some identity drift." else: id_text += "Low — significant identity change detected." else: id_text = "Identity score unavailable (FaceNet not loaded)." # 3) ECC topology plot fig = ecc_plot(inp_img, out_img) return out_img, id_text, fig # ── Gradio UI ────────────────────────────────────────────────────────────────── with gr.Blocks(title="ITDF — Makeup Removal") as demo: gr.Markdown( """ # ITDF — Implicit Topological De-Filtering **Makeup Removal via Topology-Aware Implicit Neural Representation** Upload a face image with makeup. The model outputs the bare face along with: - **Identity Preservation Score** — how well face identity is maintained (ArcFace cosine similarity) - **ECC Topology Plot** — Euler Characteristic Curves comparing input vs output topology across 9 directions > ⏳ Running on free CPU — inference takes **~35–70 seconds**. Please wait. """ ) with gr.Row(): with gr.Column(): inp = gr.Image(type="pil", label="Input: Makeup Face") btn = gr.Button("Remove Makeup", variant="primary") with gr.Column(): out_img = gr.Image(type="pil", label="Output: Bare Face") id_score = gr.Markdown(label="Identity Score") ecc_fig = gr.Plot(label="ECC Topology Plot (Makeup vs Bare)") btn.click( fn=remove_makeup, inputs=[inp], outputs=[out_img, id_score, ecc_fig], ) gr.Markdown( """ --- **Model:** ViT-B/16 encoder + SIREN implicit decoder **Topology losses:** Differentiable Euler Characteristic Transform (DECT) + Multiparameter Persistence **Dataset:** FFHQ-Makeup (73,480 training pairs) """ ) demo.queue() demo.launch(show_error=True)