Spaces:
Sleeping
Sleeping
| """ | |
| ITDF β Implicit Topological De-Filtering | |
| Makeup Removal Demo | |
| Upload a makeup face image β get bare face + identity score + ECC topology plot. | |
| """ | |
| import os | |
| import math | |
| import yaml | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from models.itdf import ITDF | |
| # ββ constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_REPO = "priyadip/itdf-model" | |
| CKPT_FILE = "ckpt_final.pt" | |
| CONFIG_FILE = "default.yaml" | |
| IMG_SIZE = 512 | |
| DEVICE = torch.device("cpu") | |
| # ββ load model at startup ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| print("Downloading config β¦") | |
| cfg_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=CONFIG_FILE, | |
| repo_type="model", | |
| token=HF_TOKEN, | |
| ) | |
| with open(cfg_path) as f: | |
| cfg = yaml.safe_load(f) | |
| print("Downloading checkpoint β¦") | |
| ckpt_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=CKPT_FILE, | |
| repo_type="model", | |
| token=HF_TOKEN, | |
| ) | |
| print("Loading ITDF model β¦") | |
| model = ITDF(cfg).to(DEVICE) | |
| ckpt = torch.load(ckpt_path, map_location=DEVICE) | |
| # checkpoint may wrap weights under 'model' key | |
| state = ckpt.get("model", ckpt.get("G", ckpt)) | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| print("Model ready.") | |
| # ββ ArcFace / identity score βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from facenet_pytorch import InceptionResnetV1 | |
| facenet = InceptionResnetV1(pretrained="vggface2").eval().to(DEVICE) | |
| FACENET_OK = True | |
| print("FaceNet ready.") | |
| except Exception as e: | |
| FACENET_OK = False | |
| print(f"FaceNet not available: {e}") | |
| def _facenet_embed(img_pil: Image.Image) -> torch.Tensor | None: | |
| """Return L2-normalised 512-d face embedding or None if FaceNet unavailable.""" | |
| if not FACENET_OK: | |
| return None | |
| import torchvision.transforms as T | |
| tf = T.Compose([ | |
| T.Resize((160, 160)), | |
| T.ToTensor(), | |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ]) | |
| x = tf(img_pil).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| emb = facenet(x) | |
| return torch.nn.functional.normalize(emb, dim=1) | |
| def identity_score(inp: Image.Image, out: Image.Image) -> float | None: | |
| e1 = _facenet_embed(inp) | |
| e2 = _facenet_embed(out) | |
| if e1 is None or e2 is None: | |
| return None | |
| return float((e1 * e2).sum().clamp(-1, 1).item()) | |
| # ββ ECC computation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| N_DIRECTIONS = 9 # 8 spatial-mix + 1 pure-intensity | |
| N_THRESHOLDS = 64 | |
| TAU = 0.1 | |
| def _ecc_from_heights(heights: torch.Tensor, H: int, W: int) -> torch.Tensor: | |
| """Smooth ECC (B, T) from per-pixel lower-star heights (B, H*W).""" | |
| thresholds = torch.linspace(0, 1, N_THRESHOLDS, device=heights.device) | |
| t = thresholds.unsqueeze(0).unsqueeze(0) # (1, 1, T) | |
| chi_v = torch.sigmoid((t - heights.unsqueeze(-1)) / TAU).sum(dim=1) | |
| h2d = heights.reshape(heights.shape[0], H, W) | |
| h_eh = torch.max(h2d[:, :, :-1], h2d[:, :, 1:]).reshape(heights.shape[0], -1) | |
| chi_eh = torch.sigmoid((t - h_eh.unsqueeze(-1)) / TAU).sum(dim=1) | |
| h_ev = torch.max(h2d[:, :-1, :], h2d[:, 1:, :]).reshape(heights.shape[0], -1) | |
| chi_ev = torch.sigmoid((t - h_ev.unsqueeze(-1)) / TAU).sum(dim=1) | |
| h_sq = torch.max( | |
| torch.max(h2d[:, :-1, :-1], h2d[:, :-1, 1:]), | |
| torch.max(h2d[:, 1:, :-1], h2d[:, 1:, 1:]), | |
| ).reshape(heights.shape[0], -1) | |
| chi_sq = torch.sigmoid((t - h_sq.unsqueeze(-1)) / TAU).sum(dim=1) | |
| return chi_v - chi_eh - chi_ev + chi_sq | |
| def compute_ecc(img_gray: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| img_gray: (1, H, W) in [0, 1] | |
| Returns: | |
| ecc: (D+1, T) | |
| """ | |
| B, H, W = img_gray.shape | |
| ys = torch.linspace(0, 1, H, device=img_gray.device) | |
| xs = torch.linspace(0, 1, W, device=img_gray.device) | |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") | |
| coords_flat = torch.stack([gx, gy], dim=-1).reshape(-1, 2) | |
| intensity_flat = img_gray.reshape(B, -1) | |
| angles = torch.arange(N_DIRECTIONS - 1, dtype=torch.float32) * (math.pi / (N_DIRECTIONS - 1)) | |
| dirs = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1) | |
| eccs = [] | |
| for d_idx in range(N_DIRECTIONS - 1): | |
| nu = dirs[d_idx] | |
| spatial = (coords_flat * nu).sum(-1) | |
| spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min() + 1e-8) | |
| heights = 0.5 * spatial.unsqueeze(0) + 0.5 * intensity_flat | |
| eccs.append(_ecc_from_heights(heights, H, W)) | |
| eccs.append(_ecc_from_heights(intensity_flat, H, W)) | |
| return torch.stack(eccs, dim=1).squeeze(0) # (D+1, T) | |
| def ecc_plot(inp_pil: Image.Image, out_pil: Image.Image) -> plt.Figure: | |
| """Return matplotlib figure comparing ECC of input vs output.""" | |
| import torchvision.transforms.functional as TF | |
| def to_gray(pil): | |
| t = TF.to_tensor(pil.convert("RGB").resize((IMG_SIZE, IMG_SIZE))) | |
| return t.mean(0, keepdim=True).unsqueeze(0) # (1, 1, H, W) -> need (1,H,W) | |
| with torch.no_grad(): | |
| g_inp = to_gray(inp_pil).squeeze(0) # (1, H, W) | |
| g_out = to_gray(out_pil).squeeze(0) | |
| ecc_inp = compute_ecc(g_inp) # (D+1, T) | |
| ecc_out = compute_ecc(g_out) | |
| thresholds = np.linspace(0, 1, N_THRESHOLDS) | |
| direction_labels = [f"Dir {i+1}" for i in range(N_DIRECTIONS - 1)] + ["Intensity"] | |
| fig, axes = plt.subplots(3, 3, figsize=(12, 10)) | |
| axes = axes.flatten() | |
| fig.suptitle("Euler Characteristic Curves: Makeup (blue) vs Bare (orange)", fontsize=13) | |
| for d in range(N_DIRECTIONS): | |
| ax = axes[d] | |
| ax.plot(thresholds, ecc_inp[d].numpy(), color="#4C72B0", linewidth=1.5, label="Makeup input") | |
| ax.plot(thresholds, ecc_out[d].numpy(), color="#DD8452", linewidth=1.5, label="Bare output") | |
| ax.fill_between( | |
| thresholds, | |
| ecc_inp[d].numpy(), | |
| ecc_out[d].numpy(), | |
| alpha=0.15, color="purple" | |
| ) | |
| ax.set_title(direction_labels[d], fontsize=9) | |
| ax.set_xlabel("Threshold", fontsize=8) | |
| ax.set_ylabel("Ο(t)", fontsize=8) | |
| ax.tick_params(labelsize=7) | |
| if d == 0: | |
| ax.legend(fontsize=7) | |
| plt.tight_layout() | |
| return fig | |
| # ββ image preprocessing / postprocessing ββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(pil_img: Image.Image) -> torch.Tensor: | |
| img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS) | |
| arr = np.array(img, dtype=np.float32) / 127.5 - 1.0 # [0,255] β [-1,1] | |
| t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # (1,3,H,W) | |
| return t.to(DEVICE) | |
| def postprocess(t: torch.Tensor) -> Image.Image: | |
| arr = t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() | |
| arr = ((arr + 1.0) * 127.5).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(arr) | |
| # ββ main inference function ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def remove_makeup(inp_img: Image.Image): | |
| if inp_img is None: | |
| return None, "No image uploaded.", None | |
| # 1) ITDF inference | |
| with torch.no_grad(): | |
| x = preprocess(inp_img) | |
| out = model(x) | |
| out_img = postprocess(out) | |
| # 2) Identity preservation score | |
| score = identity_score(inp_img, out_img) | |
| if score is not None: | |
| id_text = f"**Identity Preservation Score:** {score:.4f} / 1.0000\n\n" | |
| if score >= 0.85: | |
| id_text += "Excellent β face identity strongly preserved." | |
| elif score >= 0.70: | |
| id_text += "Good β face identity well preserved." | |
| elif score >= 0.50: | |
| id_text += "Moderate β some identity drift." | |
| else: | |
| id_text += "Low β significant identity change detected." | |
| else: | |
| id_text = "Identity score unavailable (FaceNet not loaded)." | |
| # 3) ECC topology plot | |
| fig = ecc_plot(inp_img, out_img) | |
| return out_img, id_text, fig | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="ITDF β Makeup Removal") as demo: | |
| gr.Markdown( | |
| """ | |
| # ITDF β Implicit Topological De-Filtering | |
| **Makeup Removal via Topology-Aware Implicit Neural Representation** | |
| Upload a face image with makeup. The model outputs the bare face along with: | |
| - **Identity Preservation Score** β how well face identity is maintained (ArcFace cosine similarity) | |
| - **ECC Topology Plot** β Euler Characteristic Curves comparing input vs output topology across 9 directions | |
| > β³ Running on free CPU β inference takes **~35β70 seconds**. Please wait. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="Input: Makeup Face") | |
| btn = gr.Button("Remove Makeup", variant="primary") | |
| with gr.Column(): | |
| out_img = gr.Image(type="pil", label="Output: Bare Face") | |
| id_score = gr.Markdown(label="Identity Score") | |
| ecc_fig = gr.Plot(label="ECC Topology Plot (Makeup vs Bare)") | |
| btn.click( | |
| fn=remove_makeup, | |
| inputs=[inp], | |
| outputs=[out_img, id_score, ecc_fig], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Model:** ViT-B/16 encoder + SIREN implicit decoder | |
| **Topology losses:** Differentiable Euler Characteristic Transform (DECT) + Multiparameter Persistence | |
| **Dataset:** FFHQ-Makeup (73,480 training pairs) | |
| """ | |
| ) | |
| demo.queue() | |
| demo.launch(show_error=True) | |