FuzzyPSI-hamming / src /ui /components.py
acprk
Add shared runtime face enrollment flow
48b5d08
Raw
History Blame Contribute Delete
17.5 kB
from __future__ import annotations
import json
import gradio as gr
import pandas as pd
from PIL import Image
from src import config
from src.data.gallery_store import DemoGallery
from src.pipeline.image_pipeline import ImagePipeline
from src.protocol.fpsi_adapter import FuzzyPSIAdapter
gallery = DemoGallery()
image_pipeline = ImagePipeline()
adapter = FuzzyPSIAdapter(config.DEFAULT_MODE)
CSS = """
footer, .built-with, a[href*="gradio.app"] { display: none !important; }
.main-title { text-align: center; padding: 10px 0 6px; }
.main-title h1 { font-size: 1.7em; margin: 0; letter-spacing: -0.5px; }
.main-title p { color: #555; margin: 2px 0 0; font-size: 0.92em; }
.compact-note textarea { font-family: ui-monospace, SFMono-Regular, Consolas, monospace; }
.small-copy { color: #666; font-size: 0.9em; }
"""
def _example_paths() -> list[list[str]]:
return [[path] for _, path in gallery.example_choices()]
def _empty_metrics() -> pd.DataFrame:
return pd.DataFrame(columns=["Metric", "Value"])
def _source_label(match_source: str) -> str:
if match_source == "enrolled":
return "Enrolled runtime library"
return "Packaged LFW gallery"
def _gallery_metrics(summary_json: str, notes: str) -> tuple[str, pd.DataFrame, str, str]:
payload = json.loads(summary_json)
summary = payload["summary"]
match = payload["match"]
details = payload["details"]
source_label = _source_label(details.get("match_source", "demo"))
verdict = (
f"MATCH: {match['person']}\n"
f"Source: {source_label} | Distance: {match['hamming_distance']} | Score: {match['score']:.4f}"
if match["matched"]
else f"NO MATCH\nBest: {match['person']} | Source: {source_label} | Distance: {match['hamming_distance']} | Score: {match['score']:.4f}"
)
metrics_df = pd.DataFrame(
[
("Result source", source_label),
("Mode", summary["mode"]),
("Dimension", summary["dim"]),
("Threshold δ", summary["delta"]),
("E-LSH L", summary["L"]),
("Latency (s)", f"{summary['time_s']:.3f}"),
("Communication (MB)", f"{summary['communication_mb']:.3f}"),
("Gallery images", str(summary["gallery_size"])),
("Gao feasible", "YES" if summary["gao_feasible"] else "NO"),
("Top-5 distances", ", ".join(map(str, details["top5_distances"]))),
],
columns=["Metric", "Value"],
)
compact_stats = (
f"src={details.get('match_source', 'demo')} | mode={summary['mode']} | d={summary['dim']} | "
f"δ={summary['delta']} | L={summary['L']} | latency={summary['time_s']:.3f}s"
)
return verdict, metrics_df, compact_stats, notes
def _query_gallery(query_embedding, features, people, filenames, dim: int):
if len(people) == 0:
return None
return adapter.query_against_gallery(
query_embedding,
features,
people,
filenames,
dim,
calibration_features=gallery.calibration_features,
calibration_people=gallery.calibration_people,
)
def _match_from_image(image: Image.Image, dim: int, mode: str):
adapter.mode = mode
analysis = image_pipeline.analyze(image)
checked_enrolled = gallery.has_enrollments()
enrolled_result = None
if checked_enrolled:
enrolled_result = _query_gallery(
analysis.embedding,
gallery.enrolled_features,
gallery.enrolled_people,
gallery.enrolled_filenames,
dim,
)
if enrolled_result is not None and enrolled_result[0].matched:
match, summary, details = enrolled_result
details["match_source"] = "enrolled"
notes_prefix = "matched enrolled-user overlay"
else:
match, summary, details = _query_gallery(
analysis.embedding,
gallery.features,
gallery.people,
gallery.filenames,
dim,
)
details["match_source"] = "demo"
if checked_enrolled:
notes_prefix = "no enrolled-user match; used packaged LFW gallery"
else:
notes_prefix = "no enrolled users registered; used packaged LFW gallery"
details["checked_enrolled"] = checked_enrolled
details["enrolled_gallery_size"] = int(len(gallery.enrolled_people))
gallery_image = Image.open(gallery.gallery_image_path(match.filename, match.person)).convert("RGB")
summary_json = adapter.export_summary(match, summary, details)
notes = "\n".join(f"- {note}" for note in ([notes_prefix] + analysis.notes + summary.notes))
verdict, metrics_df, compact_stats, note_text = _gallery_metrics(summary_json, notes)
return analysis.preview, gallery_image, verdict, compact_stats, metrics_df, note_text, summary_json
def run_image_demo(image: Image.Image, dim: int, mode: str):
if image is None:
raise gr.Error("Please upload or capture a face image.")
return _match_from_image(image, dim, mode)
def run_realtime_demo(image: Image.Image | None, dim: int, mode: str):
if image is None:
return None, None, "Waiting for webcam frame.", "", _empty_metrics(), "", ""
try:
return _match_from_image(image, dim, mode)
except Exception as exc:
return None, None, f"Realtime error: {exc}", "", _empty_metrics(), "", ""
def register_person(image: Image.Image | None, person: str):
if image is None:
raise gr.Error("Please upload or capture a face image.")
normalized_name = " ".join((person or "").split())
if not normalized_name:
raise gr.Error("Please enter a name before registering.")
analysis = image_pipeline.analyze(image)
record, replaced = gallery.enroll(normalized_name, analysis.preview, analysis.embedding)
saved_image = Image.open(record.gallery_path).convert("RGB")
status = (
f"Updated {normalized_name} in the shared temporary database."
if replaced
else f"Registered {normalized_name} in the shared temporary database."
)
details = (
"Upload and Realtime queries will search enrolled users first. "
"This shared temporary database may reset when the Space restarts."
)
return analysis.preview, saved_image, status, details, gallery.gallery_items(1), gallery.database_info(1), 1
def load_gallery_page(page: float | int | None):
page_num = 1 if page is None else int(page)
return gallery.gallery_items(page_num), gallery.database_info(page_num)
def build_app() -> gr.Blocks:
summary = gallery.summary()
with gr.Blocks(title=config.SPACE_TITLE, theme=gr.themes.Base(), css=CSS) as demo:
gr.HTML(
f'<div class="main-title"><h1>{config.SPACE_TITLE}</h1><p>{config.SPACE_SUBTITLE}</p></div>'
)
with gr.Tabs():
with gr.Tab("Database"):
with gr.Row():
with gr.Column(scale=4):
database_gallery = gr.Gallery(
value=gallery.gallery_items(1),
label="Visible Gallery",
columns=4,
rows=4,
height=520,
object_fit="cover",
preview=True,
allow_preview=True,
)
with gr.Column(scale=1):
database_info = gr.Textbox(
value=gallery.database_info(1),
label="Info",
interactive=False,
lines=4,
)
gallery_page = gr.Number(label="Page", value=1, minimum=1, maximum=gallery.total_pages(), precision=0)
load_page_btn = gr.Button("Load")
with gr.Accordion("Parameters", open=False):
gr.Dataframe(
headers=["Param", "Value"],
value=[
["Backend", config.DEFAULT_MODE],
["Dimensions", ", ".join(map(str, config.DEFAULT_DIMENSIONS))],
["Matching images", str(summary["gallery_size"])],
["Matching identities", str(summary["identity_count"])],
["Enrolled images", str(summary["enrolled_gallery_size"])],
["Enrolled identities", str(summary["enrolled_identity_count"])],
["Visible thumbnails", str(summary["visible_gallery_size"])],
["Calibration", str(summary["calibration_size"])],
["Embedding dim", str(summary["dimensions"])],
],
interactive=False,
)
load_page_btn.click(load_gallery_page, inputs=[gallery_page], outputs=[database_gallery, database_info], api_name=False)
with gr.Tab("Enroll"):
gr.Markdown("Register a face with a name into the shared temporary database. Later Upload and Realtime queries will search this registered library first.")
with gr.Row():
with gr.Column():
enroll_image = gr.Image(
label="Registration Face",
type="pil",
sources=["upload", "webcam"],
height=280,
)
enroll_name = gr.Textbox(label="Name", placeholder="Enter the identity name to register")
enroll_btn = gr.Button("Register Face", variant="primary")
with gr.Column():
enroll_saved = gr.Image(label="Saved Database Face", height=280, interactive=False)
enroll_status = gr.Textbox(label="Registration Status", lines=2, interactive=False)
with gr.Row():
with gr.Column():
enroll_preview = gr.Image(label="Processed Registration Face", height=220, interactive=False)
with gr.Column():
enroll_notes = gr.Textbox(label="Notes", lines=4, interactive=False)
enroll_btn.click(
register_person,
inputs=[enroll_image, enroll_name],
outputs=[enroll_preview, enroll_saved, enroll_status, enroll_notes, database_gallery, database_info, gallery_page],
api_name=False,
)
with gr.Tab("Upload"):
gr.Markdown("Upload a query face and match it against the enrolled-user overlay first, then the larger LFW-derived database.")
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Query Face",
type="pil",
sources=["upload", "webcam"],
height=280,
)
gr.Examples(examples=_example_paths(), inputs=[image_input], label="Examples")
with gr.Accordion("Advanced Settings", open=False):
upload_dim = gr.Dropdown(
choices=list(config.DEFAULT_DIMENSIONS),
value=config.DEFAULT_DIM,
label="Binary dimension",
)
upload_mode = gr.Dropdown(
choices=list(config.SUPPORTED_MODES),
value=config.DEFAULT_MODE,
label="Backend mode",
)
upload_btn = gr.Button("Run Identification", variant="primary")
with gr.Column():
upload_match = gr.Image(label="Matched Identity", height=280, interactive=False)
upload_verdict = gr.Textbox(label="Verdict", lines=3, interactive=False)
upload_summary = gr.Textbox(label="Summary", lines=2, interactive=False)
with gr.Row():
with gr.Column():
upload_preview = gr.Image(label="Processed Query", height=220, interactive=False)
with gr.Column():
upload_metrics = gr.Dataframe(label="Protocol Metrics", interactive=False)
with gr.Accordion("Execution Notes", open=False):
upload_notes = gr.Textbox(label="Notes", lines=6, interactive=False, elem_classes=["compact-note"])
with gr.Accordion("Detailed JSON Summary", open=False):
upload_json = gr.Textbox(label="JSON", lines=14, interactive=False, elem_classes=["compact-note"])
upload_btn.click(
run_image_demo,
inputs=[image_input, upload_dim, upload_mode],
outputs=[upload_preview, upload_match, upload_verdict, upload_summary, upload_metrics, upload_notes, upload_json],
api_name=False,
)
with gr.Tab("Realtime"):
gr.Markdown("Browser webcam recognition. Frames are processed on a throttled interval for stability.")
with gr.Row():
with gr.Column():
webcam_input = gr.Image(
label="Realtime Camera",
type="pil",
sources=["webcam"],
streaming=True,
height=280,
)
gr.Markdown(
f"<div class='small-copy'>Auto-matches every {config.REALTIME_STREAM_SECONDS:.1f}s while frames are available.</div>"
)
with gr.Accordion("Advanced Settings", open=False):
realtime_dim = gr.Dropdown(
choices=list(config.DEFAULT_DIMENSIONS),
value=config.DEFAULT_DIM,
label="Binary dimension",
)
realtime_mode = gr.Dropdown(
choices=list(config.SUPPORTED_MODES),
value=config.DEFAULT_MODE,
label="Backend mode",
)
realtime_btn = gr.Button("Identify Current Frame")
with gr.Column():
realtime_match = gr.Image(label="Matched Identity", height=280, interactive=False)
realtime_verdict = gr.Textbox(label="Realtime Verdict", lines=3, interactive=False)
realtime_summary = gr.Textbox(label="Realtime Summary", lines=2, interactive=False)
with gr.Row():
with gr.Column():
realtime_preview = gr.Image(label="Current Frame Preview", height=220, interactive=False)
with gr.Column():
realtime_metrics = gr.Dataframe(label="Realtime Metrics", interactive=False)
with gr.Accordion("Realtime Notes", open=False):
realtime_notes = gr.Textbox(label="Notes", lines=6, interactive=False, elem_classes=["compact-note"])
with gr.Accordion("Realtime JSON Summary", open=False):
realtime_json = gr.Textbox(label="JSON", lines=14, interactive=False, elem_classes=["compact-note"])
webcam_input.stream(
run_realtime_demo,
inputs=[webcam_input, realtime_dim, realtime_mode],
outputs=[realtime_preview, realtime_match, realtime_verdict, realtime_summary, realtime_metrics, realtime_notes, realtime_json],
api_name=False,
trigger_mode="always_last",
concurrency_limit=1,
show_progress="hidden",
stream_every=config.REALTIME_STREAM_SECONDS,
)
realtime_btn.click(
run_realtime_demo,
inputs=[webcam_input, realtime_dim, realtime_mode],
outputs=[realtime_preview, realtime_match, realtime_verdict, realtime_summary, realtime_metrics, realtime_notes, realtime_json],
api_name=False,
)
with gr.Accordion("Protocol Details", open=False):
gr.Markdown(
f"""
**Pipeline**: image/webcam frame → face embedding → binary projection → enrolled-user overlay match → fallback LFW E-LSH/Hamming verification.
**Modes**: `simulation` for public reliability, `full` for optional native FPSI execution when runtime support is present.
**Database**: {summary['demo_gallery_size']} packaged LFW images plus a shared temporary enrolled-user overlay are used for matching and browsing.
"""
)
return demo