Chris Addis commited on
Commit
a1d918e
·
1 Parent(s): af23186
Files changed (2) hide show
  1. .ipynb_checkpoints/app-Copy1-checkpoint.py +387 -0
  2. app.py +10 -7
.ipynb_checkpoints/app-Copy1-checkpoint.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import requests
7
+ import json
8
+ from dotenv import load_dotenv
9
+ import openai
10
+ import base64
11
+ import csv
12
+ import tempfile
13
+ import datetime
14
+
15
+ # Load environment variables from .env file if it exists (for local development)
16
+ # On Hugging Face Spaces, the secrets are automatically available as environment variables
17
+ if os.path.exists(".env"):
18
+ load_dotenv()
19
+
20
+ from io import BytesIO
21
+ import numpy as np
22
+ import requests
23
+ from PIL import Image
24
+
25
+ # import libraries
26
+ from library.utils_model import *
27
+ from library.utils_html import *
28
+ from library.utils_prompt import *
29
+
30
+ OR = OpenRouterAPI()
31
+ gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
32
+
33
+ # Path for storing user preferences
34
+ PREFERENCES_FILE = "data/user_preferences.csv"
35
+
36
+ # Ensure directory exists
37
+ os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
38
+
39
+ def get_sys_prompt(length="medium"):
40
+ if length == "short":
41
+ dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
42
+ elif length == "medium":
43
+ dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
44
+ else:
45
+ dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
46
+ return dev_prompt
47
+
48
+ def create_csv_file_simple(results):
49
+ """Create a CSV file from the results and return the path"""
50
+ # Create a temporary file
51
+ fd, path = tempfile.mkstemp(suffix='.csv')
52
+
53
+ with os.fdopen(fd, 'w', newline='') as f:
54
+ writer = csv.writer(f)
55
+ # Write header
56
+ writer.writerow(['image_id', 'content'])
57
+ # Write data
58
+ for result in results:
59
+ writer.writerow([
60
+ result.get('image_id', ''),
61
+ result.get('content', '')
62
+ ])
63
+
64
+ return path
65
+
66
+ # Extract original filename without path or extension
67
+ def get_base_filename(filepath):
68
+ if not filepath:
69
+ return ""
70
+ # Get the basename (filename with extension)
71
+ basename = os.path.basename(filepath)
72
+ # Remove extension
73
+ filename = os.path.splitext(basename)[0]
74
+ return filename
75
+
76
+ custom_css = """
77
+ .image-container img {
78
+ object-fit: contain;
79
+ width: 100%;
80
+ height: 100%;
81
+ }
82
+ """
83
+
84
+ # Define the Gradio interface
85
+ def create_demo():
86
+ with gr.Blocks(theme=gr.themes.Monochrome(),css=custom_css) as demo:
87
+ # Replace the existing logo code section:
88
+ with gr.Row():
89
+ with gr.Column(scale=3):
90
+ gr.Markdown("# MATCHA: Museum Alt-Text for Cultural Heritage with AI 🍵 🌿")
91
+ gr.Markdown("Upload one or more images to generate accessible alternative text (designed to meet WCAG Guidelines)")
92
+ gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme")
93
+ with gr.Column(scale=1):
94
+ with gr.Row():
95
+ # Use gr.Image with all interactive features disabled
96
+ gr.Image("images/nhm_logo.png", show_label=False, height=120,
97
+ interactive=False, show_download_button=False,
98
+ show_share_button=False, show_fullscreen_button=False,
99
+ container=False)
100
+ gr.Image("images/nml_logo.png", show_label=False, height=120,
101
+ interactive=False, show_download_button=False,
102
+ show_share_button=False, show_fullscreen_button=False,
103
+ container=False)
104
+
105
+ with gr.Row():
106
+ # Left column: Controls and uploads
107
+ with gr.Column(scale=1):
108
+ # Upload interface
109
+ upload_button = gr.UploadButton(
110
+ "Click to Upload Images",
111
+ file_types=["image"],
112
+ file_count="multiple"
113
+ )
114
+
115
+ # Define choices as a list of tuples: (Display Name, Internal Value)
116
+ model_choices = [
117
+ # Gemini
118
+ ("Gemini 2.0 Flash (default)", "google/gemini-2.0-flash-001"),
119
+ # GPT-4.1 Series
120
+ ("GPT-4.1 Nano", "gpt-4.1-nano"),
121
+ ("GPT-4.1 Mini", "gpt-4.1-mini"),
122
+ ("GPT-4.1", "gpt-4.1"),
123
+ ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
124
+ # Other Models
125
+ ("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
126
+ ("Llama 4 Maverick", "meta-llama/llama-4-maverick"),
127
+ # Experimental Models
128
+ ("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
129
+ ("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
130
+ ]
131
+
132
+ # Find the internal value of the default choice
133
+ default_model_internal_value = "google/gemini-2.0-flash-001"
134
+
135
+ # Add model selection dropdown
136
+ model_choice = gr.Dropdown(
137
+ choices=model_choices,
138
+ label="Select Model",
139
+ value=default_model_internal_value, # Use the internal value for the default
140
+ # info="Choose the language model to use." # Optional: Add extra info tooltip
141
+ visible=True
142
+ )
143
+
144
+
145
+ # Add response length selection
146
+ length_choice = gr.Radio(
147
+ choices=["short", "medium", "long"],
148
+ label="Response Length",
149
+ value="medium",
150
+ info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
151
+ )
152
+
153
+ # Preview gallery for uploaded images
154
+ gr.Markdown("### Uploaded Images")
155
+ input_gallery = gr.Gallery(
156
+ label="",
157
+ columns=3,
158
+ height=150,
159
+ object_fit="contain"
160
+ )
161
+
162
+ # Analysis button
163
+ analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
164
+
165
+ # Hidden state component to store image info
166
+ image_state = gr.State([])
167
+ filename_state = gr.State([])
168
+
169
+ # CSV download component
170
+ csv_download = gr.File(label="CSV Results")
171
+
172
+ # Right column: Display area
173
+ with gr.Column(scale=2):
174
+ with gr.Column(elem_classes="image-container"):
175
+ current_image = gr.Image(
176
+ label="Current Image",
177
+ height=600, # Set the maximum desired height
178
+ width=1000,
179
+ type="filepath",
180
+ show_fullscreen_button=True,
181
+ show_download_button=False,
182
+ show_share_button=False,
183
+ elem_classes="image-container"
184
+ )
185
+
186
+ # Navigation row
187
+ with gr.Row():
188
+ prev_button = gr.Button("← Previous", size="sm")
189
+ image_counter = gr.Markdown("", elem_id="image-counter")
190
+ next_button = gr.Button("Next →", size="sm")
191
+
192
+ # Alt-text heading and output
193
+ gr.Markdown("### Generated Alt-text")
194
+
195
+ # Alt-text
196
+ analysis_text = gr.Textbox(
197
+ label="",
198
+ value="Upload images and select model to generate alt-text!",
199
+ lines=6,
200
+ max_lines=10,
201
+ interactive=False,
202
+ show_label=False
203
+ )
204
+
205
+ # Hidden state for gallery navigation
206
+ current_index = gr.State(0)
207
+ all_images = gr.State([])
208
+ all_results = gr.State([])
209
+
210
+ # Handle file uploads - store files for use during analysis
211
+ def handle_upload(files):
212
+ file_paths = []
213
+ file_names = []
214
+ for file in files:
215
+ file_paths.append(file.name)
216
+ # Extract filename without path or extension for later use
217
+ file_names.append(get_base_filename(file.name))
218
+ return file_paths, file_paths, file_names
219
+
220
+ upload_button.upload(
221
+ fn=handle_upload,
222
+ inputs=[upload_button],
223
+ outputs=[input_gallery, image_state, filename_state]
224
+ )
225
+
226
+ # Function to analyze images
227
+ # Modify the analyze_images function in your code:
228
+
229
+ def analyze_images(image_paths, model_choice, length_choice, filenames):
230
+ if not image_paths:
231
+ return [], [], 0, "", "No images", "", ""
232
+
233
+ # Get system prompt based on length selection
234
+ sys_prompt = get_sys_prompt(length_choice)
235
+
236
+ image_results = []
237
+
238
+ for i, image_path in enumerate(image_paths):
239
+ # Use original filename as image_id if available
240
+ if i < len(filenames) and filenames[i]:
241
+ image_id = filenames[i]
242
+ else:
243
+ image_id = f"Image {i+1}"
244
+
245
+ try:
246
+ # Open the image file for analysis
247
+ img = Image.open(image_path)
248
+ prompt0 = prompt_new() # Using the new prompt function
249
+
250
+ # Extract the actual model name (remove any labels like "(default)")
251
+ if " (" in model_choice:
252
+ model_name = model_choice.split(" (")[0]
253
+ else:
254
+ model_name = model_choice
255
+
256
+ # Check if this is one of the Gemini models that needs special handling
257
+ is_gemini_model = "gemini-2.5-pro" in model_name or "gemini-2.0-flash-thinking" in model_name
258
+
259
+ if is_gemini_model:
260
+ try:
261
+ # First try using the dedicated gemini client
262
+ result = gemini.generate_caption(
263
+ img,
264
+ model=model_name,
265
+ max_image_size=512,
266
+ prompt=prompt0,
267
+ prompt_dev=sys_prompt,
268
+ temperature=1
269
+ )
270
+ except Exception as gemini_error:
271
+ # If gemini client fails, fall back to standard OR client
272
+ result = OR.generate_caption(
273
+ img,
274
+ model=model_name,
275
+ max_image_size=512,
276
+ prompt=prompt0,
277
+ prompt_dev=sys_prompt,
278
+ temperature=1
279
+ )
280
+ else:
281
+ # For all other models, use OR client directly
282
+ result = OR.generate_caption(
283
+ img,
284
+ model=model_name,
285
+ max_image_size=512,
286
+ prompt=prompt0,
287
+ prompt_dev=sys_prompt,
288
+ temperature=1
289
+ )
290
+
291
+ # Add to results
292
+ image_results.append({
293
+ "image_id": image_id,
294
+ "content": result
295
+ })
296
+
297
+ except Exception as e:
298
+ error_message = f"Error: {str(e)}"
299
+ image_results.append({
300
+ "image_id": image_id,
301
+ "content": error_message
302
+ })
303
+
304
+ # Create a CSV file for download
305
+ csv_path = create_csv_file_simple(image_results)
306
+
307
+ # Set up initial display with first image
308
+ if len(image_paths) > 0:
309
+ initial_image = image_paths[0]
310
+ initial_counter = f"{1} of {len(image_paths)}"
311
+ initial_text = image_results[0]["content"]
312
+ else:
313
+ initial_image = ""
314
+ initial_text = "No images analyzed"
315
+ initial_counter = "0 of 0"
316
+
317
+ return (image_paths, image_results, 0, initial_image, initial_counter,
318
+ initial_text, csv_path)
319
+
320
+
321
+ # Function to navigate to previous image
322
+ def go_to_prev(current_idx, images, results):
323
+ if not images or len(images) == 0:
324
+ return current_idx, "", "0 of 0", ""
325
+
326
+ new_idx = (current_idx - 1) % len(images) if current_idx > 0 else len(images) - 1
327
+ counter_html = f"{new_idx + 1} of {len(images)}"
328
+
329
+ return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
330
+
331
+ # Function to navigate to next image
332
+ def go_to_next(current_idx, images, results):
333
+ if not images or len(images) == 0:
334
+ return current_idx, "", "0 of 0", ""
335
+
336
+ new_idx = (current_idx + 1) % len(images)
337
+ counter_html = f"{new_idx + 1} of {len(images)}"
338
+
339
+ return (new_idx, images[new_idx], counter_html, results[new_idx]["content"])
340
+
341
+ # Connect the analyze button
342
+ analyze_button.click(
343
+ fn=analyze_images,
344
+ inputs=[image_state, model_choice, length_choice, filename_state],
345
+ outputs=[
346
+ all_images, all_results, current_index, current_image, image_counter,
347
+ analysis_text, csv_download
348
+ ]
349
+ )
350
+
351
+ # Connect navigation buttons
352
+ prev_button.click(
353
+ fn=go_to_prev,
354
+ inputs=[current_index, all_images, all_results],
355
+ outputs=[current_index, current_image, image_counter, analysis_text]
356
+ )
357
+
358
+ next_button.click(
359
+ fn=go_to_next,
360
+ inputs=[current_index, all_images, all_results],
361
+ outputs=[current_index, current_image, image_counter, analysis_text]
362
+ )
363
+
364
+ # Optional: Add additional information
365
+ with gr.Accordion("About", open=False):
366
+ gr.Markdown("""
367
+ ## About this demo
368
+
369
+ This demo generates alternative text for images.
370
+
371
+ - Upload one or more images using the upload button
372
+ - Choose a model and response length for generation
373
+ - Navigate through the images with the Previous and Next buttons
374
+ - Download CSV with all results
375
+
376
+ Developed by the Natural History Museum in Partnership with National Museums Liverpool.
377
+
378
+ If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
379
+ chris.addis@nhm.ac.uk
380
+ """)
381
+
382
+ return demo
383
+
384
+ # Launch the app
385
+ if __name__ == "__main__":
386
+ app = create_demo()
387
+ app.launch()
app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import requests
7
  import json
8
  from dotenv import load_dotenv
9
- # import openai
10
  import base64
11
  import csv
12
  import tempfile
@@ -18,14 +18,17 @@ if os.path.exists(".env"):
18
  load_dotenv()
19
 
20
  from io import BytesIO
 
 
 
 
 
 
 
 
21
 
22
  OR = OpenRouterAPI()
23
- # Ensure GEMINI_API_KEY is set in your environment or .env file
24
- gemini_api_key = os.getenv("GEMINI_API_KEY")
25
- if not gemini_api_key:
26
- print("Warning: GEMINI_API_KEY environment variable not set. Using placeholder.")
27
- # Handle the case where the key might be missing
28
- gemini = OpenRouterAPI(api_key=gemini_api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
29
 
30
  # Path for storing user preferences
31
  PREFERENCES_FILE = "data/user_preferences.csv"
 
6
  import requests
7
  import json
8
  from dotenv import load_dotenv
9
+ import openai
10
  import base64
11
  import csv
12
  import tempfile
 
18
  load_dotenv()
19
 
20
  from io import BytesIO
21
+ import numpy as np
22
+ import requests
23
+ from PIL import Image
24
+
25
+ # import libraries
26
+ from library.utils_model import *
27
+ from library.utils_html import *
28
+ from library.utils_prompt import *
29
 
30
  OR = OpenRouterAPI()
31
+ gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
 
 
 
 
 
32
 
33
  # Path for storing user preferences
34
  PREFERENCES_FILE = "data/user_preferences.csv"