itdf-space / app.py
priyadip's picture
Upload app.py with huggingface_hub
655192b verified
"""
ITDF β€” Implicit Topological De-Filtering
Makeup Removal Demo
Upload a makeup face image β†’ get bare face + identity score + ECC topology plot.
"""
import os
import math
import yaml
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from models.itdf import ITDF
# ── constants ──────────────────────────────────────────────────────────────────
MODEL_REPO = "priyadip/itdf-model"
CKPT_FILE = "ckpt_final.pt"
CONFIG_FILE = "default.yaml"
IMG_SIZE = 512
DEVICE = torch.device("cpu")
# ── load model at startup ──────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN", None)
print("Downloading config …")
cfg_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=CONFIG_FILE,
repo_type="model",
token=HF_TOKEN,
)
with open(cfg_path) as f:
cfg = yaml.safe_load(f)
print("Downloading checkpoint …")
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=CKPT_FILE,
repo_type="model",
token=HF_TOKEN,
)
print("Loading ITDF model …")
model = ITDF(cfg).to(DEVICE)
ckpt = torch.load(ckpt_path, map_location=DEVICE)
# checkpoint may wrap weights under 'model' key
state = ckpt.get("model", ckpt.get("G", ckpt))
model.load_state_dict(state, strict=False)
model.eval()
print("Model ready.")
# ── ArcFace / identity score ───────────────────────────────────────────────────
try:
from facenet_pytorch import InceptionResnetV1
facenet = InceptionResnetV1(pretrained="vggface2").eval().to(DEVICE)
FACENET_OK = True
print("FaceNet ready.")
except Exception as e:
FACENET_OK = False
print(f"FaceNet not available: {e}")
def _facenet_embed(img_pil: Image.Image) -> torch.Tensor | None:
"""Return L2-normalised 512-d face embedding or None if FaceNet unavailable."""
if not FACENET_OK:
return None
import torchvision.transforms as T
tf = T.Compose([
T.Resize((160, 160)),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
x = tf(img_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
emb = facenet(x)
return torch.nn.functional.normalize(emb, dim=1)
def identity_score(inp: Image.Image, out: Image.Image) -> float | None:
e1 = _facenet_embed(inp)
e2 = _facenet_embed(out)
if e1 is None or e2 is None:
return None
return float((e1 * e2).sum().clamp(-1, 1).item())
# ── ECC computation ────────────────────────────────────────────────────────────
N_DIRECTIONS = 9 # 8 spatial-mix + 1 pure-intensity
N_THRESHOLDS = 64
TAU = 0.1
def _ecc_from_heights(heights: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""Smooth ECC (B, T) from per-pixel lower-star heights (B, H*W)."""
thresholds = torch.linspace(0, 1, N_THRESHOLDS, device=heights.device)
t = thresholds.unsqueeze(0).unsqueeze(0) # (1, 1, T)
chi_v = torch.sigmoid((t - heights.unsqueeze(-1)) / TAU).sum(dim=1)
h2d = heights.reshape(heights.shape[0], H, W)
h_eh = torch.max(h2d[:, :, :-1], h2d[:, :, 1:]).reshape(heights.shape[0], -1)
chi_eh = torch.sigmoid((t - h_eh.unsqueeze(-1)) / TAU).sum(dim=1)
h_ev = torch.max(h2d[:, :-1, :], h2d[:, 1:, :]).reshape(heights.shape[0], -1)
chi_ev = torch.sigmoid((t - h_ev.unsqueeze(-1)) / TAU).sum(dim=1)
h_sq = torch.max(
torch.max(h2d[:, :-1, :-1], h2d[:, :-1, 1:]),
torch.max(h2d[:, 1:, :-1], h2d[:, 1:, 1:]),
).reshape(heights.shape[0], -1)
chi_sq = torch.sigmoid((t - h_sq.unsqueeze(-1)) / TAU).sum(dim=1)
return chi_v - chi_eh - chi_ev + chi_sq
def compute_ecc(img_gray: torch.Tensor) -> torch.Tensor:
"""
Args:
img_gray: (1, H, W) in [0, 1]
Returns:
ecc: (D+1, T)
"""
B, H, W = img_gray.shape
ys = torch.linspace(0, 1, H, device=img_gray.device)
xs = torch.linspace(0, 1, W, device=img_gray.device)
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
coords_flat = torch.stack([gx, gy], dim=-1).reshape(-1, 2)
intensity_flat = img_gray.reshape(B, -1)
angles = torch.arange(N_DIRECTIONS - 1, dtype=torch.float32) * (math.pi / (N_DIRECTIONS - 1))
dirs = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)
eccs = []
for d_idx in range(N_DIRECTIONS - 1):
nu = dirs[d_idx]
spatial = (coords_flat * nu).sum(-1)
spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min() + 1e-8)
heights = 0.5 * spatial.unsqueeze(0) + 0.5 * intensity_flat
eccs.append(_ecc_from_heights(heights, H, W))
eccs.append(_ecc_from_heights(intensity_flat, H, W))
return torch.stack(eccs, dim=1).squeeze(0) # (D+1, T)
def ecc_plot(inp_pil: Image.Image, out_pil: Image.Image) -> plt.Figure:
"""Return matplotlib figure comparing ECC of input vs output."""
import torchvision.transforms.functional as TF
def to_gray(pil):
t = TF.to_tensor(pil.convert("RGB").resize((IMG_SIZE, IMG_SIZE)))
return t.mean(0, keepdim=True).unsqueeze(0) # (1, 1, H, W) -> need (1,H,W)
with torch.no_grad():
g_inp = to_gray(inp_pil).squeeze(0) # (1, H, W)
g_out = to_gray(out_pil).squeeze(0)
ecc_inp = compute_ecc(g_inp) # (D+1, T)
ecc_out = compute_ecc(g_out)
thresholds = np.linspace(0, 1, N_THRESHOLDS)
direction_labels = [f"Dir {i+1}" for i in range(N_DIRECTIONS - 1)] + ["Intensity"]
fig, axes = plt.subplots(3, 3, figsize=(12, 10))
axes = axes.flatten()
fig.suptitle("Euler Characteristic Curves: Makeup (blue) vs Bare (orange)", fontsize=13)
for d in range(N_DIRECTIONS):
ax = axes[d]
ax.plot(thresholds, ecc_inp[d].numpy(), color="#4C72B0", linewidth=1.5, label="Makeup input")
ax.plot(thresholds, ecc_out[d].numpy(), color="#DD8452", linewidth=1.5, label="Bare output")
ax.fill_between(
thresholds,
ecc_inp[d].numpy(),
ecc_out[d].numpy(),
alpha=0.15, color="purple"
)
ax.set_title(direction_labels[d], fontsize=9)
ax.set_xlabel("Threshold", fontsize=8)
ax.set_ylabel("Ο‡(t)", fontsize=8)
ax.tick_params(labelsize=7)
if d == 0:
ax.legend(fontsize=7)
plt.tight_layout()
return fig
# ── image preprocessing / postprocessing ──────────────────────────────────────
def preprocess(pil_img: Image.Image) -> torch.Tensor:
img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
arr = np.array(img, dtype=np.float32) / 127.5 - 1.0 # [0,255] β†’ [-1,1]
t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # (1,3,H,W)
return t.to(DEVICE)
def postprocess(t: torch.Tensor) -> Image.Image:
arr = t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
arr = ((arr + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
return Image.fromarray(arr)
# ── main inference function ────────────────────────────────────────────────────
def remove_makeup(inp_img: Image.Image):
if inp_img is None:
return None, "No image uploaded.", None
# 1) ITDF inference
with torch.no_grad():
x = preprocess(inp_img)
out = model(x)
out_img = postprocess(out)
# 2) Identity preservation score
score = identity_score(inp_img, out_img)
if score is not None:
id_text = f"**Identity Preservation Score:** {score:.4f} / 1.0000\n\n"
if score >= 0.85:
id_text += "Excellent β€” face identity strongly preserved."
elif score >= 0.70:
id_text += "Good β€” face identity well preserved."
elif score >= 0.50:
id_text += "Moderate β€” some identity drift."
else:
id_text += "Low β€” significant identity change detected."
else:
id_text = "Identity score unavailable (FaceNet not loaded)."
# 3) ECC topology plot
fig = ecc_plot(inp_img, out_img)
return out_img, id_text, fig
# ── Gradio UI ──────────────────────────────────────────────────────────────────
with gr.Blocks(title="ITDF β€” Makeup Removal") as demo:
gr.Markdown(
"""
# ITDF β€” Implicit Topological De-Filtering
**Makeup Removal via Topology-Aware Implicit Neural Representation**
Upload a face image with makeup. The model outputs the bare face along with:
- **Identity Preservation Score** β€” how well face identity is maintained (ArcFace cosine similarity)
- **ECC Topology Plot** β€” Euler Characteristic Curves comparing input vs output topology across 9 directions
> ⏳ Running on free CPU β€” inference takes **~35–70 seconds**. Please wait.
"""
)
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Input: Makeup Face")
btn = gr.Button("Remove Makeup", variant="primary")
with gr.Column():
out_img = gr.Image(type="pil", label="Output: Bare Face")
id_score = gr.Markdown(label="Identity Score")
ecc_fig = gr.Plot(label="ECC Topology Plot (Makeup vs Bare)")
btn.click(
fn=remove_makeup,
inputs=[inp],
outputs=[out_img, id_score, ecc_fig],
)
gr.Markdown(
"""
---
**Model:** ViT-B/16 encoder + SIREN implicit decoder
**Topology losses:** Differentiable Euler Characteristic Transform (DECT) + Multiparameter Persistence
**Dataset:** FFHQ-Makeup (73,480 training pairs)
"""
)
demo.queue()
demo.launch(show_error=True)