| import matplotlib |
| import numpy as np |
|
|
| matplotlib.use("Agg") |
| import os |
|
|
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import torch |
| import torch.nn.functional as F |
| from hw2.models import ResNetKeypointDetector, Simple_CNN, UNetKeypointDetector |
| from PIL import Image |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
| |
| MODEL_REGISTRY = { |
| "Simple CNN": ("KoniHD/Simple_CNN", Simple_CNN), |
| "ResNet18": ("KoniHD/Fine-Tuned-ResNet18", ResNetKeypointDetector), |
| "U-Net": ("KoniHD/UNet", UNetKeypointDetector), |
| } |
|
|
| _cache: dict = {} |
|
|
|
|
| def load_model(name: str) -> torch.nn.Module: |
| if name not in _cache: |
| repo_id, cls = MODEL_REGISTRY[name] |
| model = cls.from_pretrained( |
| repo_id, token=HF_TOKEN |
| ) |
| model.eval() |
| _cache[name] = model |
| return _cache[name] |
|
|
|
|
| |
|
|
|
|
| def preprocess(img: np.ndarray) -> torch.Tensor: |
| """numpy HWC uint8 β [1,1,224,224] float32 in [0,1]""" |
| pil = Image.fromarray(img).convert("L").resize((224, 224), Image.BILINEAR) |
| arr = np.array(pil, dtype=np.float32) / 255.0 |
| tensor = torch.from_numpy(arr) |
| return tensor |
|
|
|
|
| |
|
|
|
|
| def _fig_to_array(fig: plt.Figure) -> np.ndarray: |
| fig.canvas.draw() |
| buf = fig.canvas.buffer_rgba() |
| arr = np.asarray(buf)[..., :3] |
| plt.close(fig) |
| return arr |
|
|
|
|
| def draw_keypoints(img: np.ndarray, kps: np.ndarray, title: str) -> np.ndarray: |
| h, w = img.shape |
| kps = kps[0] |
| xs = kps[:, 0] * (w / 2) + (w / 2) |
| ys = kps[:, 1] * (h / 2) + (h / 2) |
| fig, ax = plt.subplots(figsize=(4, 4)) |
| ax.imshow(img, cmap="gray", vmin=0, vmax=1) |
| ax.scatter(xs, ys, s=10, c="r") |
| ax.set_title(title, fontsize=10, pad=6) |
| ax.axis("off") |
| fig.tight_layout(pad=0.3) |
| return _fig_to_array(fig) |
|
|
|
|
| |
|
|
|
|
| def run_inference(image: np.ndarray, model_name: str) -> tuple[np.ndarray, str]: |
| """ |
| Run facial keypoint detection on a face image and return the annotated result. |
| |
| Args: |
| image (np.ndarray): Input face image. Pass as a public URL when using via MCP. |
| model_name (str): Model to use. One of: "Simple CNN", "ResNet18", "U-Net". |
| |
| Returns: |
| tuple: (keypoint_image, info_string) where keypoint_image is the face annotated |
| with 68 predicted landmarks, and info_string describes the model and output. |
| """ |
| if image is None: |
| raise ValueError("No image provided.") |
|
|
| img = preprocess(image) |
| tensor = img.unsqueeze(0).unsqueeze(0) |
| model = load_model(model_name) |
| is_unet = model_name == "U-Net" |
|
|
| model.eval() |
| with torch.inference_mode(): |
| output = model(tensor) |
|
|
| if is_unet: |
| heatmaps = F.sigmoid(output).cpu().squeeze(0) |
| c, h, w = heatmaps.shape |
| heatmaps = heatmaps.view(c, -1).argmax(dim=1).numpy() |
| ys = (heatmaps // w).astype(float) / (h / 2) - 1 |
| xs = (heatmaps % w).astype(float) / (w / 2) - 1 |
| kps = np.stack([xs, ys], axis=1) |
| kps = kps[np.newaxis] |
| kp_img = draw_keypoints(img, kps, f"{model_name} β heatmap-based") |
| info = f"{model_name} | 68 keypoints | heatmap-based" |
| else: |
| kps = output.view(-1, 68, 2).cpu().numpy() |
| kp_img = draw_keypoints(img, kps, f"{model_name} β direct regression") |
| info = f"{model_name} | 68 keypoints | direct regression" |
|
|
| return kp_img, info |
|
|
|
|
| |
|
|
| css = """ |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;500&display=swap'); |
| body, .gradio-container { |
| background: #0a0a0f !important; |
| font-family: 'IBM Plex Sans', sans-serif !important; |
| } |
| h1 { font-family: 'IBM Plex Mono', monospace !important; letter-spacing: -0.03em; } |
| .run-btn { |
| background: linear-gradient(135deg, #ff3366 0%, #ff6b35 100%) !important; |
| border: none !important; border-radius: 6px !important; |
| font-family: 'IBM Plex Mono', monospace !important; |
| font-weight: 600 !important; letter-spacing: 0.05em !important; |
| color: white !important; transition: opacity 0.15s ease !important; |
| } |
| .run-btn:hover { opacity: 0.85 !important; } |
| .info-box textarea { |
| background: #111118 !important; border: 1px solid #1e1e2e !important; |
| color: #66ffcc !important; font-family: 'IBM Plex Mono', monospace !important; |
| font-size: 0.82rem !important; |
| } |
| footer { display: none !important; } |
| """ |
|
|
| with gr.Blocks(title="Facial Keypoint Detection") as demo: |
| gr.HTML(""" |
| <div style="padding:2rem 0 1rem; text-align:center;"> |
| <h1 style="color:#f0f0f8; font-size:1.9rem; margin:0;">facial keypoint detection</h1> |
| <p style="color:#555577; font-family:'IBM Plex Mono',monospace; |
| font-size:0.8rem; margin-top:0.4rem; letter-spacing:0.08em;"> |
| CS280 Β· HW2 Β· three models Β· 68 landmarks |
| </p> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image( |
| label="Input image", |
| type="numpy", |
| value="data/example.jpg", |
| sources=["upload", "webcam", "clipboard"], |
| height=300, |
| ) |
| model_selector = gr.Radio( |
| choices=list(MODEL_REGISTRY.keys()), |
| value="Simple CNN", |
| label="Model", |
| ) |
| run_btn = gr.Button("βΆ Run inference", elem_classes=["run-btn"]) |
| info_box = gr.Textbox( |
| label="", |
| interactive=False, |
| elem_classes=["info-box"], |
| max_lines=1, |
| ) |
|
|
| with gr.Column(scale=2): |
| kp_output = gr.Image(label="Predicted keypoints", height=320) |
|
|
| |
| run_btn.click( |
| fn=run_inference, |
| inputs=[image_input, model_selector], |
| outputs=[kp_output, info_box], |
| api_name=False, |
| ) |
| image_input.upload( |
| fn=run_inference, |
| inputs=[image_input, model_selector], |
| outputs=[kp_output, info_box], |
| api_name=False, |
| ) |
|
|
| |
| gr.api( |
| fn=run_inference, |
| api_name="run_inference", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(mcp_server=True, css=css) |
|
|