Chris Addis commited on
Commit
b81c5d1
·
1 Parent(s): 1c02ad4
Files changed (2) hide show
  1. app-Copy1.py +387 -0
  2. app.py +230 -186
app-Copy1.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import requests
7
+ import json
8
+ from dotenv import load_dotenv
9
+ import openai
10
+ import base64
11
+ import csv
12
+ import tempfile
13
+ import datetime
14
+
15
+ # Load environment variables from .env file if it exists (for local development)
16
+ # On Hugging Face Spaces, the secrets are automatically available as environment variables
17
+ if os.path.exists(".env"):
18
+ load_dotenv()
19
+
20
+ from io import BytesIO
21
+ import numpy as np
22
+ import requests
23
+ from PIL import Image
24
+
25
+ # import libraries
26
+ from library.utils_model import *
27
+ from library.utils_html import *
28
+ from library.utils_prompt import *
29
+
30
+ OR = OpenRouterAPI()
31
+ gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
32
+
33
+ # Path for storing user preferences
34
+ PREFERENCES_FILE = "data/user_preferences.csv"
35
+
36
+ # Ensure directory exists
37
+ os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
38
+
39
+ def get_sys_prompt(length="medium"):
40
+ if length == "short":
41
+ dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
42
+ elif length == "medium":
43
+ dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
44
+ else:
45
+ dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
46
+ return dev_prompt
47
+
48
+ def create_csv_file_simple(results):
49
+ """Create a CSV file from the results and return the path"""
50
+ # Create a temporary file
51
+ fd, path = tempfile.mkstemp(suffix='.csv')
52
+
53
+ with os.fdopen(fd, 'w', newline='') as f:
54
+ writer = csv.writer(f)
55
+ # Write header
56
+ writer.writerow(['image_id', 'content'])
57
+ # Write data
58
+ for result in results:
59
+ writer.writerow([
60
+ result.get('image_id', ''),
61
+ result.get('content', '')
62
+ ])
63
+
64
+ return path
65
+
66
+ # Extract original filename without path or extension
67
+ def get_base_filename(filepath):
68
+ if not filepath:
69
+ return ""
70
+ # Get the basename (filename with extension)
71
+ basename = os.path.basename(filepath)
72
+ # Remove extension
73
+ filename = os.path.splitext(basename)[0]
74
+ return filename
75
+
76
+ custom_css = """
77
+ .image-container img {
78
+ object-fit: contain;
79
+ width: 100%;
80
+ height: 100%;
81
+ }
82
+ """
83
+
84
+ # Define the Gradio interface
85
+ def create_demo():
86
+ with gr.Blocks(theme=gr.themes.Monochrome(),css=custom_css) as demo:
87
+ # Replace the existing logo code section:
88
+ with gr.Row():
89
+ with gr.Column(scale=3):
90
+ gr.Markdown("# MATCHA: Museum Alt-Text for Cultural Heritage with AI 🍵 🌿")
91
+ gr.Markdown("Upload one or more images to generate accessible alternative text (designed to meet WCAG Guidelines)")
92
+ gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme")
93
+ with gr.Column(scale=1):
94
+ with gr.Row():
95
+ # Use gr.Image with all interactive features disabled
96
+ gr.Image("images/nhm_logo.png", show_label=False, height=120,
97
+ interactive=False, show_download_button=False,
98
+ show_share_button=False, show_fullscreen_button=False,
99
+ container=False)
100
+ gr.Image("images/nml_logo.png", show_label=False, height=120,
101
+ interactive=False, show_download_button=False,
102
+ show_share_button=False, show_fullscreen_button=False,
103
+ container=False)
104
+
105
+ with gr.Row():
106
+ # Left column: Controls and uploads
107
+ with gr.Column(scale=1):
108
+ # Upload interface
109
+ upload_button = gr.UploadButton(
110
+ "Click to Upload Images",
111
+ file_types=["image"],
112
+ file_count="multiple"
113
+ )
114
+
115
+ # Define choices as a list of tuples: (Display Name, Internal Value)
116
+ model_choices = [
117
+ # Gemini
118
+ ("Gemini 2.0 Flash (default)", "google/gemini-2.0-flash-001"),
119
+ # GPT-4.1 Series
120
+ ("GPT-4.1 Nano", "gpt-4.1-nano"),
121
+ ("GPT-4.1 Mini", "gpt-4.1-mini"),
122
+ ("GPT-4.1", "gpt-4.1"),
123
+ ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
124
+ # Other Models
125
+ ("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
126
+ ("Llama 4 Maverick", "meta-llama/llama-4-maverick"),
127
+ # Experimental Models
128
+ ("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
129
+ ("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
130
+ ]
131
+
132
+ # Find the internal value of the default choice
133
+ default_model_internal_value = "google/gemini-2.0-flash-001"
134
+
135
+ # Add model selection dropdown
136
+ model_choice = gr.Dropdown(
137
+ choices=model_choices,
138
+ label="Select Model",
139
+ value=default_model_internal_value, # Use the internal value for the default
140
+ # info="Choose the language model to use." # Optional: Add extra info tooltip
141
+ visible=True
142
+ )
143
+
144
+
145
+ # Add response length selection
146
+ length_choice = gr.Radio(
147
+ choices=["short", "medium", "long"],
148
+ label="Response Length",
149
+ value="medium",
150
+ info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
151
+ )
152
+
153
+ # Preview gallery for uploaded images
154
+ gr.Markdown("### Uploaded Images")
155
+ input_gallery = gr.Gallery(
156
+ label="",
157
+ columns=3,
158
+ height=150,
159
+ object_fit="contain"
160
+ )
161
+
162
+ # Analysis button
163
+ analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
164
+
165
+ # Hidden state component to store image info
166
+ image_state = gr.State([])
167
+ filename_state = gr.State([])
168
+
169
+ # CSV download component
170
+ csv_download = gr.File(label="CSV Results")
171
+
172
+ # Right column: Display area
173
+ with gr.Column(scale=2):
174
+ with gr.Column(elem_classes="image-container"):
175
+ current_image = gr.Image(
176
+ label="Current Image",
177
+ height=600, # Set the maximum desired height
178
+ width=1000,
179
+ type="filepath",
180
+ show_fullscreen_button=True,
181
+ show_download_button=False,
182
+ show_share_button=False,
183
+ elem_classes="image-container"
184
+ )
185
+
186
+ # Navigation row
187
+ with gr.Row():
188
+ prev_button = gr.Button("← Previous", size="sm")
189
+ image_counter = gr.Markdown("", elem_id="image-counter")
190
+ next_button = gr.Button("Next →", size="sm")
191
+
192
+ # Alt-text heading and output
193
+ gr.Markdown("### Generated Alt-text")
194
+
195
+ # Alt-text
196
+ analysis_text = gr.Textbox(
197
+ label="",
198
+ value="Upload images and select model to generate alt-text!",
199
+ lines=6,
200
+ max_lines=10,
201
+ interactive=False,
202
+ show_label=False
203
+ )
204
+
205
+ # Hidden state for gallery navigation
206
+ current_index = gr.State(0)
207
+ all_images = gr.State([])
208
+ all_results = gr.State([])
209
+
210
+ # Handle file uploads - store files for use during analysis
211
+ def handle_upload(files):
212
+ file_paths = []
213
+ file_names = []
214
+ for file in files:
215
+ file_paths.append(file.name)
216
+ # Extract filename without path or extension for later use
217
+ file_names.append(get_base_filename(file.name))
218
+ return file_paths, file_paths, file_names
219
+
220
+ upload_button.upload(
221
+ fn=handle_upload,
222
+ inputs=[upload_button],
223
+ outputs=[input_gallery, image_state, filename_state]
224
+ )
225
+
226
+ # Function to analyze images
227
+ # Modify the analyze_images function in your code:
228
+
229
+ def analyze_images(image_paths, model_choice, length_choice, filenames):
230
+ if not image_paths:
231
+ return [], [], 0, "", "No images", "", ""
232
+
233
+ # Get system prompt based on length selection
234
+ sys_prompt = get_sys_prompt(length_choice)
235
+
236
+ image_results = []
237
+
238
+ for i, image_path in enumerate(image_paths):
239
+ # Use original filename as image_id if available
240
+ if i < len(filenames) and filenames[i]:
241
+ image_id = filenames[i]
242
+ else:
243
+ image_id = f"Image {i+1}"
244
+
245
+ try:
246
+ # Open the image file for analysis
247
+ img = Image.open(image_path)
248
+ prompt0 = prompt_new() # Using the new prompt function
249
+
250
+ # Extract the actual model name (remove any labels like "(default)")
251
+ if " (" in model_choice:
252
+ model_name = model_choice.split(" (")[0]
253
+ else:
254
+ model_name = model_choice
255
+
256
+ # Check if this is one of the Gemini models that needs special handling
257
+ is_gemini_model = "gemini-2.5-pro" in model_name or "gemini-2.0-flash-thinking" in model_name
258
+
259
+ if is_gemini_model:
260
+ try:
261
+ # First try using the dedicated gemini client
262
+ result = gemini.generate_caption(
263
+ img,
264
+ model=model_name,
265
+ max_image_size=512,
266
+ prompt=prompt0,
267
+ prompt_dev=sys_prompt,
268
+ temperature=1
269
+ )
270
+ except Exception as gemini_error:
271
+ # If gemini client fails, fall back to standard OR client
272
+ result = OR.generate_caption(
273
+ img,
274
+ model=model_name,
275
+ max_image_size=512,
276
+ prompt=prompt0,
277
+ prompt_dev=sys_prompt,
278
+ temperature=1
279
+ )
280
+ else:
281
+ # For all other models, use OR client directly
282
+ result = OR.generate_caption(
283
+ img,
284
+ model=model_name,
285
+ max_image_size=512,
286
+ prompt=prompt0,
287
+ prompt_dev=sys_prompt,
288
+ temperature=1
289
+ )
290
+
291
+ # Add to results
292
+ image_results.append({
293
+ "image_id": image_id,
294
+ "content": result
295
+ })
296
+
297
+ except Exception as e:
298
+ error_message = f"Error: {str(e)}"
299
+ image_results.append({
300
+ "image_id": image_id,
301
+ "content": error_message
302
+ })
303
+
304
+ # Create a CSV file for download
305
+ csv_path = create_csv_file_simple(image_results)
306
+
307
+ # Set up initial display with first image
308
+ if len(image_paths) > 0:
309
+ initial_image = image_paths[0]
310
+ initial_counter = f"{1} of {len(image_paths)}"
311
+ initial_text = image_results[0]["content"]
312
+ else:
313
+ initial_image = ""
314
+ initial_text = "No images analyzed"
315
+ initial_counter = "0 of 0"
316
+
317
+ return (image_paths, image_results, 0, initial_image, initial_counter,
318
+ initial_text, csv_path)
319
+
320
+
321
+ # Function to navigate to previous image
322
+ def go_to_prev(current_idx, images, results):
323
+ if not images or len(images) == 0:
324
+ return current_idx, "", "0 of 0", ""
325
+
326
+ new_idx = (current_idx - 1) % len(images) if current_idx > 0 else len(images) - 1
327
+ counter_html = f"{new_idx + 1} of {len(images)}"
328
+
329
+ return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
330
+
331
+ # Function to navigate to next image
332
+ def go_to_next(current_idx, images, results):
333
+ if not images or len(images) == 0:
334
+ return current_idx, "", "0 of 0", ""
335
+
336
+ new_idx = (current_idx + 1) % len(images)
337
+ counter_html = f"{new_idx + 1} of {len(images)}"
338
+
339
+ return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
340
+
341
+ # Connect the analyze button
342
+ analyze_button.click(
343
+ fn=analyze_images,
344
+ inputs=[image_state, model_choice, length_choice, filename_state],
345
+ outputs=[
346
+ all_images, all_results, current_index, current_image, image_counter,
347
+ analysis_text, csv_download
348
+ ]
349
+ )
350
+
351
+ # Connect navigation buttons
352
+ prev_button.click(
353
+ fn=go_to_prev,
354
+ inputs=[current_index, all_images, all_results],
355
+ outputs=[current_index, current_image, image_counter, analysis_text]
356
+ )
357
+
358
+ next_button.click(
359
+ fn=go_to_next,
360
+ inputs=[current_index, all_images, all_results],
361
+ outputs=[current_index, current_image, image_counter, analysis_text]
362
+ )
363
+
364
+ # Optional: Add additional information
365
+ with gr.Accordion("About", open=False):
366
+ gr.Markdown("""
367
+ ## About this demo
368
+
369
+ This demo generates alternative text for images.
370
+
371
+ - Upload one or more images using the upload button
372
+ - Choose a model and response length for generation
373
+ - Navigate through the images with the Previous and Next buttons
374
+ - Download CSV with all results
375
+
376
+ Developed by the Natural History Museum in Partnership with National Museums Liverpool.
377
+
378
+ If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
379
+ chris.addis@nhm.ac.uk
380
+ """)
381
+
382
+ return demo
383
+
384
+ # Launch the app
385
+ if __name__ == "__main__":
386
+ app = create_demo()
387
+ app.launch()
app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import requests
7
  import json
8
  from dotenv import load_dotenv
9
- import openai
10
  import base64
11
  import csv
12
  import tempfile
@@ -18,17 +18,32 @@ if os.path.exists(".env"):
18
  load_dotenv()
19
 
20
  from io import BytesIO
21
- import numpy as np
22
- import requests
23
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # import libraries
26
- from library.utils_model import *
27
- from library.utils_html import *
28
- from library.utils_prompt import *
29
 
30
  OR = OpenRouterAPI()
31
- gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
 
 
 
 
 
32
 
33
  # Path for storing user preferences
34
  PREFERENCES_FILE = "data/user_preferences.csv"
@@ -41,27 +56,31 @@ def get_sys_prompt(length="medium"):
41
  dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
42
  elif length == "medium":
43
  dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
44
- else:
45
  dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
46
  return dev_prompt
47
 
48
  def create_csv_file_simple(results):
49
  """Create a CSV file from the results and return the path"""
50
  # Create a temporary file
51
- fd, path = tempfile.mkstemp(suffix='.csv')
52
-
53
- with os.fdopen(fd, 'w', newline='') as f:
54
- writer = csv.writer(f)
55
- # Write header
56
- writer.writerow(['image_id', 'content'])
57
- # Write data
58
- for result in results:
59
- writer.writerow([
60
- result.get('image_id', ''),
61
- result.get('content', '')
62
- ])
63
-
64
- return path
 
 
 
 
65
 
66
  # Extract original filename without path or extension
67
  def get_base_filename(filepath):
@@ -73,17 +92,10 @@ def get_base_filename(filepath):
73
  filename = os.path.splitext(basename)[0]
74
  return filename
75
 
76
- custom_css = """
77
- .image-container img {
78
- object-fit: contain;
79
- width: 100%;
80
- height: 100%;
81
- }
82
- """
83
-
84
  # Define the Gradio interface
85
  def create_demo():
86
- with gr.Blocks(theme=gr.themes.Monochrome(),css=custom_css) as demo:
 
87
  # Replace the existing logo code section:
88
  with gr.Row():
89
  with gr.Column(scale=3):
@@ -93,25 +105,25 @@ def create_demo():
93
  with gr.Column(scale=1):
94
  with gr.Row():
95
  # Use gr.Image with all interactive features disabled
96
- gr.Image("images/nhm_logo.png", show_label=False, height=120,
97
- interactive=False, show_download_button=False,
98
- show_share_button=False, show_fullscreen_button=False,
99
- container=False)
100
- gr.Image("images/nml_logo.png", show_label=False, height=120,
101
- interactive=False, show_download_button=False,
102
- show_share_button=False, show_fullscreen_button=False,
103
- container=False)
104
-
105
  with gr.Row():
106
  # Left column: Controls and uploads
107
  with gr.Column(scale=1):
108
  # Upload interface
109
  upload_button = gr.UploadButton(
110
- "Click to Upload Images",
111
- file_types=["image"],
112
  file_count="multiple"
113
  )
114
-
115
  # Define choices as a list of tuples: (Display Name, Internal Value)
116
  model_choices = [
117
  # Gemini
@@ -128,10 +140,10 @@ def create_demo():
128
  ("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
129
  ("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
130
  ]
131
-
132
  # Find the internal value of the default choice
133
  default_model_internal_value = "google/gemini-2.0-flash-001"
134
-
135
  # Add model selection dropdown
136
  model_choice = gr.Dropdown(
137
  choices=model_choices,
@@ -141,7 +153,7 @@ def create_demo():
141
  visible=True
142
  )
143
 
144
-
145
  # Add response length selection
146
  length_choice = gr.Radio(
147
  choices=["short", "medium", "long"],
@@ -149,195 +161,204 @@ def create_demo():
149
  value="medium",
150
  info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
151
  )
152
-
153
  # Preview gallery for uploaded images
154
  gr.Markdown("### Uploaded Images")
155
  input_gallery = gr.Gallery(
156
- label="",
157
- columns=3,
158
- height=150,
159
- object_fit="contain"
 
160
  )
161
-
162
  # Analysis button
163
  analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
164
-
165
  # Hidden state component to store image info
166
  image_state = gr.State([])
167
  filename_state = gr.State([])
168
-
169
  # CSV download component
170
- csv_download = gr.File(label="CSV Results")
171
-
172
  # Right column: Display area
173
  with gr.Column(scale=2):
174
- with gr.Column(elem_classes="image-container"):
175
- current_image = gr.Image(
176
- label="Current Image",
177
- height=600, # Set the maximum desired height
178
- width=1000,
179
- type="filepath",
180
- show_fullscreen_button=True,
181
- show_download_button=False,
182
- show_share_button=False,
183
- elem_classes="image-container"
184
- )
185
-
 
 
 
186
  # Navigation row
187
  with gr.Row():
188
  prev_button = gr.Button("← Previous", size="sm")
189
- image_counter = gr.Markdown("", elem_id="image-counter")
190
  next_button = gr.Button("Next →", size="sm")
191
-
192
  # Alt-text heading and output
193
  gr.Markdown("### Generated Alt-text")
194
-
195
  # Alt-text
196
  analysis_text = gr.Textbox(
197
- label="",
198
- value="Upload images and select model to generate alt-text!",
199
  lines=6,
200
  max_lines=10,
201
- interactive=False,
202
- show_label=False
203
  )
204
-
205
  # Hidden state for gallery navigation
206
  current_index = gr.State(0)
207
  all_images = gr.State([])
208
  all_results = gr.State([])
209
-
210
  # Handle file uploads - store files for use during analysis
211
- def handle_upload(files):
 
 
212
  file_paths = []
213
  file_names = []
214
- for file in files:
215
- file_paths.append(file.name)
216
- # Extract filename without path or extension for later use
217
- file_names.append(get_base_filename(file.name))
218
- return file_paths, file_paths, file_names
219
-
 
 
220
  upload_button.upload(
221
  fn=handle_upload,
222
- inputs=[upload_button],
223
- outputs=[input_gallery, image_state, filename_state]
 
224
  )
225
-
226
- # Function to analyze images
227
- # Modify the analyze_images function in your code:
228
 
 
229
  def analyze_images(image_paths, model_choice, length_choice, filenames):
230
  if not image_paths:
231
- return [], [], 0, "", "No images", "", ""
232
-
 
233
  # Get system prompt based on length selection
234
  sys_prompt = get_sys_prompt(length_choice)
235
-
236
  image_results = []
237
-
238
- for i, image_path in enumerate(image_paths):
 
239
  # Use original filename as image_id if available
240
  if i < len(filenames) and filenames[i]:
241
  image_id = filenames[i]
242
  else:
243
- image_id = f"Image {i+1}"
244
-
 
 
245
  try:
246
  # Open the image file for analysis
247
  img = Image.open(image_path)
248
  prompt0 = prompt_new() # Using the new prompt function
249
-
250
- # Extract the actual model name (remove any labels like "(default)")
251
- if " (" in model_choice:
252
- model_name = model_choice.split(" (")[0]
253
- else:
254
- model_name = model_choice
255
-
256
  # Check if this is one of the Gemini models that needs special handling
257
- is_gemini_model = "gemini-2.5-pro" in model_name or "gemini-2.0-flash-thinking" in model_name
258
-
259
- if is_gemini_model:
260
- try:
261
- # First try using the dedicated gemini client
262
- result = gemini.generate_caption(
263
- img,
264
- model=model_name,
265
- max_image_size=512,
266
- prompt=prompt0,
267
- prompt_dev=sys_prompt,
268
- temperature=1
269
- )
270
- except Exception as gemini_error:
271
- # If gemini client fails, fall back to standard OR client
272
- result = OR.generate_caption(
273
- img,
274
- model=model_name,
275
- max_image_size=512,
276
- prompt=prompt0,
277
- prompt_dev=sys_prompt,
278
- temperature=1
279
- )
280
- else:
281
- # For all other models, use OR client directly
282
- result = OR.generate_caption(
283
- img,
284
- model=model_name,
285
- max_image_size=512,
286
- prompt=prompt0,
287
- prompt_dev=sys_prompt,
288
- temperature=1
289
- )
290
-
291
  # Add to results
292
  image_results.append({
293
  "image_id": image_id,
294
- "content": result
295
  })
296
-
 
 
 
 
297
  except Exception as e:
298
- error_message = f"Error: {str(e)}"
 
299
  image_results.append({
300
  "image_id": image_id,
301
  "content": error_message
302
  })
303
-
304
  # Create a CSV file for download
305
  csv_path = create_csv_file_simple(image_results)
306
-
307
- # Set up initial display with first image
308
- if len(image_paths) > 0:
309
  initial_image = image_paths[0]
310
- initial_counter = f"{1} of {len(image_paths)}"
311
  initial_text = image_results[0]["content"]
312
- else:
313
- initial_image = ""
314
- initial_text = "No images analyzed"
315
  initial_counter = "0 of 0"
316
-
317
- return (image_paths, image_results, 0, initial_image, initial_counter,
318
  initial_text, csv_path)
319
 
320
-
321
  # Function to navigate to previous image
322
  def go_to_prev(current_idx, images, results):
323
- if not images or len(images) == 0:
324
- return current_idx, "", "0 of 0", ""
325
-
326
- new_idx = (current_idx - 1) % len(images) if current_idx > 0 else len(images) - 1
327
- counter_html = f"{new_idx + 1} of {len(images)}"
328
-
329
- return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
330
-
 
 
 
 
331
  # Function to navigate to next image
332
  def go_to_next(current_idx, images, results):
333
- if not images or len(images) == 0:
334
- return current_idx, "", "0 of 0", ""
335
-
336
  new_idx = (current_idx + 1) % len(images)
337
- counter_html = f"{new_idx + 1} of {len(images)}"
338
-
339
- return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
340
-
 
 
 
341
  # Connect the analyze button
342
  analyze_button.click(
343
  fn=analyze_images,
@@ -347,41 +368,64 @@ def create_demo():
347
  analysis_text, csv_download
348
  ]
349
  )
350
-
351
  # Connect navigation buttons
352
  prev_button.click(
353
  fn=go_to_prev,
354
  inputs=[current_index, all_images, all_results],
355
- outputs=[current_index, current_image, image_counter, analysis_text]
 
 
356
  )
357
-
358
  next_button.click(
359
  fn=go_to_next,
360
  inputs=[current_index, all_images, all_results],
361
- outputs=[current_index, current_image, image_counter, analysis_text]
 
 
362
  )
363
-
364
  # Optional: Add additional information
365
  with gr.Accordion("About", open=False):
366
  gr.Markdown("""
367
  ## About this demo
368
-
369
- This demo generates alternative text for images.
370
-
371
- - Upload one or more images using the upload button
372
- - Choose a model and response length for generation
373
- - Navigate through the images with the Previous and Next buttons
374
- - Download CSV with all results
375
-
376
- Developed by the Natural History Museum in Partnership with National Museums Liverpool.
377
-
378
- If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
 
379
  chris.addis@nhm.ac.uk
380
  """)
381
-
382
  return demo
383
 
384
  # Launch the app
385
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  app = create_demo()
387
- app.launch()
 
6
  import requests
7
  import json
8
  from dotenv import load_dotenv
9
+ # import openai # Assuming openai is not directly used in this snippet anymore
10
  import base64
11
  import csv
12
  import tempfile
 
18
  load_dotenv()
19
 
20
  from io import BytesIO
21
+ # import numpy as np # Already imported
22
+ # import requests # Already imported
23
+ # from PIL import Image # Already imported
24
+
25
+ # Assume these are defined elsewhere or replace with actual implementations if needed
26
+ class OpenRouterAPI:
27
+ def __init__(self, api_key=None, base_url=None):
28
+ pass
29
+ def generate_caption(self, img, model, max_image_size, prompt, prompt_dev, temperature):
30
+ # Dummy implementation for testing
31
+ print(f"Generating caption with model: {model}")
32
+ return f"Generated caption for image using {model}."
33
+
34
+ def prompt_new():
35
+ # Dummy implementation
36
+ return "Describe this image."
37
+ # --- End Dummy implementations ---
38
 
 
 
 
 
39
 
40
  OR = OpenRouterAPI()
41
+ # Ensure GEMINI_API_KEY is set in your environment or .env file
42
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
43
+ if not gemini_api_key:
44
+ print("Warning: GEMINI_API_KEY environment variable not set. Using placeholder.")
45
+ # Handle the case where the key might be missing, perhaps disable the Gemini models or use a default key if applicable
46
+ gemini = OpenRouterAPI(api_key=gemini_api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/") # Note: This base_url looks like OpenAI, ensure it's correct for Gemini via OpenRouter or direct API
47
 
48
  # Path for storing user preferences
49
  PREFERENCES_FILE = "data/user_preferences.csv"
 
56
  dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
57
  elif length == "medium":
58
  dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
59
+ else: # long
60
  dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
61
  return dev_prompt
62
 
63
  def create_csv_file_simple(results):
64
  """Create a CSV file from the results and return the path"""
65
  # Create a temporary file
66
+ try:
67
+ # Use NamedTemporaryFile to simplify cleanup
68
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, newline='', encoding='utf-8') as f:
69
+ path = f.name
70
+ writer = csv.writer(f)
71
+ # Write header
72
+ writer.writerow(['image_id', 'content'])
73
+ # Write data
74
+ for result in results:
75
+ writer.writerow([
76
+ result.get('image_id', ''),
77
+ result.get('content', '')
78
+ ])
79
+ return path
80
+ except Exception as e:
81
+ print(f"Error creating CSV: {e}")
82
+ return None
83
+
84
 
85
  # Extract original filename without path or extension
86
  def get_base_filename(filepath):
 
92
  filename = os.path.splitext(basename)[0]
93
  return filename
94
 
 
 
 
 
 
 
 
 
95
  # Define the Gradio interface
96
  def create_demo():
97
+ # Removed custom_css as we will use the built-in object_fit parameter
98
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo: # Removed css=custom_css
99
  # Replace the existing logo code section:
100
  with gr.Row():
101
  with gr.Column(scale=3):
 
105
  with gr.Column(scale=1):
106
  with gr.Row():
107
  # Use gr.Image with all interactive features disabled
108
+ gr.Image("images/nhm_logo.png", show_label=False, height=120,
109
+ interactive=False, show_download_button=False,
110
+ show_share_button=False, show_fullscreen_button=False,
111
+ container=False, elem_id="nhm-logo") # Added elem_id for clarity
112
+ gr.Image("images/nml_logo.png", show_label=False, height=120,
113
+ interactive=False, show_download_button=False,
114
+ show_share_button=False, show_fullscreen_button=False,
115
+ container=False, elem_id="nml-logo") # Added elem_id for clarity
116
+
117
  with gr.Row():
118
  # Left column: Controls and uploads
119
  with gr.Column(scale=1):
120
  # Upload interface
121
  upload_button = gr.UploadButton(
122
+ "Click to Upload Images",
123
+ file_types=["image"],
124
  file_count="multiple"
125
  )
126
+
127
  # Define choices as a list of tuples: (Display Name, Internal Value)
128
  model_choices = [
129
  # Gemini
 
140
  ("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
141
  ("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
142
  ]
143
+
144
  # Find the internal value of the default choice
145
  default_model_internal_value = "google/gemini-2.0-flash-001"
146
+
147
  # Add model selection dropdown
148
  model_choice = gr.Dropdown(
149
  choices=model_choices,
 
153
  visible=True
154
  )
155
 
156
+
157
  # Add response length selection
158
  length_choice = gr.Radio(
159
  choices=["short", "medium", "long"],
 
161
  value="medium",
162
  info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
163
  )
164
+
165
  # Preview gallery for uploaded images
166
  gr.Markdown("### Uploaded Images")
167
  input_gallery = gr.Gallery(
168
+ label="Uploaded Image Previews", # Added label
169
+ columns=3,
170
+ height=150, # Reduced height slightly if needed
171
+ object_fit="contain", # Ensure gallery previews also fit well
172
+ show_label=False # Hide the label text above the gallery
173
  )
174
+
175
  # Analysis button
176
  analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
177
+
178
  # Hidden state component to store image info
179
  image_state = gr.State([])
180
  filename_state = gr.State([])
181
+
182
  # CSV download component
183
+ csv_download = gr.File(label="Download CSV Results") # Clarified label
184
+
185
  # Right column: Display area
186
  with gr.Column(scale=2):
187
+ # Directly place the Image component here
188
+ # Use object_fit='contain' and set height. Width will adapt.
189
+ current_image = gr.Image(
190
+ label="Current Image",
191
+ height=600, # Set the maximum desired height
192
+ # width=1000, # REMOVED fixed width
193
+ type="filepath",
194
+ object_fit="contain", # ADDED: Scale image while preserving aspect ratio
195
+ show_fullscreen_button=True,
196
+ show_download_button=False, # Keep false as per original code
197
+ show_share_button=False, # Keep false as per original code
198
+ show_label=False # Hide the "Current Image" label above the image
199
+ # Removed elem_classes="image-container" as object_fit handles it
200
+ )
201
+
202
  # Navigation row
203
  with gr.Row():
204
  prev_button = gr.Button("← Previous", size="sm")
205
+ image_counter = gr.Markdown("0 of 0", elem_id="image-counter") # Default text
206
  next_button = gr.Button("Next →", size="sm")
207
+
208
  # Alt-text heading and output
209
  gr.Markdown("### Generated Alt-text")
210
+
211
  # Alt-text
212
  analysis_text = gr.Textbox(
213
+ label="Generated Text", # Added label
214
+ value="Upload images and click 'Generate Alt-Text'.", # Initial message
215
  lines=6,
216
  max_lines=10,
217
+ interactive=True, # Allow user to edit if desired? Set back to False if not.
218
+ show_label=False # Hide the label text
219
  )
220
+
221
  # Hidden state for gallery navigation
222
  current_index = gr.State(0)
223
  all_images = gr.State([])
224
  all_results = gr.State([])
225
+
226
  # Handle file uploads - store files for use during analysis
227
+ def handle_upload(files, current_paths, current_filenames):
228
+ # Append new files to existing ones if needed, or replace
229
+ # This version replaces existing uploads each time
230
  file_paths = []
231
  file_names = []
232
+ if files: # Check if files is not None
233
+ for file in files:
234
+ file_paths.append(file.name)
235
+ # Extract filename without path or extension for later use
236
+ file_names.append(get_base_filename(file.name))
237
+ # Reset view if new files are uploaded
238
+ return file_paths, file_paths, file_names, 0, None, "0 of 0", "Upload images and click 'Generate Alt-Text'."
239
+
240
  upload_button.upload(
241
  fn=handle_upload,
242
+ inputs=[upload_button, image_state, filename_state], # Pass current state if appending needed
243
+ outputs=[input_gallery, image_state, filename_state, # Outputs updated state
244
+ current_index, current_image, image_counter, analysis_text] # Reset display
245
  )
 
 
 
246
 
247
+ # Function to analyze images
248
  def analyze_images(image_paths, model_choice, length_choice, filenames):
249
  if not image_paths:
250
+ # Return state that clears/resets the output fields
251
+ return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None # No CSV path
252
+
253
  # Get system prompt based on length selection
254
  sys_prompt = get_sys_prompt(length_choice)
255
+
256
  image_results = []
257
+ analysis_progress = gr.Progress(track_tqdm=True) # Add progress bar
258
+
259
+ for i, image_path in enumerate(analysis_progress.tqdm(image_paths, desc="Analyzing Images")):
260
  # Use original filename as image_id if available
261
  if i < len(filenames) and filenames[i]:
262
  image_id = filenames[i]
263
  else:
264
+ # Fallback if filename extraction failed or list mismatch
265
+ image_id = f"Image_{i+1}_{os.path.basename(image_path)}"
266
+
267
+
268
  try:
269
  # Open the image file for analysis
270
  img = Image.open(image_path)
271
  prompt0 = prompt_new() # Using the new prompt function
272
+
273
+ # Determine the actual model name (strip extra labels)
274
+ # Using the selected internal value directly is safer
275
+ model_name = model_choice # Already the internal value from dropdown
276
+
 
 
277
  # Check if this is one of the Gemini models that needs special handling
278
+ # Note: This check might need adjustment based on how OpenRouterAPI handles different model endpoints/APIs
279
+ is_experimental_gemini = "gemini-2.5-pro" in model_name or "gemini-2.0-flash-thinking" in model_name
280
+ is_google_gemini = model_name.startswith("google/gemini")
281
+
282
+ client_to_use = OR # Default to standard OpenRouter client
283
+
284
+ # Example logic: Use dedicated client if API key and specific model match
285
+ # Adjust this based on your OpenRouterAPI class capabilities
286
+ # if is_experimental_gemini and gemini: # And potentially check if gemini client is configured
287
+ # client_to_use = gemini
288
+ # elif is_google_gemini and gemini:
289
+ # client_to_use = gemini # Or maybe all google models use the specific client?
290
+
291
+ result = client_to_use.generate_caption(
292
+ img,
293
+ model=model_name,
294
+ max_image_size=512, # Consider if this should be configurable
295
+ prompt=prompt0,
296
+ prompt_dev=sys_prompt,
297
+ temperature=1 # Consider if this should be configurable
298
+ )
299
+
 
 
 
 
 
 
 
 
 
 
 
 
300
  # Add to results
301
  image_results.append({
302
  "image_id": image_id,
303
+ "content": result.strip() # Trim whitespace
304
  })
305
+
306
+ except FileNotFoundError:
307
+ error_message = f"Error: File not found at path '{image_path}'"
308
+ print(error_message) # Log error
309
+ image_results.append({"image_id": image_id, "content": error_message})
310
  except Exception as e:
311
+ error_message = f"Error processing {image_id}: {str(e)}"
312
+ print(error_message) # Log error
313
  image_results.append({
314
  "image_id": image_id,
315
  "content": error_message
316
  })
317
+
318
  # Create a CSV file for download
319
  csv_path = create_csv_file_simple(image_results)
320
+
321
+ # Set up initial display with first image result
322
+ if image_results: # Check if there are results (even errors)
323
  initial_image = image_paths[0]
324
+ initial_counter = f"1 of {len(image_paths)}"
325
  initial_text = image_results[0]["content"]
326
+ else: # Should not happen if image_paths is not empty, but good fallback
327
+ initial_image = None
328
+ initial_text = "Analysis complete, but no results generated."
329
  initial_counter = "0 of 0"
330
+
331
+ return (image_paths, image_results, 0, initial_image, initial_counter,
332
  initial_text, csv_path)
333
 
334
+
335
  # Function to navigate to previous image
336
  def go_to_prev(current_idx, images, results):
337
+ if not images or not results or len(images) == 0: # Check results too
338
+ return current_idx, None, "0 of 0", "" # Return None for image path
339
+
340
+ # Calculate new index correctly wrapping around
341
+ new_idx = (current_idx - 1 + len(images)) % len(images)
342
+ counter_text = f"{new_idx + 1} of {len(images)}"
343
+
344
+ # Ensure result exists for the index
345
+ result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
346
+
347
+ return (new_idx, images[new_idx], counter_text, result_content)
348
+
349
  # Function to navigate to next image
350
  def go_to_next(current_idx, images, results):
351
+ if not images or not results or len(images) == 0: # Check results too
352
+ return current_idx, None, "0 of 0", "" # Return None for image path
353
+
354
  new_idx = (current_idx + 1) % len(images)
355
+ counter_text = f"{new_idx + 1} of {len(images)}"
356
+
357
+ # Ensure result exists for the index
358
+ result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
359
+
360
+ return (new_idx, images[new_idx], counter_text, result_content)
361
+
362
  # Connect the analyze button
363
  analyze_button.click(
364
  fn=analyze_images,
 
368
  analysis_text, csv_download
369
  ]
370
  )
371
+
372
  # Connect navigation buttons
373
  prev_button.click(
374
  fn=go_to_prev,
375
  inputs=[current_index, all_images, all_results],
376
+ outputs=[current_index, current_image, image_counter, analysis_text],
377
+ # Add queue=False if navigation should be instant and not wait for analysis
378
+ queue=False
379
  )
380
+
381
  next_button.click(
382
  fn=go_to_next,
383
  inputs=[current_index, all_images, all_results],
384
+ outputs=[current_index, current_image, image_counter, analysis_text],
385
+ # Add queue=False if navigation should be instant
386
+ queue=False
387
  )
388
+
389
  # Optional: Add additional information
390
  with gr.Accordion("About", open=False):
391
  gr.Markdown("""
392
  ## About this demo
393
+
394
+ This demo generates alternative text for museum object images using various AI models.
395
+
396
+ - Upload one or more images using the 'Click to Upload Images' button.
397
+ - Select the AI model and desired response length.
398
+ - Click 'Generate Alt-Text'. Processing time depends on the number of images and the selected model.
399
+ - View the generated text for each image using the Previous and Next buttons.
400
+ - Download a CSV file containing all results using the 'Download CSV Results' link.
401
+
402
+ Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme.
403
+
404
+ If you find any bugs, have problems, or have suggestions, please feel free to get in touch:
405
  chris.addis@nhm.ac.uk
406
  """)
407
+
408
  return demo
409
 
410
  # Launch the app
411
  if __name__ == "__main__":
412
+ # --- Dummy classes/functions for local execution ---
413
+ # You would remove these if running with your actual library files
414
+ # class OpenRouterAPI:
415
+ # def __init__(self, api_key=None, base_url=None): pass
416
+ # def generate_caption(self, img, model, max_image_size, prompt, prompt_dev, temperature): return f"Dummy caption for {model}"
417
+ # def prompt_new(): return "Describe."
418
+ # OR = OpenRouterAPI()
419
+ # gemini = OpenRouterAPI()
420
+ # --- End Dummy section ---
421
+
422
+ # Create dummy image files if they don't exist for local testing
423
+ os.makedirs("images", exist_ok=True)
424
+ if not os.path.exists("images/nhm_logo.png"):
425
+ Image.new('RGB', (60, 30), color = 'red').save('images/nhm_logo.png')
426
+ if not os.path.exists("images/nml_logo.png"):
427
+ Image.new('RGB', (60, 30), color = 'blue').save('images/nml_logo.png')
428
+
429
+
430
  app = create_demo()
431
+ app.launch() # Add share=True if you want a public link when running locally