Koni
fix: numpy float issue
c22b52f
import matplotlib
import numpy as np
matplotlib.use("Agg")
import os
import gradio as gr
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from hw2.models import ResNetKeypointDetector, Simple_CNN, UNetKeypointDetector
from PIL import Image
HF_TOKEN = os.environ.get("HF_TOKEN")
# ── model registry ─────────────────────────────────────────────────────────────
MODEL_REGISTRY = {
"Simple CNN": ("KoniHD/Simple_CNN", Simple_CNN),
"ResNet18": ("KoniHD/Fine-Tuned-ResNet18", ResNetKeypointDetector),
"U-Net": ("KoniHD/UNet", UNetKeypointDetector),
}
_cache: dict = {}
def load_model(name: str) -> torch.nn.Module:
if name not in _cache:
repo_id, cls = MODEL_REGISTRY[name]
model = cls.from_pretrained(
repo_id, token=HF_TOKEN
) # loads model_config.json + safetensors
model.eval()
_cache[name] = model
return _cache[name]
# ── preprocessing ──────────────────────────────────────────────────────────────
def preprocess(img: np.ndarray) -> torch.Tensor:
"""numpy HWC uint8 β†’ [1,1,224,224] float32 in [0,1]"""
pil = Image.fromarray(img).convert("L").resize((224, 224), Image.BILINEAR)
arr = np.array(pil, dtype=np.float32) / 255.0
tensor = torch.from_numpy(arr)
return tensor
# ── visualisation ──────────────────────────────────────────────────────────────
def _fig_to_array(fig: plt.Figure) -> np.ndarray:
fig.canvas.draw()
buf = fig.canvas.buffer_rgba()
arr = np.asarray(buf)[..., :3]
plt.close(fig)
return arr
def draw_keypoints(img: np.ndarray, kps: np.ndarray, title: str) -> np.ndarray:
h, w = img.shape
kps = kps[0]
xs = kps[:, 0] * (w / 2) + (w / 2)
ys = kps[:, 1] * (h / 2) + (h / 2)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(img, cmap="gray", vmin=0, vmax=1)
ax.scatter(xs, ys, s=10, c="r")
ax.set_title(title, fontsize=10, pad=6)
ax.axis("off")
fig.tight_layout(pad=0.3)
return _fig_to_array(fig)
# ── inference ──────────────────────────────────────────────────────────────────
def run_inference(image: np.ndarray, model_name: str) -> tuple[np.ndarray, str]:
"""
Run facial keypoint detection on a face image and return the annotated result.
Args:
image (np.ndarray): Input face image. Pass as a public URL when using via MCP.
model_name (str): Model to use. One of: "Simple CNN", "ResNet18", "U-Net".
Returns:
tuple: (keypoint_image, info_string) where keypoint_image is the face annotated
with 68 predicted landmarks, and info_string describes the model and output.
"""
if image is None:
raise ValueError("No image provided.")
img = preprocess(image)
tensor = img.unsqueeze(0).unsqueeze(0)
model = load_model(model_name)
is_unet = model_name == "U-Net"
model.eval()
with torch.inference_mode():
output = model(tensor)
if is_unet:
heatmaps = F.sigmoid(output).cpu().squeeze(0)
c, h, w = heatmaps.shape
heatmaps = heatmaps.view(c, -1).argmax(dim=1).numpy()
ys = (heatmaps // w).astype(float) / (h / 2) - 1
xs = (heatmaps % w).astype(float) / (w / 2) - 1
kps = np.stack([xs, ys], axis=1)
kps = kps[np.newaxis]
kp_img = draw_keypoints(img, kps, f"{model_name} β€” heatmap-based")
info = f"{model_name} | 68 keypoints | heatmap-based"
else:
kps = output.view(-1, 68, 2).cpu().numpy()
kp_img = draw_keypoints(img, kps, f"{model_name} β€” direct regression")
info = f"{model_name} | 68 keypoints | direct regression"
return kp_img, info
# ── UI ─────────────────────────────────────────────────────────────────────────
css = """
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;500&display=swap');
body, .gradio-container {
background: #0a0a0f !important;
font-family: 'IBM Plex Sans', sans-serif !important;
}
h1 { font-family: 'IBM Plex Mono', monospace !important; letter-spacing: -0.03em; }
.run-btn {
background: linear-gradient(135deg, #ff3366 0%, #ff6b35 100%) !important;
border: none !important; border-radius: 6px !important;
font-family: 'IBM Plex Mono', monospace !important;
font-weight: 600 !important; letter-spacing: 0.05em !important;
color: white !important; transition: opacity 0.15s ease !important;
}
.run-btn:hover { opacity: 0.85 !important; }
.info-box textarea {
background: #111118 !important; border: 1px solid #1e1e2e !important;
color: #66ffcc !important; font-family: 'IBM Plex Mono', monospace !important;
font-size: 0.82rem !important;
}
footer { display: none !important; }
"""
with gr.Blocks(title="Facial Keypoint Detection") as demo:
gr.HTML("""
<div style="padding:2rem 0 1rem; text-align:center;">
<h1 style="color:#f0f0f8; font-size:1.9rem; margin:0;">facial keypoint detection</h1>
<p style="color:#555577; font-family:'IBM Plex Mono',monospace;
font-size:0.8rem; margin-top:0.4rem; letter-spacing:0.08em;">
CS280 Β· HW2 Β· three models Β· 68 landmarks
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Input image",
type="numpy",
value="data/example.jpg",
sources=["upload", "webcam", "clipboard"],
height=300,
)
model_selector = gr.Radio(
choices=list(MODEL_REGISTRY.keys()),
value="Simple CNN",
label="Model",
)
run_btn = gr.Button("β–Ά Run inference", elem_classes=["run-btn"])
info_box = gr.Textbox(
label="",
interactive=False,
elem_classes=["info-box"],
max_lines=1,
)
with gr.Column(scale=2):
kp_output = gr.Image(label="Predicted keypoints", height=320)
# UI event β€” not exposed to MCP
run_btn.click(
fn=run_inference,
inputs=[image_input, model_selector],
outputs=[kp_output, info_box],
api_name=False,
)
image_input.upload(
fn=run_inference,
inputs=[image_input, model_selector],
outputs=[kp_output, info_box],
api_name=False,
)
# Explicit MCP/API tool β€” only this is exposed
gr.api(
fn=run_inference,
api_name="run_inference",
)
if __name__ == "__main__":
demo.launch(mcp_server=True, css=css)