svg-annotation / app.py
davda54's picture
Update app.py
383827b verified
from __future__ import annotations
import os
import gradio as gr
import json
import random
from datetime import datetime
from typing import Dict, List, Tuple
import hashlib
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download, upload_file
import threading
from collections.abc import Iterable
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
HF_TOKEN = os.environ.get("HF_TOKEN")
os.environ['HF_AUTH'] = HF_TOKEN
HfApi(token=HF_TOKEN)
USER_IDS = set(json.loads(os.environ.get("USER_IDS")) + json.loads(os.environ.get("USER_IDS_2")))
class Soft(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.indigo,
secondary_hue: colors.Color | str = colors.indigo,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_md,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
"ui-sans-serif",
"system-ui",
"sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
"ui-monospace",
"Consolas",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
self.name = "soft"
super().set(
# Colors
background_fill_primary="*neutral_50",
slider_color="*primary_500",
slider_color_dark="*primary_600",
# Shadows
shadow_drop="0 1px 4px 0 rgb(0 0 0 / 0.1)",
shadow_drop_lg="0 2px 5px 0 rgb(0 0 0 / 0.2)",
# Block Labels
block_background_fill="white",
block_label_padding="*spacing_sm *spacing_md",
block_label_background_fill="*primary_100",
block_label_background_fill_dark="*primary_600",
block_label_radius="*radius_md",
block_label_text_size="*text_md",
block_label_text_weight="600",
block_label_text_color="*primary_500",
block_label_text_color_dark="white",
block_title_radius="*block_label_radius",
block_title_padding="*block_label_padding",
block_title_background_fill="*block_label_background_fill",
block_title_text_weight="600",
block_title_text_color="*primary_500",
block_title_text_color_dark="white",
block_label_margin="*spacing_md",
# Inputs
input_background_fill="white",
input_border_color="*neutral_100",
input_shadow="*shadow_drop",
input_shadow_focus="*shadow_drop_lg",
checkbox_shadow="none",
# Buttons
shadow_spread="6px",
button_primary_shadow="*shadow_drop_lg",
button_primary_shadow_hover="*shadow_drop_lg",
button_primary_shadow_active="*shadow_inset",
button_secondary_shadow="*shadow_drop_lg",
button_secondary_shadow_hover="*shadow_drop_lg",
button_secondary_shadow_active="*shadow_inset",
checkbox_label_shadow="*shadow_drop_lg",
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_400",
button_primary_background_fill_hover_dark="*primary_500",
button_primary_text_color="white",
button_secondary_background_fill="white",
button_secondary_background_fill_hover="*neutral_100",
button_secondary_background_fill_hover_dark="*primary_500",
button_secondary_text_color="*neutral_800",
button_cancel_background_fill="*button_secondary_background_fill",
button_cancel_background_fill_hover="*button_secondary_background_fill_hover",
button_cancel_background_fill_hover_dark="*button_secondary_background_fill_hover",
button_cancel_text_color="*button_secondary_text_color",
checkbox_label_background_fill_selected="*primary_500",
checkbox_label_background_fill_selected_dark="*primary_600",
checkbox_border_width="1px",
checkbox_border_color="*neutral_100",
checkbox_border_color_dark="*neutral_600",
checkbox_background_color_selected="*primary_600",
checkbox_background_color_selected_dark="*primary_700",
checkbox_border_color_focus="*primary_500",
checkbox_border_color_focus_dark="*primary_600",
checkbox_border_color_selected="*primary_600",
checkbox_border_color_selected_dark="*primary_700",
checkbox_label_text_color_selected="white",
# Borders
block_border_width="0px",
panel_border_width="0px",
)
# Configuration for the output dataset
ANNOTATIONS_REPO = "ltg/svg-qa-annotations"
DATA_DIR = "annotation_data"
ANNOTATIONS_FILE = os.path.join(DATA_DIR, "train.jsonl")
PROMPT_TEXT = "Which of the four captions best describes the following SVG image?"
api = HfApi(token=HF_TOKEN)
# Initialize: download existing annotations file if it exists
def init_annotations_file():
"""Download existing annotations file from the hub, or create empty local file."""
os.makedirs(DATA_DIR, exist_ok=True)
try:
path = hf_hub_download(
repo_id=ANNOTATIONS_REPO,
filename="train.jsonl",
repo_type="dataset",
token=HF_TOKEN,
local_dir=DATA_DIR,
)
print(f"Downloaded existing annotations from hub: {path}")
except Exception as e:
print(f"No existing annotations file on hub (will create new): {e}")
# Create empty file
with open(ANNOTATIONS_FILE, "w") as f:
pass
init_annotations_file()
def load_existing_annotations():
"""Load existing annotations from the jsonl file"""
annotations = {}
if os.path.exists(ANNOTATIONS_FILE):
try:
with open(ANNOTATIONS_FILE, "r") as f:
for line in f:
if line.strip():
ann = json.loads(line)
user_id = ann.get("user_id")
if user_id:
if user_id not in annotations:
annotations[user_id] = []
annotations[user_id].append(ann)
print(f"Loaded {sum(len(v) for v in annotations.values())} existing annotations")
except Exception as e:
print(f"Error loading annotations: {e}")
return annotations
def save_annotation_to_file(annotation_data):
"""Save a single annotation to the jsonl file and upload to hub"""
try:
# Append to local jsonl file
with open(ANNOTATIONS_FILE, "a") as f:
line = json.dumps(annotation_data, ensure_ascii=False)
f.write(f"{line}\n")
# Upload to hub via HTTP API
api.upload_file(
path_or_fileobj=ANNOTATIONS_FILE,
path_in_repo="train.jsonl",
repo_id=ANNOTATIONS_REPO,
repo_type="dataset",
)
except Exception as e:
print(f"Error saving annotation: {e}")
def load_dataset_samples():
"""Load and prepare dataset samples for SVG caption QA"""
try:
dataset = load_dataset("ltg/svg-qa", split="train", token=HF_TOKEN)
samples = []
for item in dataset:
sample_id = item["index"]
svg_content = item["image"]
annotations = item["annotations"]
# Skip samples with null annotations
if annotations is None:
continue
# Each annotation becomes a separate sample
for ann_idx, annotation in enumerate(annotations):
if annotation is None:
continue
correct_caption = annotation["correct"]
incorrect_captions = annotation["incorrect"] # list of 3
# Use a unique id combining sample index and annotation index
unique_id = f"{sample_id}_{ann_idx}" if len(annotations) > 1 else str(sample_id)
samples.append({
"id": unique_id,
"original_index": sample_id,
"annotation_index": ann_idx,
"filename": item["filename"],
"svg": svg_content,
"correct": correct_caption,
"incorrect": incorrect_captions,
})
print(f"Loaded {len(samples)} SVG QA samples (from {len(dataset)} images)")
return samples
except Exception as e:
print(f"Error loading dataset: {e}")
print("Using dummy data for testing...")
return [
{
"id": "0",
"original_index": 0,
"annotation_index": 0,
"filename": "dummy.svg",
"svg": '<svg xmlns="http://www.w3.org/2000/svg" width="200" height="200"><circle cx="100" cy="100" r="80" fill="red"/></svg>',
"correct": "a red circle",
"incorrect": [
"a blue square",
"a green triangle",
"a yellow star"
],
}
]
# Load dataset on startup
DATASET_SAMPLES = load_dataset_samples()
def get_shuffled_options(sample):
"""Given a sample, return shuffled (options, correct_index) pair.
Seeded on the SVG image content so order is consistent across all users."""
options = [sample["correct"]] + sample["incorrect"]
random.seed(sample["svg"])
random.shuffle(options)
correct_index = options.index(sample["correct"])
return options, correct_index
class AnnotationManager:
def __init__(self):
self.annotations = load_existing_annotations()
self.user_states = {}
for user_id, user_annotations in self.annotations.items():
annotated_ids = [ann["sample_id"] for ann in user_annotations]
self.user_states[user_id] = {
"annotations": annotated_ids
}
def get_user_seed(self, user_id: str) -> int:
"""Generate consistent seed for user"""
return int(hashlib.md5(user_id.encode()).hexdigest(), 16)
def get_user_samples(self, user_id: str) -> List[Dict]:
"""Get shuffled samples for user based on their ID"""
seed = self.get_user_seed(user_id)
samples = DATASET_SAMPLES.copy()
random.Random(seed).shuffle(samples)
return samples
def get_next_sample(self, user_id: str) -> Tuple[Dict, int, int]:
"""Get next unannotated sample for user"""
if user_id not in self.user_states:
if user_id in self.annotations:
annotated_ids = [ann["sample_id"] for ann in self.annotations[user_id]]
self.user_states[user_id] = {
"annotations": annotated_ids
}
else:
self.user_states[user_id] = {
"annotations": []
}
samples = self.get_user_samples(user_id)
state = self.user_states[user_id]
total_annotated = len(state["annotations"])
for idx, sample in enumerate(samples):
if not self.is_annotated(user_id, sample["id"]):
return sample, total_annotated + 1, len(samples)
return None, len(samples), len(samples)
def is_annotated(self, user_id: str, sample_id) -> bool:
"""Check if user has annotated this sample"""
if user_id not in self.annotations:
return False
return any(ann["sample_id"] == sample_id for ann in self.annotations[user_id])
def save_annotation(self, user_id: str, sample_id, choice_index: int,
chosen_caption: str, correct_caption: str,
options: List[str], correct_index: int,
filename: str = None):
"""Save user's annotation and persist to file"""
if user_id not in self.annotations:
self.annotations[user_id] = []
annotation = {
"user_id": user_id,
"sample_id": sample_id,
"filename": filename,
"chosen_index": choice_index,
"chosen_caption": chosen_caption,
"correct_caption": correct_caption,
"correct_index": correct_index,
"is_correct": chosen_caption == correct_caption,
"options": options,
"timestamp": datetime.now().isoformat()
}
self.annotations[user_id].append(annotation)
if user_id in self.user_states:
self.user_states[user_id]["annotations"].append(sample_id)
else:
self.user_states[user_id] = {
"annotations": [sample_id]
}
threading.Thread(
target=save_annotation_to_file,
args=(annotation,)
).start()
print(f"Saved annotation: user={user_id}, sample={sample_id}, correct={annotation['is_correct']}")
def get_user_progress(self, user_id: str) -> Dict:
"""Get user's annotation progress"""
if user_id not in self.annotations:
return {"completed": 0, "total": len(DATASET_SAMPLES)}
completed = len(self.annotations[user_id])
return {"completed": completed, "total": len(DATASET_SAMPLES)}
# Initialize manager
manager = AnnotationManager()
def render_svg_html(svg_content: str) -> str:
"""Wrap SVG content in HTML for display, scaled to fit container width."""
import re
# If SVG has width/height but no viewBox, add a viewBox so it scales
if 'viewBox' not in svg_content and 'viewbox' not in svg_content:
w_match = re.search(r'<svg[^>]*\swidth="([^"]*)"', svg_content)
h_match = re.search(r'<svg[^>]*\sheight="([^"]*)"', svg_content)
if w_match and h_match:
w, h = w_match.group(1), h_match.group(1)
svg_content = re.sub(r'(<svg\b)', rf'\1 viewBox="0 0 {w} {h}"', svg_content)
# Remove explicit width/height so the SVG scales to container via CSS
svg_content = re.sub(r'(<svg[^>]*?)\s+width="[^"]*"', r'\1', svg_content)
svg_content = re.sub(r'(<svg[^>]*?)\s+height="[^"]*"', r'\1', svg_content)
return f"""
<div style="width: 100%; background: #e5e7eb; border-radius: 8px; padding: 10px; box-sizing: border-box;">
{svg_content}
</div>
"""
def get_sample_display_data(sample):
"""Get the display data for a sample, including shuffled options."""
options, correct_index = get_shuffled_options(sample)
svg_html = render_svg_html(sample["svg"])
return svg_html, options, correct_index
def build_caption_buttons_html(options):
"""Build HTML for caption buttons."""
labels = ["A", "B", "C", "D"]
buttons_html = "".join(
f'<button class="caption-btn-html" data-index="{i}" onclick="selectCaption(this, {i})">'
f'{labels[i]}: {options[i]}</button>'
for i in range(4)
)
return f'<div id="caption-buttons">{buttons_html}</div>'
def login(user_id: str) -> Tuple:
"""Handle user login"""
if not user_id or user_id.strip() == "" or user_id.strip() not in USER_IDS:
return (
gr.update(visible=True), # login_interface
gr.update(visible=False), # annotation_interface
"", # user_state
gr.update(value="Please enter a valid ID"), # login_status
gr.update(), # svg_display
gr.update(), # progress
[], # current_options state
-1, # correct_index state
gr.update(), # caption_html
gr.update(), # confirmed_index hidden
)
user_id = user_id.strip()
sample, current, total = manager.get_next_sample(user_id)
if sample is None:
return (
gr.update(visible=True),
gr.update(visible=False),
user_id,
gr.update(value=f"All {total} samples completed for user: {user_id}! 🎉"),
gr.update(),
gr.update(),
[],
-1,
gr.update(),
gr.update(),
)
svg_html, options, correct_index = get_sample_display_data(sample)
return (
gr.update(visible=False), # login_interface
gr.update(visible=True), # annotation_interface
user_id, # user_state
gr.update(value=""), # login_status
gr.update(value=svg_html), # svg_display
gr.update(value=f"Progress: {current}/{total}"), # progress
options, # current_options state
correct_index, # correct_index state
gr.update(value=build_caption_buttons_html(options)), # caption_html
gr.update(value=""), # confirmed_index hidden
)
def confirm_choice(choice_index_str: str, user_id: str, current_options: list,
correct_index: int) -> Tuple:
"""Handle confirmed annotation submission (second click)."""
if not user_id or not choice_index_str:
return (
gr.update(), # svg_display
gr.update(), # progress
gr.update(value="Error: No user logged in", visible=True), # status
[], # current_options
-1, # correct_index
gr.update(), # caption_html
gr.update(), # confirmed_index hidden
)
choice_index = int(choice_index_str)
# Save annotation
sample, _, _ = manager.get_next_sample(user_id)
if sample:
chosen_caption = current_options[choice_index]
manager.save_annotation(
user_id=user_id,
sample_id=sample["id"],
choice_index=choice_index,
chosen_caption=chosen_caption,
correct_caption=sample["correct"],
options=current_options,
correct_index=correct_index,
filename=sample.get("filename"),
)
# Get next sample
next_sample, current, total = manager.get_next_sample(user_id)
if next_sample is None:
done_html = "<h2 style='text-align:center;padding:40px;'>All samples completed! Thank you for your annotations. 🎉</h2>"
return (
gr.update(value=done_html),
gr.update(value=f"Progress: {total}/{total} — Complete!"),
gr.update(value="All annotations complete!", visible=True),
[],
-1,
gr.update(value=""),
gr.update(value=""),
)
svg_html, options, new_correct_index = get_sample_display_data(next_sample)
return (
gr.update(value=svg_html),
gr.update(value=f"Progress: {current}/{total}"),
gr.update(value="Annotation saved!", visible=True),
options,
new_correct_index,
gr.update(value=build_caption_buttons_html(options)),
gr.update(value=""),
)
def logout() -> Tuple:
"""Handle user logout"""
return (
gr.update(visible=True),
gr.update(visible=False),
"",
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
[],
-1,
gr.update(value=""),
gr.update(value=""),
)
# Custom CSS
custom_css = """
#login-group {
background-color: white !important;
}
#login-group > * {
background-color: white !important;
}
#login-group .gr-group {
background-color: white !important;
}
#login-group .gr-form {
background-color: white !important;
}
.light-shadow {
box-shadow: 0 1px 4px 0 rgb(0 0 0 / 0.1) !important;
}
.narrow-center {
max-width: 33.33%;
min-width: 320px;
margin-left: auto !important;
margin-right: auto !important;
}
/* HTML caption buttons */
.caption-btn-html {
display: block;
width: 100%;
text-align: left;
white-space: normal;
min-height: 50px;
padding: 12px 16px;
margin-bottom: 8px;
border: 1px solid #d1d5db;
border-radius: 8px;
background: white;
color: #1f2937;
font-size: 16px;
font-family: inherit;
cursor: pointer;
box-shadow: 0 2px 5px 0 rgb(0 0 0 / 0.2);
transition: background 0.15s, border-color 0.15s, color 0.15s;
}
.caption-btn-html:hover {
background: #f3f4f6;
}
.caption-btn-html.selected {
background: #6366f1;
color: white;
border-color: #6366f1;
}
.caption-btn-html.selected:hover {
background: #818cf8;
border-color: #818cf8;
}
#confirm-hint {
text-align: center;
color: #6366f1;
font-weight: 600;
margin-top: 4px;
min-height: 24px;
}
/* Hide the hidden elements */
.hidden-el {
display: none !important;
}
"""
# JS for caption button selection and confirmation
caption_js = """
<script>
var _selectedIndex = -1;
function selectCaption(btn, index) {
if (_selectedIndex === index) {
// Second click on same button: confirm
_selectedIndex = -1;
// Set the hidden textbox value and click hidden confirm button
var hiddenInput = document.querySelector('#confirmed-index-box textarea') ||
document.querySelector('#confirmed-index-box input');
if (hiddenInput) {
var proto = hiddenInput.tagName === 'TEXTAREA'
? window.HTMLTextAreaElement.prototype
: window.HTMLInputElement.prototype;
var nativeSetter = Object.getOwnPropertyDescriptor(proto, 'value').set;
nativeSetter.call(hiddenInput, String(index));
hiddenInput.dispatchEvent(new Event('input', { bubbles: true }));
}
setTimeout(function() {
var confirmBtn = document.querySelector('#confirm-btn');
if (confirmBtn) confirmBtn.click();
}, 50);
} else {
// First click: select
_selectedIndex = index;
var allBtns = document.querySelectorAll('.caption-btn-html');
allBtns.forEach(function(b) { b.classList.remove('selected'); });
btn.classList.add('selected');
var hint = document.getElementById('confirm-hint');
if (hint) hint.textContent = 'Click again to confirm your choice.';
}
}
// Reset selection state when new buttons are rendered
var _observer = new MutationObserver(function(mutations) {
mutations.forEach(function(m) {
m.addedNodes.forEach(function(node) {
if (node.nodeType === 1 && (node.id === 'caption-buttons' || node.querySelector && node.querySelector('#caption-buttons'))) {
_selectedIndex = -1;
var hint = document.getElementById('confirm-hint');
if (hint) hint.textContent = '';
}
});
});
});
_observer.observe(document.body, { childList: true, subtree: true });
</script>
"""
# Create Gradio interface
with gr.Blocks(
theme=Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial"]),
title="SVG Caption Annotation Tool",
css=custom_css
) as app:
gr.Markdown("# SVG Caption Annotation")
gr.HTML(caption_js)
user_state = gr.State("")
current_options_state = gr.State([])
correct_index_state = gr.State(-1)
# Login Interface
with gr.Column(visible=True) as login_interface:
with gr.Column(variant="panel", elem_id="login-group", elem_classes="light-shadow"):
gr.Markdown("## Log in", padding=True)
user_id_input = gr.Textbox(
label="Enter your unique annotator ID to begin",
placeholder="Annotator ID"
)
with gr.Row():
login_btn = gr.Button("Login", variant="primary", scale=0.2, min_width=100)
gr.HTML("")
login_status = gr.Markdown("", padding=True)
# Annotation Interface
with gr.Column(visible=False, elem_id="annotation-group") as annotation_interface:
with gr.Column(elem_classes="narrow-center"):
progress_label = gr.Markdown("")
gr.Markdown(f"**{PROMPT_TEXT}**")
# SVG display
with gr.Row(elem_classes="light-shadow"):
svg_display = gr.HTML(value="", label="SVG Image")
# Caption choice buttons (rendered as HTML)
caption_html = gr.HTML(value="", elem_id="caption-buttons-container")
gr.HTML('<div id="confirm-hint"></div>')
status_message = gr.Markdown("", visible=False)
# Hidden elements for confirm mechanism (CSS-hidden, kept in DOM for JS)
with gr.Column(elem_classes="hidden-el"):
confirmed_index = gr.Textbox(value="", elem_id="confirmed-index-box")
confirm_btn = gr.Button("Confirm", elem_id="confirm-btn")
with gr.Row(visible=False):
logout_btn = gr.Button("Logout", variant="stop", size="sm")
# --- Event handlers ---
login_outputs = [
login_interface,
annotation_interface,
user_state,
login_status,
svg_display,
progress_label,
current_options_state,
correct_index_state,
caption_html,
confirmed_index,
]
login_btn.click(
fn=login,
inputs=[user_id_input],
outputs=login_outputs
)
user_id_input.submit(
fn=login,
inputs=[user_id_input],
outputs=login_outputs
)
confirm_outputs = [
svg_display,
progress_label,
status_message,
current_options_state,
correct_index_state,
caption_html,
confirmed_index,
]
confirm_btn.click(
fn=confirm_choice,
inputs=[confirmed_index, user_state, current_options_state, correct_index_state],
outputs=confirm_outputs
)
logout_btn.click(
fn=logout,
inputs=[],
outputs=login_outputs
)
if __name__ == "__main__":
app.launch()