MATCHA / app.py
Chris Addis
remove A/B
e8242a3
raw
history blame
12.4 kB
import gradio as gr
import numpy as np
from PIL import Image
import io
import os
import requests
import json
from dotenv import load_dotenv
import openai
import base64
import csv
import tempfile
import datetime
# Load environment variables from .env file if it exists (for local development)
# On Hugging Face Spaces, the secrets are automatically available as environment variables
if os.path.exists(".env"):
load_dotenv()
from io import BytesIO
import numpy as np
import requests
from PIL import Image
# import libraries
from library.utils_model import *
from library.utils_html import *
from library.utils_prompt import *
OR = OpenRouterAPI()
gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
# Path for storing user preferences
PREFERENCES_FILE = "data/user_preferences.csv"
# Ensure directory exists
os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
def get_sys_prompt(length="medium"):
if length == "short":
dev_prompt = """You are a museum curator tasked with generating alt-text (as defined in WCAG 2.1) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
elif length == "medium":
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in WCAG 2.1) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
else:
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in WCAG 2.1) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
return dev_prompt
# This function is no longer needed since we removed A/B testing
def create_csv_file_simple(results):
"""Create a CSV file from the results and return the path"""
# Create a temporary file
fd, path = tempfile.mkstemp(suffix='.csv')
with os.fdopen(fd, 'w', newline='') as f:
writer = csv.writer(f)
# Write header
writer.writerow(['image_id', 'content'])
# Write data
for result in results:
writer.writerow([
result.get('image_id', ''),
result.get('content', '')
])
return path
# Extract original filename without path or extension
def get_base_filename(filepath):
if not filepath:
return ""
# Get the basename (filename with extension)
basename = os.path.basename(filepath)
# Remove extension
filename = os.path.splitext(basename)[0]
return filename
# Define the Gradio interface
def create_demo():
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# AI Alt-text Generator")
gr.Markdown("Upload one or more images to generate Alt-text")
gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool.")
with gr.Row():
# Left column: Controls and uploads
with gr.Column(scale=1):
# Upload interface
upload_button = gr.UploadButton(
"Click to Upload Images",
file_types=["image"],
file_count="multiple"
)
# Add model selection dropdown with new model choices
model_choice = gr.Dropdown(
choices=["google/gemini-2.0-flash-001", "anthropic/claude-3.7-sonnet", "openai/chatgpt-4o-latest"],
label="Select Model",
value="anthropic/claude-3.7-sonnet",
visible=True
)
# Add response length selection
length_choice = gr.Radio(
choices=["short", "medium", "long"],
label="Response Length",
value="medium",
info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
)
# Preview gallery for uploaded images
gr.Markdown("### Uploaded Images")
input_gallery = gr.Gallery(
label="",
columns=3,
height=150,
object_fit="contain"
)
# Analysis button
analyze_button = gr.Button("Analyze Images", variant="primary", size="lg")
# Hidden state component to store image info
image_state = gr.State([])
filename_state = gr.State([])
# CSV download component
csv_download = gr.File(label="CSV Results")
# Right column: Display area
with gr.Column(scale=2):
# Image display
current_image = gr.Image(
label="Current Image",
height=400,
type="filepath",
show_download_button=True,
show_share_button=True
)
# Navigation row
with gr.Row():
prev_button = gr.Button("← Previous", size="sm")
image_counter = gr.Markdown("", elem_id="image-counter")
next_button = gr.Button("Next →", size="sm")
# Alt-text heading and output
gr.Markdown("### Generated Alt-text")
# Alt-text
analysis_text = gr.Textbox(
label="",
value="Please analyze images to see results",
lines=6,
max_lines=10,
interactive=False,
show_label=False
)
# Hidden state for gallery navigation
current_index = gr.State(0)
all_images = gr.State([])
all_results = gr.State([])
# Handle file uploads - store files for use during analysis
def handle_upload(files):
file_paths = []
file_names = []
for file in files:
file_paths.append(file.name)
# Extract filename without path or extension for later use
file_names.append(get_base_filename(file.name))
return file_paths, file_paths, file_names
upload_button.upload(
fn=handle_upload,
inputs=[upload_button],
outputs=[input_gallery, image_state, filename_state]
)
# Function to analyze images
def analyze_images(image_paths, model_choice, length_choice, filenames):
if not image_paths:
return [], [], 0, "", "No images", "", ""
# Get system prompt based on length selection
sys_prompt = get_sys_prompt(length_choice)
image_results = []
for i, image_path in enumerate(image_paths):
# Use original filename as image_id if available
if i < len(filenames) and filenames[i]:
image_id = filenames[i]
else:
image_id = f"Image {i+1}"
try:
# Open the image file for analysis
img = Image.open(image_path)
prompt0 = prompt_new() # Using the new prompt function
# Use the selected model
result = OR.generate_caption(
img,
model=model_choice,
max_image_size=512,
prompt=prompt0,
prompt_dev=sys_prompt,
temperature=1
)
# Add to results
image_results.append({
"image_id": image_id,
"content": result
})
except Exception as e:
error_message = f"Error: {str(e)}"
image_results.append({
"image_id": image_id,
"content": error_message
})
# Create a CSV file for download
csv_path = create_csv_file_simple(image_results)
# Set up initial display with first image
if len(image_paths) > 0:
initial_image = image_paths[0]
initial_counter = f"{1} of {len(image_paths)}"
initial_text = image_results[0]["content"]
else:
initial_image = ""
initial_text = "No images analyzed"
initial_counter = "0 of 0"
return (image_paths, image_results, 0, initial_image, initial_counter,
initial_text, csv_path)
# Function to navigate to previous image
def go_to_prev(current_idx, images, results):
if not images or len(images) == 0:
return current_idx, "", "0 of 0", ""
new_idx = (current_idx - 1) % len(images) if current_idx > 0 else len(images) - 1
counter_html = f"{new_idx + 1} of {len(images)}"
return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
# Function to navigate to next image
def go_to_next(current_idx, images, results):
if not images or len(images) == 0:
return current_idx, "", "0 of 0", ""
new_idx = (current_idx + 1) % len(images)
counter_html = f"{new_idx + 1} of {len(images)}"
return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
# Connect the analyze button
analyze_button.click(
fn=analyze_images,
inputs=[image_state, model_choice, length_choice, filename_state],
outputs=[
all_images, all_results, current_index, current_image, image_counter,
analysis_text, csv_download
]
)
# Connect navigation buttons
prev_button.click(
fn=go_to_prev,
inputs=[current_index, all_images, all_results],
outputs=[current_index, current_image, image_counter, analysis_text]
)
next_button.click(
fn=go_to_next,
inputs=[current_index, all_images, all_results],
outputs=[current_index, current_image, image_counter, analysis_text]
)
# Optional: Add additional information
with gr.Accordion("About", open=False):
gr.Markdown("""
## About this demo
This demo generates alt-text for uploaded images.
- Upload one or more images using the upload button
- Choose a model and response length for generation
- Navigate through the images with the Previous and Next buttons
- Download CSV with all results
Developed by the Natural History Museum in Partnership with National Museums Liverpool.
""")
return demo
# Launch the app
if __name__ == "__main__":
app = create_demo()
app.launch()