Chris Addis commited on
Commit ·
b81c5d1
1
Parent(s): 1c02ad4
Matcha 2
Browse files- app-Copy1.py +387 -0
- app.py +230 -186
app-Copy1.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
import openai
|
| 10 |
+
import base64
|
| 11 |
+
import csv
|
| 12 |
+
import tempfile
|
| 13 |
+
import datetime
|
| 14 |
+
|
| 15 |
+
# Load environment variables from .env file if it exists (for local development)
|
| 16 |
+
# On Hugging Face Spaces, the secrets are automatically available as environment variables
|
| 17 |
+
if os.path.exists(".env"):
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
from io import BytesIO
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
# import libraries
|
| 26 |
+
from library.utils_model import *
|
| 27 |
+
from library.utils_html import *
|
| 28 |
+
from library.utils_prompt import *
|
| 29 |
+
|
| 30 |
+
OR = OpenRouterAPI()
|
| 31 |
+
gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
| 32 |
+
|
| 33 |
+
# Path for storing user preferences
|
| 34 |
+
PREFERENCES_FILE = "data/user_preferences.csv"
|
| 35 |
+
|
| 36 |
+
# Ensure directory exists
|
| 37 |
+
os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
|
| 38 |
+
|
| 39 |
+
def get_sys_prompt(length="medium"):
|
| 40 |
+
if length == "short":
|
| 41 |
+
dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
|
| 42 |
+
elif length == "medium":
|
| 43 |
+
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
|
| 44 |
+
else:
|
| 45 |
+
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
|
| 46 |
+
return dev_prompt
|
| 47 |
+
|
| 48 |
+
def create_csv_file_simple(results):
|
| 49 |
+
"""Create a CSV file from the results and return the path"""
|
| 50 |
+
# Create a temporary file
|
| 51 |
+
fd, path = tempfile.mkstemp(suffix='.csv')
|
| 52 |
+
|
| 53 |
+
with os.fdopen(fd, 'w', newline='') as f:
|
| 54 |
+
writer = csv.writer(f)
|
| 55 |
+
# Write header
|
| 56 |
+
writer.writerow(['image_id', 'content'])
|
| 57 |
+
# Write data
|
| 58 |
+
for result in results:
|
| 59 |
+
writer.writerow([
|
| 60 |
+
result.get('image_id', ''),
|
| 61 |
+
result.get('content', '')
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
return path
|
| 65 |
+
|
| 66 |
+
# Extract original filename without path or extension
|
| 67 |
+
def get_base_filename(filepath):
|
| 68 |
+
if not filepath:
|
| 69 |
+
return ""
|
| 70 |
+
# Get the basename (filename with extension)
|
| 71 |
+
basename = os.path.basename(filepath)
|
| 72 |
+
# Remove extension
|
| 73 |
+
filename = os.path.splitext(basename)[0]
|
| 74 |
+
return filename
|
| 75 |
+
|
| 76 |
+
custom_css = """
|
| 77 |
+
.image-container img {
|
| 78 |
+
object-fit: contain;
|
| 79 |
+
width: 100%;
|
| 80 |
+
height: 100%;
|
| 81 |
+
}
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
# Define the Gradio interface
|
| 85 |
+
def create_demo():
|
| 86 |
+
with gr.Blocks(theme=gr.themes.Monochrome(),css=custom_css) as demo:
|
| 87 |
+
# Replace the existing logo code section:
|
| 88 |
+
with gr.Row():
|
| 89 |
+
with gr.Column(scale=3):
|
| 90 |
+
gr.Markdown("# MATCHA: Museum Alt-Text for Cultural Heritage with AI 🍵 🌿")
|
| 91 |
+
gr.Markdown("Upload one or more images to generate accessible alternative text (designed to meet WCAG Guidelines)")
|
| 92 |
+
gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme")
|
| 93 |
+
with gr.Column(scale=1):
|
| 94 |
+
with gr.Row():
|
| 95 |
+
# Use gr.Image with all interactive features disabled
|
| 96 |
+
gr.Image("images/nhm_logo.png", show_label=False, height=120,
|
| 97 |
+
interactive=False, show_download_button=False,
|
| 98 |
+
show_share_button=False, show_fullscreen_button=False,
|
| 99 |
+
container=False)
|
| 100 |
+
gr.Image("images/nml_logo.png", show_label=False, height=120,
|
| 101 |
+
interactive=False, show_download_button=False,
|
| 102 |
+
show_share_button=False, show_fullscreen_button=False,
|
| 103 |
+
container=False)
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
# Left column: Controls and uploads
|
| 107 |
+
with gr.Column(scale=1):
|
| 108 |
+
# Upload interface
|
| 109 |
+
upload_button = gr.UploadButton(
|
| 110 |
+
"Click to Upload Images",
|
| 111 |
+
file_types=["image"],
|
| 112 |
+
file_count="multiple"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Define choices as a list of tuples: (Display Name, Internal Value)
|
| 116 |
+
model_choices = [
|
| 117 |
+
# Gemini
|
| 118 |
+
("Gemini 2.0 Flash (default)", "google/gemini-2.0-flash-001"),
|
| 119 |
+
# GPT-4.1 Series
|
| 120 |
+
("GPT-4.1 Nano", "gpt-4.1-nano"),
|
| 121 |
+
("GPT-4.1 Mini", "gpt-4.1-mini"),
|
| 122 |
+
("GPT-4.1", "gpt-4.1"),
|
| 123 |
+
("ChatGPT Latest", "openai/chatgpt-4o-latest"),
|
| 124 |
+
# Other Models
|
| 125 |
+
("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
|
| 126 |
+
("Llama 4 Maverick", "meta-llama/llama-4-maverick"),
|
| 127 |
+
# Experimental Models
|
| 128 |
+
("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
|
| 129 |
+
("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
# Find the internal value of the default choice
|
| 133 |
+
default_model_internal_value = "google/gemini-2.0-flash-001"
|
| 134 |
+
|
| 135 |
+
# Add model selection dropdown
|
| 136 |
+
model_choice = gr.Dropdown(
|
| 137 |
+
choices=model_choices,
|
| 138 |
+
label="Select Model",
|
| 139 |
+
value=default_model_internal_value, # Use the internal value for the default
|
| 140 |
+
# info="Choose the language model to use." # Optional: Add extra info tooltip
|
| 141 |
+
visible=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Add response length selection
|
| 146 |
+
length_choice = gr.Radio(
|
| 147 |
+
choices=["short", "medium", "long"],
|
| 148 |
+
label="Response Length",
|
| 149 |
+
value="medium",
|
| 150 |
+
info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Preview gallery for uploaded images
|
| 154 |
+
gr.Markdown("### Uploaded Images")
|
| 155 |
+
input_gallery = gr.Gallery(
|
| 156 |
+
label="",
|
| 157 |
+
columns=3,
|
| 158 |
+
height=150,
|
| 159 |
+
object_fit="contain"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Analysis button
|
| 163 |
+
analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
|
| 164 |
+
|
| 165 |
+
# Hidden state component to store image info
|
| 166 |
+
image_state = gr.State([])
|
| 167 |
+
filename_state = gr.State([])
|
| 168 |
+
|
| 169 |
+
# CSV download component
|
| 170 |
+
csv_download = gr.File(label="CSV Results")
|
| 171 |
+
|
| 172 |
+
# Right column: Display area
|
| 173 |
+
with gr.Column(scale=2):
|
| 174 |
+
with gr.Column(elem_classes="image-container"):
|
| 175 |
+
current_image = gr.Image(
|
| 176 |
+
label="Current Image",
|
| 177 |
+
height=600, # Set the maximum desired height
|
| 178 |
+
width=1000,
|
| 179 |
+
type="filepath",
|
| 180 |
+
show_fullscreen_button=True,
|
| 181 |
+
show_download_button=False,
|
| 182 |
+
show_share_button=False,
|
| 183 |
+
elem_classes="image-container"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Navigation row
|
| 187 |
+
with gr.Row():
|
| 188 |
+
prev_button = gr.Button("← Previous", size="sm")
|
| 189 |
+
image_counter = gr.Markdown("", elem_id="image-counter")
|
| 190 |
+
next_button = gr.Button("Next →", size="sm")
|
| 191 |
+
|
| 192 |
+
# Alt-text heading and output
|
| 193 |
+
gr.Markdown("### Generated Alt-text")
|
| 194 |
+
|
| 195 |
+
# Alt-text
|
| 196 |
+
analysis_text = gr.Textbox(
|
| 197 |
+
label="",
|
| 198 |
+
value="Upload images and select model to generate alt-text!",
|
| 199 |
+
lines=6,
|
| 200 |
+
max_lines=10,
|
| 201 |
+
interactive=False,
|
| 202 |
+
show_label=False
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Hidden state for gallery navigation
|
| 206 |
+
current_index = gr.State(0)
|
| 207 |
+
all_images = gr.State([])
|
| 208 |
+
all_results = gr.State([])
|
| 209 |
+
|
| 210 |
+
# Handle file uploads - store files for use during analysis
|
| 211 |
+
def handle_upload(files):
|
| 212 |
+
file_paths = []
|
| 213 |
+
file_names = []
|
| 214 |
+
for file in files:
|
| 215 |
+
file_paths.append(file.name)
|
| 216 |
+
# Extract filename without path or extension for later use
|
| 217 |
+
file_names.append(get_base_filename(file.name))
|
| 218 |
+
return file_paths, file_paths, file_names
|
| 219 |
+
|
| 220 |
+
upload_button.upload(
|
| 221 |
+
fn=handle_upload,
|
| 222 |
+
inputs=[upload_button],
|
| 223 |
+
outputs=[input_gallery, image_state, filename_state]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Function to analyze images
|
| 227 |
+
# Modify the analyze_images function in your code:
|
| 228 |
+
|
| 229 |
+
def analyze_images(image_paths, model_choice, length_choice, filenames):
|
| 230 |
+
if not image_paths:
|
| 231 |
+
return [], [], 0, "", "No images", "", ""
|
| 232 |
+
|
| 233 |
+
# Get system prompt based on length selection
|
| 234 |
+
sys_prompt = get_sys_prompt(length_choice)
|
| 235 |
+
|
| 236 |
+
image_results = []
|
| 237 |
+
|
| 238 |
+
for i, image_path in enumerate(image_paths):
|
| 239 |
+
# Use original filename as image_id if available
|
| 240 |
+
if i < len(filenames) and filenames[i]:
|
| 241 |
+
image_id = filenames[i]
|
| 242 |
+
else:
|
| 243 |
+
image_id = f"Image {i+1}"
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# Open the image file for analysis
|
| 247 |
+
img = Image.open(image_path)
|
| 248 |
+
prompt0 = prompt_new() # Using the new prompt function
|
| 249 |
+
|
| 250 |
+
# Extract the actual model name (remove any labels like "(default)")
|
| 251 |
+
if " (" in model_choice:
|
| 252 |
+
model_name = model_choice.split(" (")[0]
|
| 253 |
+
else:
|
| 254 |
+
model_name = model_choice
|
| 255 |
+
|
| 256 |
+
# Check if this is one of the Gemini models that needs special handling
|
| 257 |
+
is_gemini_model = "gemini-2.5-pro" in model_name or "gemini-2.0-flash-thinking" in model_name
|
| 258 |
+
|
| 259 |
+
if is_gemini_model:
|
| 260 |
+
try:
|
| 261 |
+
# First try using the dedicated gemini client
|
| 262 |
+
result = gemini.generate_caption(
|
| 263 |
+
img,
|
| 264 |
+
model=model_name,
|
| 265 |
+
max_image_size=512,
|
| 266 |
+
prompt=prompt0,
|
| 267 |
+
prompt_dev=sys_prompt,
|
| 268 |
+
temperature=1
|
| 269 |
+
)
|
| 270 |
+
except Exception as gemini_error:
|
| 271 |
+
# If gemini client fails, fall back to standard OR client
|
| 272 |
+
result = OR.generate_caption(
|
| 273 |
+
img,
|
| 274 |
+
model=model_name,
|
| 275 |
+
max_image_size=512,
|
| 276 |
+
prompt=prompt0,
|
| 277 |
+
prompt_dev=sys_prompt,
|
| 278 |
+
temperature=1
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
# For all other models, use OR client directly
|
| 282 |
+
result = OR.generate_caption(
|
| 283 |
+
img,
|
| 284 |
+
model=model_name,
|
| 285 |
+
max_image_size=512,
|
| 286 |
+
prompt=prompt0,
|
| 287 |
+
prompt_dev=sys_prompt,
|
| 288 |
+
temperature=1
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Add to results
|
| 292 |
+
image_results.append({
|
| 293 |
+
"image_id": image_id,
|
| 294 |
+
"content": result
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
error_message = f"Error: {str(e)}"
|
| 299 |
+
image_results.append({
|
| 300 |
+
"image_id": image_id,
|
| 301 |
+
"content": error_message
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
# Create a CSV file for download
|
| 305 |
+
csv_path = create_csv_file_simple(image_results)
|
| 306 |
+
|
| 307 |
+
# Set up initial display with first image
|
| 308 |
+
if len(image_paths) > 0:
|
| 309 |
+
initial_image = image_paths[0]
|
| 310 |
+
initial_counter = f"{1} of {len(image_paths)}"
|
| 311 |
+
initial_text = image_results[0]["content"]
|
| 312 |
+
else:
|
| 313 |
+
initial_image = ""
|
| 314 |
+
initial_text = "No images analyzed"
|
| 315 |
+
initial_counter = "0 of 0"
|
| 316 |
+
|
| 317 |
+
return (image_paths, image_results, 0, initial_image, initial_counter,
|
| 318 |
+
initial_text, csv_path)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# Function to navigate to previous image
|
| 322 |
+
def go_to_prev(current_idx, images, results):
|
| 323 |
+
if not images or len(images) == 0:
|
| 324 |
+
return current_idx, "", "0 of 0", ""
|
| 325 |
+
|
| 326 |
+
new_idx = (current_idx - 1) % len(images) if current_idx > 0 else len(images) - 1
|
| 327 |
+
counter_html = f"{new_idx + 1} of {len(images)}"
|
| 328 |
+
|
| 329 |
+
return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
|
| 330 |
+
|
| 331 |
+
# Function to navigate to next image
|
| 332 |
+
def go_to_next(current_idx, images, results):
|
| 333 |
+
if not images or len(images) == 0:
|
| 334 |
+
return current_idx, "", "0 of 0", ""
|
| 335 |
+
|
| 336 |
+
new_idx = (current_idx + 1) % len(images)
|
| 337 |
+
counter_html = f"{new_idx + 1} of {len(images)}"
|
| 338 |
+
|
| 339 |
+
return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
|
| 340 |
+
|
| 341 |
+
# Connect the analyze button
|
| 342 |
+
analyze_button.click(
|
| 343 |
+
fn=analyze_images,
|
| 344 |
+
inputs=[image_state, model_choice, length_choice, filename_state],
|
| 345 |
+
outputs=[
|
| 346 |
+
all_images, all_results, current_index, current_image, image_counter,
|
| 347 |
+
analysis_text, csv_download
|
| 348 |
+
]
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Connect navigation buttons
|
| 352 |
+
prev_button.click(
|
| 353 |
+
fn=go_to_prev,
|
| 354 |
+
inputs=[current_index, all_images, all_results],
|
| 355 |
+
outputs=[current_index, current_image, image_counter, analysis_text]
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
next_button.click(
|
| 359 |
+
fn=go_to_next,
|
| 360 |
+
inputs=[current_index, all_images, all_results],
|
| 361 |
+
outputs=[current_index, current_image, image_counter, analysis_text]
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Optional: Add additional information
|
| 365 |
+
with gr.Accordion("About", open=False):
|
| 366 |
+
gr.Markdown("""
|
| 367 |
+
## About this demo
|
| 368 |
+
|
| 369 |
+
This demo generates alternative text for images.
|
| 370 |
+
|
| 371 |
+
- Upload one or more images using the upload button
|
| 372 |
+
- Choose a model and response length for generation
|
| 373 |
+
- Navigate through the images with the Previous and Next buttons
|
| 374 |
+
- Download CSV with all results
|
| 375 |
+
|
| 376 |
+
Developed by the Natural History Museum in Partnership with National Museums Liverpool.
|
| 377 |
+
|
| 378 |
+
If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
|
| 379 |
+
chris.addis@nhm.ac.uk
|
| 380 |
+
""")
|
| 381 |
+
|
| 382 |
+
return demo
|
| 383 |
+
|
| 384 |
+
# Launch the app
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
app = create_demo()
|
| 387 |
+
app.launch()
|
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import os
|
|
| 6 |
import requests
|
| 7 |
import json
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
-
import openai
|
| 10 |
import base64
|
| 11 |
import csv
|
| 12 |
import tempfile
|
|
@@ -18,17 +18,32 @@ if os.path.exists(".env"):
|
|
| 18 |
load_dotenv()
|
| 19 |
|
| 20 |
from io import BytesIO
|
| 21 |
-
import numpy as np
|
| 22 |
-
import requests
|
| 23 |
-
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
# import libraries
|
| 26 |
-
from library.utils_model import *
|
| 27 |
-
from library.utils_html import *
|
| 28 |
-
from library.utils_prompt import *
|
| 29 |
|
| 30 |
OR = OpenRouterAPI()
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Path for storing user preferences
|
| 34 |
PREFERENCES_FILE = "data/user_preferences.csv"
|
|
@@ -41,27 +56,31 @@ def get_sys_prompt(length="medium"):
|
|
| 41 |
dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
|
| 42 |
elif length == "medium":
|
| 43 |
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
|
| 44 |
-
else:
|
| 45 |
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
|
| 46 |
return dev_prompt
|
| 47 |
|
| 48 |
def create_csv_file_simple(results):
|
| 49 |
"""Create a CSV file from the results and return the path"""
|
| 50 |
# Create a temporary file
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Extract original filename without path or extension
|
| 67 |
def get_base_filename(filepath):
|
|
@@ -73,17 +92,10 @@ def get_base_filename(filepath):
|
|
| 73 |
filename = os.path.splitext(basename)[0]
|
| 74 |
return filename
|
| 75 |
|
| 76 |
-
custom_css = """
|
| 77 |
-
.image-container img {
|
| 78 |
-
object-fit: contain;
|
| 79 |
-
width: 100%;
|
| 80 |
-
height: 100%;
|
| 81 |
-
}
|
| 82 |
-
"""
|
| 83 |
-
|
| 84 |
# Define the Gradio interface
|
| 85 |
def create_demo():
|
| 86 |
-
|
|
|
|
| 87 |
# Replace the existing logo code section:
|
| 88 |
with gr.Row():
|
| 89 |
with gr.Column(scale=3):
|
|
@@ -93,25 +105,25 @@ def create_demo():
|
|
| 93 |
with gr.Column(scale=1):
|
| 94 |
with gr.Row():
|
| 95 |
# Use gr.Image with all interactive features disabled
|
| 96 |
-
gr.Image("images/nhm_logo.png", show_label=False, height=120,
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
gr.Image("images/nml_logo.png", show_label=False, height=120,
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
with gr.Row():
|
| 106 |
# Left column: Controls and uploads
|
| 107 |
with gr.Column(scale=1):
|
| 108 |
# Upload interface
|
| 109 |
upload_button = gr.UploadButton(
|
| 110 |
-
"Click to Upload Images",
|
| 111 |
-
file_types=["image"],
|
| 112 |
file_count="multiple"
|
| 113 |
)
|
| 114 |
-
|
| 115 |
# Define choices as a list of tuples: (Display Name, Internal Value)
|
| 116 |
model_choices = [
|
| 117 |
# Gemini
|
|
@@ -128,10 +140,10 @@ def create_demo():
|
|
| 128 |
("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
|
| 129 |
("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
|
| 130 |
]
|
| 131 |
-
|
| 132 |
# Find the internal value of the default choice
|
| 133 |
default_model_internal_value = "google/gemini-2.0-flash-001"
|
| 134 |
-
|
| 135 |
# Add model selection dropdown
|
| 136 |
model_choice = gr.Dropdown(
|
| 137 |
choices=model_choices,
|
|
@@ -141,7 +153,7 @@ def create_demo():
|
|
| 141 |
visible=True
|
| 142 |
)
|
| 143 |
|
| 144 |
-
|
| 145 |
# Add response length selection
|
| 146 |
length_choice = gr.Radio(
|
| 147 |
choices=["short", "medium", "long"],
|
|
@@ -149,195 +161,204 @@ def create_demo():
|
|
| 149 |
value="medium",
|
| 150 |
info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
|
| 151 |
)
|
| 152 |
-
|
| 153 |
# Preview gallery for uploaded images
|
| 154 |
gr.Markdown("### Uploaded Images")
|
| 155 |
input_gallery = gr.Gallery(
|
| 156 |
-
label="",
|
| 157 |
-
columns=3,
|
| 158 |
-
height=150,
|
| 159 |
-
object_fit="contain"
|
|
|
|
| 160 |
)
|
| 161 |
-
|
| 162 |
# Analysis button
|
| 163 |
analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
|
| 164 |
-
|
| 165 |
# Hidden state component to store image info
|
| 166 |
image_state = gr.State([])
|
| 167 |
filename_state = gr.State([])
|
| 168 |
-
|
| 169 |
# CSV download component
|
| 170 |
-
csv_download = gr.File(label="CSV Results")
|
| 171 |
-
|
| 172 |
# Right column: Display area
|
| 173 |
with gr.Column(scale=2):
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
| 186 |
# Navigation row
|
| 187 |
with gr.Row():
|
| 188 |
prev_button = gr.Button("← Previous", size="sm")
|
| 189 |
-
image_counter = gr.Markdown("", elem_id="image-counter")
|
| 190 |
next_button = gr.Button("Next →", size="sm")
|
| 191 |
-
|
| 192 |
# Alt-text heading and output
|
| 193 |
gr.Markdown("### Generated Alt-text")
|
| 194 |
-
|
| 195 |
# Alt-text
|
| 196 |
analysis_text = gr.Textbox(
|
| 197 |
-
label="",
|
| 198 |
-
value="Upload images and
|
| 199 |
lines=6,
|
| 200 |
max_lines=10,
|
| 201 |
-
interactive=
|
| 202 |
-
show_label=False
|
| 203 |
)
|
| 204 |
-
|
| 205 |
# Hidden state for gallery navigation
|
| 206 |
current_index = gr.State(0)
|
| 207 |
all_images = gr.State([])
|
| 208 |
all_results = gr.State([])
|
| 209 |
-
|
| 210 |
# Handle file uploads - store files for use during analysis
|
| 211 |
-
def handle_upload(files):
|
|
|
|
|
|
|
| 212 |
file_paths = []
|
| 213 |
file_names = []
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
upload_button.upload(
|
| 221 |
fn=handle_upload,
|
| 222 |
-
inputs=[upload_button],
|
| 223 |
-
outputs=[input_gallery, image_state, filename_state
|
|
|
|
| 224 |
)
|
| 225 |
-
|
| 226 |
-
# Function to analyze images
|
| 227 |
-
# Modify the analyze_images function in your code:
|
| 228 |
|
|
|
|
| 229 |
def analyze_images(image_paths, model_choice, length_choice, filenames):
|
| 230 |
if not image_paths:
|
| 231 |
-
|
| 232 |
-
|
|
|
|
| 233 |
# Get system prompt based on length selection
|
| 234 |
sys_prompt = get_sys_prompt(length_choice)
|
| 235 |
-
|
| 236 |
image_results = []
|
| 237 |
-
|
| 238 |
-
|
|
|
|
| 239 |
# Use original filename as image_id if available
|
| 240 |
if i < len(filenames) and filenames[i]:
|
| 241 |
image_id = filenames[i]
|
| 242 |
else:
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
try:
|
| 246 |
# Open the image file for analysis
|
| 247 |
img = Image.open(image_path)
|
| 248 |
prompt0 = prompt_new() # Using the new prompt function
|
| 249 |
-
|
| 250 |
-
#
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
model_name = model_choice
|
| 255 |
-
|
| 256 |
# Check if this is one of the Gemini models that needs special handling
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
)
|
| 280 |
-
else:
|
| 281 |
-
# For all other models, use OR client directly
|
| 282 |
-
result = OR.generate_caption(
|
| 283 |
-
img,
|
| 284 |
-
model=model_name,
|
| 285 |
-
max_image_size=512,
|
| 286 |
-
prompt=prompt0,
|
| 287 |
-
prompt_dev=sys_prompt,
|
| 288 |
-
temperature=1
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
# Add to results
|
| 292 |
image_results.append({
|
| 293 |
"image_id": image_id,
|
| 294 |
-
"content": result
|
| 295 |
})
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
except Exception as e:
|
| 298 |
-
error_message = f"Error: {str(e)}"
|
|
|
|
| 299 |
image_results.append({
|
| 300 |
"image_id": image_id,
|
| 301 |
"content": error_message
|
| 302 |
})
|
| 303 |
-
|
| 304 |
# Create a CSV file for download
|
| 305 |
csv_path = create_csv_file_simple(image_results)
|
| 306 |
-
|
| 307 |
-
# Set up initial display with first image
|
| 308 |
-
if
|
| 309 |
initial_image = image_paths[0]
|
| 310 |
-
initial_counter = f"
|
| 311 |
initial_text = image_results[0]["content"]
|
| 312 |
-
else:
|
| 313 |
-
initial_image =
|
| 314 |
-
initial_text = "
|
| 315 |
initial_counter = "0 of 0"
|
| 316 |
-
|
| 317 |
-
return (image_paths, image_results, 0, initial_image, initial_counter,
|
| 318 |
initial_text, csv_path)
|
| 319 |
|
| 320 |
-
|
| 321 |
# Function to navigate to previous image
|
| 322 |
def go_to_prev(current_idx, images, results):
|
| 323 |
-
if not images or len(images) == 0:
|
| 324 |
-
return current_idx,
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
# Function to navigate to next image
|
| 332 |
def go_to_next(current_idx, images, results):
|
| 333 |
-
if not images or len(images) == 0:
|
| 334 |
-
return current_idx,
|
| 335 |
-
|
| 336 |
new_idx = (current_idx + 1) % len(images)
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
| 341 |
# Connect the analyze button
|
| 342 |
analyze_button.click(
|
| 343 |
fn=analyze_images,
|
|
@@ -347,41 +368,64 @@ def create_demo():
|
|
| 347 |
analysis_text, csv_download
|
| 348 |
]
|
| 349 |
)
|
| 350 |
-
|
| 351 |
# Connect navigation buttons
|
| 352 |
prev_button.click(
|
| 353 |
fn=go_to_prev,
|
| 354 |
inputs=[current_index, all_images, all_results],
|
| 355 |
-
outputs=[current_index, current_image, image_counter, analysis_text]
|
|
|
|
|
|
|
| 356 |
)
|
| 357 |
-
|
| 358 |
next_button.click(
|
| 359 |
fn=go_to_next,
|
| 360 |
inputs=[current_index, all_images, all_results],
|
| 361 |
-
outputs=[current_index, current_image, image_counter, analysis_text]
|
|
|
|
|
|
|
| 362 |
)
|
| 363 |
-
|
| 364 |
# Optional: Add additional information
|
| 365 |
with gr.Accordion("About", open=False):
|
| 366 |
gr.Markdown("""
|
| 367 |
## About this demo
|
| 368 |
-
|
| 369 |
-
This demo generates alternative text for images.
|
| 370 |
-
|
| 371 |
-
- Upload one or more images using the
|
| 372 |
-
-
|
| 373 |
-
-
|
| 374 |
-
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
| 379 |
chris.addis@nhm.ac.uk
|
| 380 |
""")
|
| 381 |
-
|
| 382 |
return demo
|
| 383 |
|
| 384 |
# Launch the app
|
| 385 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
app = create_demo()
|
| 387 |
-
app.launch()
|
|
|
|
| 6 |
import requests
|
| 7 |
import json
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
+
# import openai # Assuming openai is not directly used in this snippet anymore
|
| 10 |
import base64
|
| 11 |
import csv
|
| 12 |
import tempfile
|
|
|
|
| 18 |
load_dotenv()
|
| 19 |
|
| 20 |
from io import BytesIO
|
| 21 |
+
# import numpy as np # Already imported
|
| 22 |
+
# import requests # Already imported
|
| 23 |
+
# from PIL import Image # Already imported
|
| 24 |
+
|
| 25 |
+
# Assume these are defined elsewhere or replace with actual implementations if needed
|
| 26 |
+
class OpenRouterAPI:
|
| 27 |
+
def __init__(self, api_key=None, base_url=None):
|
| 28 |
+
pass
|
| 29 |
+
def generate_caption(self, img, model, max_image_size, prompt, prompt_dev, temperature):
|
| 30 |
+
# Dummy implementation for testing
|
| 31 |
+
print(f"Generating caption with model: {model}")
|
| 32 |
+
return f"Generated caption for image using {model}."
|
| 33 |
+
|
| 34 |
+
def prompt_new():
|
| 35 |
+
# Dummy implementation
|
| 36 |
+
return "Describe this image."
|
| 37 |
+
# --- End Dummy implementations ---
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
OR = OpenRouterAPI()
|
| 41 |
+
# Ensure GEMINI_API_KEY is set in your environment or .env file
|
| 42 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 43 |
+
if not gemini_api_key:
|
| 44 |
+
print("Warning: GEMINI_API_KEY environment variable not set. Using placeholder.")
|
| 45 |
+
# Handle the case where the key might be missing, perhaps disable the Gemini models or use a default key if applicable
|
| 46 |
+
gemini = OpenRouterAPI(api_key=gemini_api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/") # Note: This base_url looks like OpenAI, ensure it's correct for Gemini via OpenRouter or direct API
|
| 47 |
|
| 48 |
# Path for storing user preferences
|
| 49 |
PREFERENCES_FILE = "data/user_preferences.csv"
|
|
|
|
| 56 |
dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
|
| 57 |
elif length == "medium":
|
| 58 |
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
|
| 59 |
+
else: # long
|
| 60 |
dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
|
| 61 |
return dev_prompt
|
| 62 |
|
| 63 |
def create_csv_file_simple(results):
|
| 64 |
"""Create a CSV file from the results and return the path"""
|
| 65 |
# Create a temporary file
|
| 66 |
+
try:
|
| 67 |
+
# Use NamedTemporaryFile to simplify cleanup
|
| 68 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, newline='', encoding='utf-8') as f:
|
| 69 |
+
path = f.name
|
| 70 |
+
writer = csv.writer(f)
|
| 71 |
+
# Write header
|
| 72 |
+
writer.writerow(['image_id', 'content'])
|
| 73 |
+
# Write data
|
| 74 |
+
for result in results:
|
| 75 |
+
writer.writerow([
|
| 76 |
+
result.get('image_id', ''),
|
| 77 |
+
result.get('content', '')
|
| 78 |
+
])
|
| 79 |
+
return path
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Error creating CSV: {e}")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
|
| 85 |
# Extract original filename without path or extension
|
| 86 |
def get_base_filename(filepath):
|
|
|
|
| 92 |
filename = os.path.splitext(basename)[0]
|
| 93 |
return filename
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
# Define the Gradio interface
|
| 96 |
def create_demo():
|
| 97 |
+
# Removed custom_css as we will use the built-in object_fit parameter
|
| 98 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: # Removed css=custom_css
|
| 99 |
# Replace the existing logo code section:
|
| 100 |
with gr.Row():
|
| 101 |
with gr.Column(scale=3):
|
|
|
|
| 105 |
with gr.Column(scale=1):
|
| 106 |
with gr.Row():
|
| 107 |
# Use gr.Image with all interactive features disabled
|
| 108 |
+
gr.Image("images/nhm_logo.png", show_label=False, height=120,
|
| 109 |
+
interactive=False, show_download_button=False,
|
| 110 |
+
show_share_button=False, show_fullscreen_button=False,
|
| 111 |
+
container=False, elem_id="nhm-logo") # Added elem_id for clarity
|
| 112 |
+
gr.Image("images/nml_logo.png", show_label=False, height=120,
|
| 113 |
+
interactive=False, show_download_button=False,
|
| 114 |
+
show_share_button=False, show_fullscreen_button=False,
|
| 115 |
+
container=False, elem_id="nml-logo") # Added elem_id for clarity
|
| 116 |
+
|
| 117 |
with gr.Row():
|
| 118 |
# Left column: Controls and uploads
|
| 119 |
with gr.Column(scale=1):
|
| 120 |
# Upload interface
|
| 121 |
upload_button = gr.UploadButton(
|
| 122 |
+
"Click to Upload Images",
|
| 123 |
+
file_types=["image"],
|
| 124 |
file_count="multiple"
|
| 125 |
)
|
| 126 |
+
|
| 127 |
# Define choices as a list of tuples: (Display Name, Internal Value)
|
| 128 |
model_choices = [
|
| 129 |
# Gemini
|
|
|
|
| 140 |
("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
|
| 141 |
("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
|
| 142 |
]
|
| 143 |
+
|
| 144 |
# Find the internal value of the default choice
|
| 145 |
default_model_internal_value = "google/gemini-2.0-flash-001"
|
| 146 |
+
|
| 147 |
# Add model selection dropdown
|
| 148 |
model_choice = gr.Dropdown(
|
| 149 |
choices=model_choices,
|
|
|
|
| 153 |
visible=True
|
| 154 |
)
|
| 155 |
|
| 156 |
+
|
| 157 |
# Add response length selection
|
| 158 |
length_choice = gr.Radio(
|
| 159 |
choices=["short", "medium", "long"],
|
|
|
|
| 161 |
value="medium",
|
| 162 |
info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
|
| 163 |
)
|
| 164 |
+
|
| 165 |
# Preview gallery for uploaded images
|
| 166 |
gr.Markdown("### Uploaded Images")
|
| 167 |
input_gallery = gr.Gallery(
|
| 168 |
+
label="Uploaded Image Previews", # Added label
|
| 169 |
+
columns=3,
|
| 170 |
+
height=150, # Reduced height slightly if needed
|
| 171 |
+
object_fit="contain", # Ensure gallery previews also fit well
|
| 172 |
+
show_label=False # Hide the label text above the gallery
|
| 173 |
)
|
| 174 |
+
|
| 175 |
# Analysis button
|
| 176 |
analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
|
| 177 |
+
|
| 178 |
# Hidden state component to store image info
|
| 179 |
image_state = gr.State([])
|
| 180 |
filename_state = gr.State([])
|
| 181 |
+
|
| 182 |
# CSV download component
|
| 183 |
+
csv_download = gr.File(label="Download CSV Results") # Clarified label
|
| 184 |
+
|
| 185 |
# Right column: Display area
|
| 186 |
with gr.Column(scale=2):
|
| 187 |
+
# Directly place the Image component here
|
| 188 |
+
# Use object_fit='contain' and set height. Width will adapt.
|
| 189 |
+
current_image = gr.Image(
|
| 190 |
+
label="Current Image",
|
| 191 |
+
height=600, # Set the maximum desired height
|
| 192 |
+
# width=1000, # REMOVED fixed width
|
| 193 |
+
type="filepath",
|
| 194 |
+
object_fit="contain", # ADDED: Scale image while preserving aspect ratio
|
| 195 |
+
show_fullscreen_button=True,
|
| 196 |
+
show_download_button=False, # Keep false as per original code
|
| 197 |
+
show_share_button=False, # Keep false as per original code
|
| 198 |
+
show_label=False # Hide the "Current Image" label above the image
|
| 199 |
+
# Removed elem_classes="image-container" as object_fit handles it
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
# Navigation row
|
| 203 |
with gr.Row():
|
| 204 |
prev_button = gr.Button("← Previous", size="sm")
|
| 205 |
+
image_counter = gr.Markdown("0 of 0", elem_id="image-counter") # Default text
|
| 206 |
next_button = gr.Button("Next →", size="sm")
|
| 207 |
+
|
| 208 |
# Alt-text heading and output
|
| 209 |
gr.Markdown("### Generated Alt-text")
|
| 210 |
+
|
| 211 |
# Alt-text
|
| 212 |
analysis_text = gr.Textbox(
|
| 213 |
+
label="Generated Text", # Added label
|
| 214 |
+
value="Upload images and click 'Generate Alt-Text'.", # Initial message
|
| 215 |
lines=6,
|
| 216 |
max_lines=10,
|
| 217 |
+
interactive=True, # Allow user to edit if desired? Set back to False if not.
|
| 218 |
+
show_label=False # Hide the label text
|
| 219 |
)
|
| 220 |
+
|
| 221 |
# Hidden state for gallery navigation
|
| 222 |
current_index = gr.State(0)
|
| 223 |
all_images = gr.State([])
|
| 224 |
all_results = gr.State([])
|
| 225 |
+
|
| 226 |
# Handle file uploads - store files for use during analysis
|
| 227 |
+
def handle_upload(files, current_paths, current_filenames):
|
| 228 |
+
# Append new files to existing ones if needed, or replace
|
| 229 |
+
# This version replaces existing uploads each time
|
| 230 |
file_paths = []
|
| 231 |
file_names = []
|
| 232 |
+
if files: # Check if files is not None
|
| 233 |
+
for file in files:
|
| 234 |
+
file_paths.append(file.name)
|
| 235 |
+
# Extract filename without path or extension for later use
|
| 236 |
+
file_names.append(get_base_filename(file.name))
|
| 237 |
+
# Reset view if new files are uploaded
|
| 238 |
+
return file_paths, file_paths, file_names, 0, None, "0 of 0", "Upload images and click 'Generate Alt-Text'."
|
| 239 |
+
|
| 240 |
upload_button.upload(
|
| 241 |
fn=handle_upload,
|
| 242 |
+
inputs=[upload_button, image_state, filename_state], # Pass current state if appending needed
|
| 243 |
+
outputs=[input_gallery, image_state, filename_state, # Outputs updated state
|
| 244 |
+
current_index, current_image, image_counter, analysis_text] # Reset display
|
| 245 |
)
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
# Function to analyze images
|
| 248 |
def analyze_images(image_paths, model_choice, length_choice, filenames):
|
| 249 |
if not image_paths:
|
| 250 |
+
# Return state that clears/resets the output fields
|
| 251 |
+
return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None # No CSV path
|
| 252 |
+
|
| 253 |
# Get system prompt based on length selection
|
| 254 |
sys_prompt = get_sys_prompt(length_choice)
|
| 255 |
+
|
| 256 |
image_results = []
|
| 257 |
+
analysis_progress = gr.Progress(track_tqdm=True) # Add progress bar
|
| 258 |
+
|
| 259 |
+
for i, image_path in enumerate(analysis_progress.tqdm(image_paths, desc="Analyzing Images")):
|
| 260 |
# Use original filename as image_id if available
|
| 261 |
if i < len(filenames) and filenames[i]:
|
| 262 |
image_id = filenames[i]
|
| 263 |
else:
|
| 264 |
+
# Fallback if filename extraction failed or list mismatch
|
| 265 |
+
image_id = f"Image_{i+1}_{os.path.basename(image_path)}"
|
| 266 |
+
|
| 267 |
+
|
| 268 |
try:
|
| 269 |
# Open the image file for analysis
|
| 270 |
img = Image.open(image_path)
|
| 271 |
prompt0 = prompt_new() # Using the new prompt function
|
| 272 |
+
|
| 273 |
+
# Determine the actual model name (strip extra labels)
|
| 274 |
+
# Using the selected internal value directly is safer
|
| 275 |
+
model_name = model_choice # Already the internal value from dropdown
|
| 276 |
+
|
|
|
|
|
|
|
| 277 |
# Check if this is one of the Gemini models that needs special handling
|
| 278 |
+
# Note: This check might need adjustment based on how OpenRouterAPI handles different model endpoints/APIs
|
| 279 |
+
is_experimental_gemini = "gemini-2.5-pro" in model_name or "gemini-2.0-flash-thinking" in model_name
|
| 280 |
+
is_google_gemini = model_name.startswith("google/gemini")
|
| 281 |
+
|
| 282 |
+
client_to_use = OR # Default to standard OpenRouter client
|
| 283 |
+
|
| 284 |
+
# Example logic: Use dedicated client if API key and specific model match
|
| 285 |
+
# Adjust this based on your OpenRouterAPI class capabilities
|
| 286 |
+
# if is_experimental_gemini and gemini: # And potentially check if gemini client is configured
|
| 287 |
+
# client_to_use = gemini
|
| 288 |
+
# elif is_google_gemini and gemini:
|
| 289 |
+
# client_to_use = gemini # Or maybe all google models use the specific client?
|
| 290 |
+
|
| 291 |
+
result = client_to_use.generate_caption(
|
| 292 |
+
img,
|
| 293 |
+
model=model_name,
|
| 294 |
+
max_image_size=512, # Consider if this should be configurable
|
| 295 |
+
prompt=prompt0,
|
| 296 |
+
prompt_dev=sys_prompt,
|
| 297 |
+
temperature=1 # Consider if this should be configurable
|
| 298 |
+
)
|
| 299 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
# Add to results
|
| 301 |
image_results.append({
|
| 302 |
"image_id": image_id,
|
| 303 |
+
"content": result.strip() # Trim whitespace
|
| 304 |
})
|
| 305 |
+
|
| 306 |
+
except FileNotFoundError:
|
| 307 |
+
error_message = f"Error: File not found at path '{image_path}'"
|
| 308 |
+
print(error_message) # Log error
|
| 309 |
+
image_results.append({"image_id": image_id, "content": error_message})
|
| 310 |
except Exception as e:
|
| 311 |
+
error_message = f"Error processing {image_id}: {str(e)}"
|
| 312 |
+
print(error_message) # Log error
|
| 313 |
image_results.append({
|
| 314 |
"image_id": image_id,
|
| 315 |
"content": error_message
|
| 316 |
})
|
| 317 |
+
|
| 318 |
# Create a CSV file for download
|
| 319 |
csv_path = create_csv_file_simple(image_results)
|
| 320 |
+
|
| 321 |
+
# Set up initial display with first image result
|
| 322 |
+
if image_results: # Check if there are results (even errors)
|
| 323 |
initial_image = image_paths[0]
|
| 324 |
+
initial_counter = f"1 of {len(image_paths)}"
|
| 325 |
initial_text = image_results[0]["content"]
|
| 326 |
+
else: # Should not happen if image_paths is not empty, but good fallback
|
| 327 |
+
initial_image = None
|
| 328 |
+
initial_text = "Analysis complete, but no results generated."
|
| 329 |
initial_counter = "0 of 0"
|
| 330 |
+
|
| 331 |
+
return (image_paths, image_results, 0, initial_image, initial_counter,
|
| 332 |
initial_text, csv_path)
|
| 333 |
|
| 334 |
+
|
| 335 |
# Function to navigate to previous image
|
| 336 |
def go_to_prev(current_idx, images, results):
|
| 337 |
+
if not images or not results or len(images) == 0: # Check results too
|
| 338 |
+
return current_idx, None, "0 of 0", "" # Return None for image path
|
| 339 |
+
|
| 340 |
+
# Calculate new index correctly wrapping around
|
| 341 |
+
new_idx = (current_idx - 1 + len(images)) % len(images)
|
| 342 |
+
counter_text = f"{new_idx + 1} of {len(images)}"
|
| 343 |
+
|
| 344 |
+
# Ensure result exists for the index
|
| 345 |
+
result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
|
| 346 |
+
|
| 347 |
+
return (new_idx, images[new_idx], counter_text, result_content)
|
| 348 |
+
|
| 349 |
# Function to navigate to next image
|
| 350 |
def go_to_next(current_idx, images, results):
|
| 351 |
+
if not images or not results or len(images) == 0: # Check results too
|
| 352 |
+
return current_idx, None, "0 of 0", "" # Return None for image path
|
| 353 |
+
|
| 354 |
new_idx = (current_idx + 1) % len(images)
|
| 355 |
+
counter_text = f"{new_idx + 1} of {len(images)}"
|
| 356 |
+
|
| 357 |
+
# Ensure result exists for the index
|
| 358 |
+
result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
|
| 359 |
+
|
| 360 |
+
return (new_idx, images[new_idx], counter_text, result_content)
|
| 361 |
+
|
| 362 |
# Connect the analyze button
|
| 363 |
analyze_button.click(
|
| 364 |
fn=analyze_images,
|
|
|
|
| 368 |
analysis_text, csv_download
|
| 369 |
]
|
| 370 |
)
|
| 371 |
+
|
| 372 |
# Connect navigation buttons
|
| 373 |
prev_button.click(
|
| 374 |
fn=go_to_prev,
|
| 375 |
inputs=[current_index, all_images, all_results],
|
| 376 |
+
outputs=[current_index, current_image, image_counter, analysis_text],
|
| 377 |
+
# Add queue=False if navigation should be instant and not wait for analysis
|
| 378 |
+
queue=False
|
| 379 |
)
|
| 380 |
+
|
| 381 |
next_button.click(
|
| 382 |
fn=go_to_next,
|
| 383 |
inputs=[current_index, all_images, all_results],
|
| 384 |
+
outputs=[current_index, current_image, image_counter, analysis_text],
|
| 385 |
+
# Add queue=False if navigation should be instant
|
| 386 |
+
queue=False
|
| 387 |
)
|
| 388 |
+
|
| 389 |
# Optional: Add additional information
|
| 390 |
with gr.Accordion("About", open=False):
|
| 391 |
gr.Markdown("""
|
| 392 |
## About this demo
|
| 393 |
+
|
| 394 |
+
This demo generates alternative text for museum object images using various AI models.
|
| 395 |
+
|
| 396 |
+
- Upload one or more images using the 'Click to Upload Images' button.
|
| 397 |
+
- Select the AI model and desired response length.
|
| 398 |
+
- Click 'Generate Alt-Text'. Processing time depends on the number of images and the selected model.
|
| 399 |
+
- View the generated text for each image using the Previous and Next buttons.
|
| 400 |
+
- Download a CSV file containing all results using the 'Download CSV Results' link.
|
| 401 |
+
|
| 402 |
+
Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme.
|
| 403 |
+
|
| 404 |
+
If you find any bugs, have problems, or have suggestions, please feel free to get in touch:
|
| 405 |
chris.addis@nhm.ac.uk
|
| 406 |
""")
|
| 407 |
+
|
| 408 |
return demo
|
| 409 |
|
| 410 |
# Launch the app
|
| 411 |
if __name__ == "__main__":
|
| 412 |
+
# --- Dummy classes/functions for local execution ---
|
| 413 |
+
# You would remove these if running with your actual library files
|
| 414 |
+
# class OpenRouterAPI:
|
| 415 |
+
# def __init__(self, api_key=None, base_url=None): pass
|
| 416 |
+
# def generate_caption(self, img, model, max_image_size, prompt, prompt_dev, temperature): return f"Dummy caption for {model}"
|
| 417 |
+
# def prompt_new(): return "Describe."
|
| 418 |
+
# OR = OpenRouterAPI()
|
| 419 |
+
# gemini = OpenRouterAPI()
|
| 420 |
+
# --- End Dummy section ---
|
| 421 |
+
|
| 422 |
+
# Create dummy image files if they don't exist for local testing
|
| 423 |
+
os.makedirs("images", exist_ok=True)
|
| 424 |
+
if not os.path.exists("images/nhm_logo.png"):
|
| 425 |
+
Image.new('RGB', (60, 30), color = 'red').save('images/nhm_logo.png')
|
| 426 |
+
if not os.path.exists("images/nml_logo.png"):
|
| 427 |
+
Image.new('RGB', (60, 30), color = 'blue').save('images/nml_logo.png')
|
| 428 |
+
|
| 429 |
+
|
| 430 |
app = create_demo()
|
| 431 |
+
app.launch() # Add share=True if you want a public link when running locally
|