bsvaz
deploy
82c9f6d
"""
Module for creating and configuring the Gradio interface.
Handles the UI layout and component setup.
"""
import gradio as gr
def create_interface(classifier, category_examples, custom_theme):
# Wrapper function to handle None inputs and provide loading state
def classify_with_loading(image):
if image is None:
return None
return classifier.classify_image(image)
# Create main interface container with custom theme
with gr.Blocks(theme=custom_theme) as iface:
# Header section with title and description
gr.Markdown("# πŸ›οΈ Landmark Image Classification")
# About section
gr.Markdown("""
This Gradio-based application allows users to classify famous landmarks using a Vision Transformer (ViT) model. Users can upload an image or select from provided examples to identify landmarks.
""")
# Create two-column layout for input and output
with gr.Row():
with gr.Column(scale=1):
# Left column: Image input and submit button
input_image = gr.Image(type="pil", label="Input Image")
submit_btn = gr.Button("Classify Landmark", variant="primary")
with gr.Column(scale=1):
# Right column: Classification results
output_label = gr.Label(num_top_classes=5, label="Predictions")
# Examples section with collapsible categories
gr.Markdown("## Example Categories")
for category, examples in category_examples.items():
# Create collapsible section for each category
with gr.Accordion(f"{category}", open=False):
# Add description of all landmarks in this category
supported_landmarks = [example[1]['label'] for example in examples] if examples else []
landmarks_text = ", ".join(supported_landmarks) if supported_landmarks else "No landmarks available"
gr.Markdown(f"**Supported landmarks in this category:** {landmarks_text}")
if examples:
gr.Examples(
examples=examples,
inputs=input_image,
outputs=output_label,
fn=classify_with_loading,
cache_examples=False,
label=None,
examples_per_page=1000
)
else:
gr.Markdown(f"No example images available for {category}")
# Connect the submit button to the classification function
submit_btn.click(
fn=classify_with_loading,
inputs=input_image,
outputs=output_label,
api_name="classify"
)
return iface