Chris Addis commited on
Commit
ba4800b
·
1 Parent(s): 02cc108

fix additional models

Browse files
app-Copy2.py → .ipynb_checkpoints/app-Copy1-checkpoint.py RENAMED
@@ -28,7 +28,6 @@ from library.utils_html import *
28
  from library.utils_prompt import *
29
 
30
  OR = OpenRouterAPI()
31
- gemini = OpenRouterAPI(api_key = os.getenv("GEMINI_API_KEY"),base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
32
 
33
  # Path for storing user preferences
34
  PREFERENCES_FILE = "data/user_preferences.csv"
@@ -36,14 +35,38 @@ PREFERENCES_FILE = "data/user_preferences.csv"
36
  # Ensure directory exists
37
  os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
38
 
39
- def get_sys_prompt(length="medium"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if length == "short":
41
- dev_prompt = """You are a museum curator tasked with generating alt-text (as defined by W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
42
  elif length == "medium":
43
- dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be between 250-300 characters in length."""
44
  else: # long
45
- dev_prompt = """You are a museum curator tasked with generating long descriptions (as defined in W3C) of museum objects for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maxium of 450 characters."""
46
- return dev_prompt
 
47
 
48
  def create_csv_file_simple(results):
49
  """Create a CSV file from the results and return the path"""
@@ -62,7 +85,6 @@ def create_csv_file_simple(results):
62
  print(f"Error creating CSV: {e}")
63
  return None
64
 
65
-
66
  def get_base_filename(filepath):
67
  if not filepath:
68
  return ""
@@ -72,7 +94,7 @@ def get_base_filename(filepath):
72
 
73
  # Define the Gradio interface
74
  def create_demo():
75
- # --- Reintroduce CSS ---
76
  custom_css = """
77
  /* Container for the image component (#current-image-display is the elem_id of gr.Image) */
78
  #current-image-display {
@@ -93,6 +115,15 @@ def create_demo():
93
  height: auto; /* Use natural height unless constrained by max-height */
94
  display: block; /* Ensure image behaves predictably in flex */
95
  }
 
 
 
 
 
 
 
 
 
96
  """
97
  # --- Pass css to gr.Blocks ---
98
  with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
@@ -112,6 +143,31 @@ def create_demo():
112
  show_share_button=False, show_fullscreen_button=False,
113
  container=False, elem_id="nml-logo")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  with gr.Row():
116
  # Left column: Controls and uploads
117
  with gr.Column(scale=1):
@@ -120,24 +176,48 @@ def create_demo():
120
  file_types=["image"],
121
  file_count="multiple"
122
  )
123
- model_choices = [
124
- ("Gemini 2.0 Flash (default)", "google/gemini-2.0-flash-001"),
125
- ("GPT-4.1 Nano", "gpt-4.1-nano"), ("GPT-4.1 Mini", "gpt-4.1-mini"),
126
- ("GPT-4.1", "gpt-4.1"), ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
127
- ("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
128
- ("Llama 4 Maverick", "meta-llama/llama-4-maverick"),
129
- ("Gemini 2.5 Pro (Experimental, limited)", "gemini-2.5-pro-exp-03-25"),
130
- ("Gemini 2.0 Flash Thinking (Experimental, limited)", "gemini-2.0-flash-thinking-exp-01-21")
131
- ]
132
- default_model_internal_value = "google/gemini-2.0-flash-001"
133
  model_choice = gr.Dropdown(
134
- choices=model_choices, label="Select Model",
135
- value=default_model_internal_value, visible=True
 
136
  )
 
137
  length_choice = gr.Radio(
138
- choices=["short", "medium", "long"], label="Response Length",
139
- value="medium", info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
 
141
  gr.Markdown("### Uploaded Images")
142
  input_gallery = gr.Gallery(
143
  label="Uploaded Image Previews", columns=3, height=150,
@@ -153,7 +233,7 @@ def create_demo():
153
  current_image = gr.Image(
154
  label="Current Image",
155
  type="filepath",
156
- elem_id="current-image-display", # ADDED - for CSS targeting
157
  show_fullscreen_button=True,
158
  show_download_button=False,
159
  show_share_button=False,
@@ -174,8 +254,57 @@ def create_demo():
174
  current_index = gr.State(0)
175
  all_images = gr.State([])
176
  all_results = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # --- Functions (handle_upload, analyze_images, navigators) remain the same ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Handle file uploads
180
  def handle_upload(files, current_paths, current_filenames):
181
  file_paths = []
@@ -194,11 +323,12 @@ def create_demo():
194
  )
195
 
196
  # Analyze images
197
- def analyze_images(image_paths, model_choice, length_choice, filenames):
198
  if not image_paths:
199
  return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None
200
 
201
- sys_prompt = get_sys_prompt(length_choice)
 
202
  image_results = []
203
  analysis_progress = gr.Progress(track_tqdm=True)
204
 
@@ -209,10 +339,6 @@ def create_demo():
209
  prompt0 = prompt_new()
210
  model_name = model_choice
211
  client_to_use = OR # Default client
212
- # Add logic here if you need to switch between OR and gemini clients based on model_name
213
- # Example:
214
- # if model_name.startswith("google/gemini") and gemini:
215
- # client_to_use = gemini
216
 
217
  result = client_to_use.generate_caption(
218
  img, model=model_name, max_image_size=512,
@@ -257,7 +383,7 @@ def create_demo():
257
  # Connect analyze button
258
  analyze_button.click(
259
  fn=analyze_images,
260
- inputs=[image_state, model_choice, length_choice, filename_state],
261
  outputs=[all_images, all_results, current_index, current_image, image_counter,
262
  analysis_text, csv_download]
263
  )
@@ -294,6 +420,5 @@ def create_demo():
294
 
295
  # Launch the app
296
  if __name__ == "__main__":
297
-
298
  app = create_demo()
299
  app.launch()
 
28
  from library.utils_prompt import *
29
 
30
  OR = OpenRouterAPI()
 
31
 
32
  # Path for storing user preferences
33
  PREFERENCES_FILE = "data/user_preferences.csv"
 
35
  # Ensure directory exists
36
  os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
37
 
38
+ # Define model pricing information (approximate costs per 100 image API calls)
39
+ MODEL_PRICING = {
40
+ "google/gemini-2.0-flash-001": "$0.03",
41
+ "gpt-4.1-mini": "$0.07",
42
+ "gpt-4.1": "$0.35",
43
+ "anthropic/claude-3.7-sonnet": "$0.70",
44
+ "google/gemini-2.5-pro-preview-03-25": "$1.20",
45
+ "google/gemini-2.5-flash-preview:thinking": "$0.35",
46
+ "gpt-4.1-nano": "$0.02",
47
+ "openai/chatgpt-4o-latest": "$0.75",
48
+ "meta-llama/llama-4-maverick": "$0.04"
49
+ }
50
+
51
+ def get_sys_prompt(length="medium", photograph=False):
52
+ extra_prompt = ""
53
+
54
+ if photograph:
55
+ object_type = "wildlife photography"
56
+ extra_prompt = " Do not guess the exact species of the animals in the photograph unless you are certain - simply use a broader terms to make less errors e.g. say Swan rather mute Swan or Whooper Swan unless you are certain."
57
+ else:
58
+ object_type = "museum objects"
59
+
60
+ dev_prompt = f"""You are a museum curator tasked with generating long descriptions (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows', 'This is an image of', 'The photograph'. Be precise, concise and avoid filler and subjective statements."""
61
+
62
  if length == "short":
63
+ dev_prompt = f"""You are a museum curator tasked with generating alt-text (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
64
  elif length == "medium":
65
+ dev_prompt += " Repsonses should be a maximum of 250-300 characters."
66
  else: # long
67
+ dev_prompt += " Repsonses should be a maximum of 450 characters."
68
+ return dev_prompt + extra_prompt
69
+
70
 
71
  def create_csv_file_simple(results):
72
  """Create a CSV file from the results and return the path"""
 
85
  print(f"Error creating CSV: {e}")
86
  return None
87
 
 
88
  def get_base_filename(filepath):
89
  if not filepath:
90
  return ""
 
94
 
95
  # Define the Gradio interface
96
  def create_demo():
97
+ # --- Updated CSS with model info styling ---
98
  custom_css = """
99
  /* Container for the image component (#current-image-display is the elem_id of gr.Image) */
100
  #current-image-display {
 
115
  height: auto; /* Use natural height unless constrained by max-height */
116
  display: block; /* Ensure image behaves predictably in flex */
117
  }
118
+
119
+ /* Custom style for model info display */
120
+ #model-info-display {
121
+ font-size: 0.85rem; /* Small font size */
122
+ color: #666; /* Subtle color */
123
+ margin-top: 0.5rem; /* Small top margin */
124
+ margin-bottom: 1rem; /* Bottom margin before next element */
125
+ padding-left: 0.5rem; /* Slight indentation */
126
+ }
127
  """
128
  # --- Pass css to gr.Blocks ---
129
  with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
 
143
  show_share_button=False, show_fullscreen_button=False,
144
  container=False, elem_id="nml-logo")
145
 
146
+ # Store model choices and state
147
+ show_all_models_state = gr.State(False)
148
+
149
+ # Define preferred and additional models directly in the function
150
+ preferred_models = [
151
+ ("Gemini 2.0 Flash (cheap)", "google/gemini-2.0-flash-001"),
152
+ ("GPT-4.1 Mini", "gpt-4.1-mini"),
153
+ ("GPT-4.1 (Recommended)", "gpt-4.1"),
154
+ ("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
155
+ ("Gemini 2.5 Pro", "google/gemini-2.5-pro-preview-03-25"),
156
+ ("Gemini 2.5 Flash Thinking (Recommended)", "google/gemini-2.5-flash-preview:thinking")
157
+ ]
158
+
159
+ additional_models = [
160
+ ("GPT-4.1 Nano", "gpt-4.1-nano"),
161
+ ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
162
+ ("Llama 4 Maverick", "meta-llama/llama-4-maverick")
163
+ ]
164
+
165
+ # Calculate all models once
166
+ all_models_list = preferred_models + additional_models
167
+
168
+ # Default model value
169
+ default_model = "google/gemini-2.0-flash-001"
170
+
171
  with gr.Row():
172
  # Left column: Controls and uploads
173
  with gr.Column(scale=1):
 
176
  file_types=["image"],
177
  file_count="multiple"
178
  )
179
+
180
+ # Model dropdown
 
 
 
 
 
 
 
 
181
  model_choice = gr.Dropdown(
182
+ choices=preferred_models,
183
+ label="Select Model",
184
+ value=default_model
185
  )
186
+
187
  length_choice = gr.Radio(
188
+ choices=["short", "medium", "long"],
189
+ label="Response Length",
190
+ value="medium",
191
+ info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
192
+ )
193
+
194
+ # Advanced settings accordion
195
+ with gr.Accordion("Advanced Settings", open=False):
196
+ show_all_models = gr.Checkbox(
197
+ label="Show Additional Models",
198
+ value=False,
199
+ info="Display additional model options in the dropdown above"
200
+ )
201
+
202
+ content_type = gr.Radio(
203
+ choices=["Museum Object", "Photography"],
204
+ label="Content Type",
205
+ value="Museum Object"
206
+ )
207
+
208
+ # Find the default model's display name
209
+ default_model_name = "Unknown Model"
210
+ for name, value in preferred_models:
211
+ if value == default_model:
212
+ default_model_name = name
213
+ break
214
+
215
+ model_info = gr.Markdown(
216
+ f"""**Current Model**: {default_model_name}
217
+ **Estimated cost per 100 Images**: {MODEL_PRICING[default_model]}""",
218
+ elem_id="model-info-display"
219
  )
220
+
221
  gr.Markdown("### Uploaded Images")
222
  input_gallery = gr.Gallery(
223
  label="Uploaded Image Previews", columns=3, height=150,
 
233
  current_image = gr.Image(
234
  label="Current Image",
235
  type="filepath",
236
+ elem_id="current-image-display",
237
  show_fullscreen_button=True,
238
  show_download_button=False,
239
  show_share_button=False,
 
254
  current_index = gr.State(0)
255
  all_images = gr.State([])
256
  all_results = gr.State([])
257
+
258
+ # Handle checkbox change to update model dropdown - modern version
259
+ def toggle_models(show_all, current_model):
260
+ # Make a fresh copy of the models lists to avoid any reference issues
261
+ preferred_choices = list(preferred_models)
262
+ all_choices = list(all_models_list)
263
+
264
+ if show_all:
265
+ # When showing all models, use the fresh copy of all models
266
+ return gr.Dropdown(choices=all_choices, value=current_model)
267
+ else:
268
+ # Check if current model is in preferred models list
269
+ preferred_values = [value for _, value in preferred_choices]
270
+
271
+ if current_model in preferred_values:
272
+ # Keep the current model if it's in preferred models
273
+ return gr.Dropdown(choices=preferred_choices, value=current_model)
274
+ else:
275
+ # Reset to default model if current model is not in preferred models
276
+ return gr.Dropdown(choices=preferred_choices, value=default_model)
277
 
278
+ # Update model info when model selection changes
279
+ def update_model_info(model_value):
280
+ # Find display name
281
+ model_name = "Unknown Model"
282
+ for name, value in all_models_list:
283
+ if value == model_value:
284
+ model_name = name
285
+ break
286
+
287
+ # Get cost
288
+ cost = MODEL_PRICING.get(model_value, "Unknown")
289
+
290
+ # Create markdown
291
+ return f"""**Current Model**: {model_name}
292
+ **Estimated cost per 100 Images**: {cost}"""
293
+
294
+ # Connect checkbox to toggle model choices
295
+ show_all_models.change(
296
+ fn=toggle_models,
297
+ inputs=[show_all_models, model_choice],
298
+ outputs=[model_choice]
299
+ )
300
+
301
+ # Connect model selection to update info
302
+ model_choice.change(
303
+ fn=update_model_info,
304
+ inputs=[model_choice],
305
+ outputs=[model_info]
306
+ )
307
+
308
  # Handle file uploads
309
  def handle_upload(files, current_paths, current_filenames):
310
  file_paths = []
 
323
  )
324
 
325
  # Analyze images
326
+ def analyze_images(image_paths, model_choice, length_choice, filenames, content_type_choice):
327
  if not image_paths:
328
  return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None
329
 
330
+ is_photography = content_type_choice == "Photography"
331
+ sys_prompt = get_sys_prompt(length_choice, photograph=is_photography)
332
  image_results = []
333
  analysis_progress = gr.Progress(track_tqdm=True)
334
 
 
339
  prompt0 = prompt_new()
340
  model_name = model_choice
341
  client_to_use = OR # Default client
 
 
 
 
342
 
343
  result = client_to_use.generate_caption(
344
  img, model=model_name, max_image_size=512,
 
383
  # Connect analyze button
384
  analyze_button.click(
385
  fn=analyze_images,
386
+ inputs=[image_state, model_choice, length_choice, filename_state, content_type],
387
  outputs=[all_images, all_results, current_index, current_image, image_counter,
388
  analysis_text, csv_download]
389
  )
 
420
 
421
  # Launch the app
422
  if __name__ == "__main__":
 
423
  app = create_demo()
424
  app.launch()
README.md CHANGED
@@ -7,6 +7,7 @@ sdk: gradio
7
  sdk_version: 5.24.0
8
  app_file: app.py
9
  pinned: false
 
10
  license: mit
11
  short_description: Generate Alt-Text for Museum and Gallery Objects!
12
  ---
 
7
  sdk_version: 5.24.0
8
  app_file: app.py
9
  pinned: false
10
+ hf_oauth: true
11
  license: mit
12
  short_description: Generate Alt-Text for Museum and Gallery Objects!
13
  ---
app-Copy1.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import requests
7
+ import json
8
+ from dotenv import load_dotenv
9
+ import openai
10
+ import base64
11
+ import csv
12
+ import tempfile
13
+ import datetime
14
+
15
+ # Load environment variables from .env file if it exists (for local development)
16
+ # On Hugging Face Spaces, the secrets are automatically available as environment variables
17
+ if os.path.exists(".env"):
18
+ load_dotenv()
19
+
20
+ from io import BytesIO
21
+ import numpy as np
22
+ import requests
23
+ from PIL import Image
24
+
25
+ # import libraries
26
+ from library.utils_model import *
27
+ from library.utils_html import *
28
+ from library.utils_prompt import *
29
+
30
+ OR = OpenRouterAPI()
31
+
32
+ # Path for storing user preferences
33
+ PREFERENCES_FILE = "data/user_preferences.csv"
34
+
35
+ # Ensure directory exists
36
+ os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
37
+
38
+ # Define model pricing information (approximate costs per 100 image API calls)
39
+ MODEL_PRICING = {
40
+ "google/gemini-2.0-flash-001": "$0.03",
41
+ "gpt-4.1-mini": "$0.07",
42
+ "gpt-4.1": "$0.35",
43
+ "anthropic/claude-3.7-sonnet": "$0.70",
44
+ "google/gemini-2.5-pro-preview-03-25": "$1.20",
45
+ "google/gemini-2.5-flash-preview:thinking": "$0.35",
46
+ "gpt-4.1-nano": "$0.02",
47
+ "openai/chatgpt-4o-latest": "$0.75",
48
+ "meta-llama/llama-4-maverick": "$0.04"
49
+ }
50
+
51
+ def get_sys_prompt(length="medium", photograph=False):
52
+ extra_prompt = ""
53
+
54
+ if photograph:
55
+ object_type = "wildlife photography"
56
+ extra_prompt = " Do not guess the exact species of the animals in the photograph unless you are certain - simply use a broader terms to make less errors e.g. say Swan rather mute Swan or Whooper Swan unless you are certain."
57
+ else:
58
+ object_type = "museum objects"
59
+
60
+ dev_prompt = f"""You are a museum curator tasked with generating long descriptions (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows', 'This is an image of', 'The photograph'. Be precise, concise and avoid filler and subjective statements."""
61
+
62
+ if length == "short":
63
+ dev_prompt = f"""You are a museum curator tasked with generating alt-text (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
64
+ elif length == "medium":
65
+ dev_prompt += " Repsonses should be a maximum of 250-300 characters."
66
+ else: # long
67
+ dev_prompt += " Repsonses should be a maximum of 450 characters."
68
+ return dev_prompt + extra_prompt
69
+
70
+
71
+ def create_csv_file_simple(results):
72
+ """Create a CSV file from the results and return the path"""
73
+ try:
74
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, newline='', encoding='utf-8') as f:
75
+ path = f.name
76
+ writer = csv.writer(f)
77
+ writer.writerow(['image_id', 'content'])
78
+ for result in results:
79
+ writer.writerow([
80
+ result.get('image_id', ''),
81
+ result.get('content', '')
82
+ ])
83
+ return path
84
+ except Exception as e:
85
+ print(f"Error creating CSV: {e}")
86
+ return None
87
+
88
+ def get_base_filename(filepath):
89
+ if not filepath:
90
+ return ""
91
+ basename = os.path.basename(filepath)
92
+ filename = os.path.splitext(basename)[0]
93
+ return filename
94
+
95
+ # Define the Gradio interface
96
+ def create_demo():
97
+ # --- Updated CSS with model info styling ---
98
+ custom_css = """
99
+ /* Container for the image component (#current-image-display is the elem_id of gr.Image) */
100
+ #current-image-display {
101
+ height: 600px; /* Define container height */
102
+ width: 100%; /* Define container width (takes column width) */
103
+ display: flex; /* Use flexbox for alignment */
104
+ justify-content: center; /* Center content horizontally */
105
+ align-items: center; /* Center content vertically */
106
+ overflow: hidden; /* Hide any potential overflow from container */
107
+ }
108
+
109
+ /* The actual <img> element inside the container */
110
+ #current-image-display img {
111
+ object-fit: contain !important; /* Scale keeping aspect ratio, within bounds */
112
+ max-width: 100%; /* Prevent image exceeding container width */
113
+ max-height: 600px !important; /* Prevent image exceeding container height */
114
+ width: auto; /* Use natural width unless constrained by max-width */
115
+ height: auto; /* Use natural height unless constrained by max-height */
116
+ display: block; /* Ensure image behaves predictably in flex */
117
+ }
118
+
119
+ /* Custom style for model info display */
120
+ #model-info-display {
121
+ font-size: 0.85rem; /* Small font size */
122
+ color: #666; /* Subtle color */
123
+ margin-top: 0.5rem; /* Small top margin */
124
+ margin-bottom: 1rem; /* Bottom margin before next element */
125
+ padding-left: 0.5rem; /* Slight indentation */
126
+ }
127
+ """
128
+ # --- Pass css to gr.Blocks ---
129
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
130
+ with gr.Row():
131
+ with gr.Column(scale=3):
132
+ gr.Markdown("# MATCHA: Museum Alt-Text for Cultural Heritage with AI 🍵 🌿")
133
+ gr.Markdown("Upload one or more images to generate accessible alternative text (designed to meet WCAG Guidelines)")
134
+ gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme")
135
+ with gr.Column(scale=1):
136
+ with gr.Row():
137
+ gr.Image("images/nhm_logo.png", show_label=False, height=120,
138
+ interactive=False, show_download_button=False,
139
+ show_share_button=False, show_fullscreen_button=False,
140
+ container=False, elem_id="nhm-logo")
141
+ gr.Image("images/nml_logo.png", show_label=False, height=120,
142
+ interactive=False, show_download_button=False,
143
+ show_share_button=False, show_fullscreen_button=False,
144
+ container=False, elem_id="nml-logo")
145
+
146
+ # Store model choices and state
147
+ show_all_models_state = gr.State(False)
148
+
149
+ # Define preferred and additional models directly in the function
150
+ preferred_models = [
151
+ ("Gemini 2.0 Flash (cheap)", "google/gemini-2.0-flash-001"),
152
+ ("GPT-4.1 Mini", "gpt-4.1-mini"),
153
+ ("GPT-4.1 (Recommended)", "gpt-4.1"),
154
+ ("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
155
+ ("Gemini 2.5 Pro", "google/gemini-2.5-pro-preview-03-25"),
156
+ ("Gemini 2.5 Flash Thinking (Recommended)", "google/gemini-2.5-flash-preview:thinking")
157
+ ]
158
+
159
+ additional_models = [
160
+ ("GPT-4.1 Nano", "gpt-4.1-nano"),
161
+ ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
162
+ ("Llama 4 Maverick", "meta-llama/llama-4-maverick")
163
+ ]
164
+
165
+ # Calculate all models once
166
+ all_models_list = preferred_models + additional_models
167
+
168
+ # Default model value
169
+ default_model = "google/gemini-2.0-flash-001"
170
+
171
+ with gr.Row():
172
+ # Left column: Controls and uploads
173
+ with gr.Column(scale=1):
174
+ upload_button = gr.UploadButton(
175
+ "Click to Upload Images",
176
+ file_types=["image"],
177
+ file_count="multiple"
178
+ )
179
+
180
+ # Model dropdown
181
+ model_choice = gr.Dropdown(
182
+ choices=preferred_models,
183
+ label="Select Model",
184
+ value=default_model
185
+ )
186
+
187
+ length_choice = gr.Radio(
188
+ choices=["short", "medium", "long"],
189
+ label="Response Length",
190
+ value="medium",
191
+ info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
192
+ )
193
+
194
+ # Advanced settings accordion
195
+ with gr.Accordion("Advanced Settings", open=False):
196
+ show_all_models = gr.Checkbox(
197
+ label="Show Additional Models",
198
+ value=False,
199
+ info="Display additional model options in the dropdown above"
200
+ )
201
+
202
+ content_type = gr.Radio(
203
+ choices=["Museum Object", "Photography"],
204
+ label="Content Type",
205
+ value="Museum Object"
206
+ )
207
+
208
+ # Find the default model's display name
209
+ default_model_name = "Unknown Model"
210
+ for name, value in preferred_models:
211
+ if value == default_model:
212
+ default_model_name = name
213
+ break
214
+
215
+ model_info = gr.Markdown(
216
+ f"""**Current Model**: {default_model_name}
217
+ **Estimated cost per 100 Images**: {MODEL_PRICING[default_model]}""",
218
+ elem_id="model-info-display"
219
+ )
220
+
221
+ gr.Markdown("### Uploaded Images")
222
+ input_gallery = gr.Gallery(
223
+ label="Uploaded Image Previews", columns=3, height=150,
224
+ object_fit="contain", show_label=False
225
+ )
226
+ analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
227
+ image_state = gr.State([])
228
+ filename_state = gr.State([])
229
+ csv_download = gr.File(label="Download CSV Results")
230
+
231
+ # Right column: Display area
232
+ with gr.Column(scale=2):
233
+ current_image = gr.Image(
234
+ label="Current Image",
235
+ type="filepath",
236
+ elem_id="current-image-display",
237
+ show_fullscreen_button=True,
238
+ show_download_button=False,
239
+ show_share_button=False,
240
+ show_label=False
241
+ )
242
+
243
+ with gr.Row():
244
+ prev_button = gr.Button("← Previous", size="sm")
245
+ image_counter = gr.Markdown("0 of 0", elem_id="image-counter")
246
+ next_button = gr.Button("Next →", size="sm")
247
+
248
+ gr.Markdown("### Generated Alt-text")
249
+ analysis_text = gr.Textbox(
250
+ label="Generated Text",
251
+ value="Upload images and click 'Generate Alt-Text'.",
252
+ lines=6, max_lines=10, interactive=True, show_label=False
253
+ )
254
+ current_index = gr.State(0)
255
+ all_images = gr.State([])
256
+ all_results = gr.State([])
257
+
258
+ # Handle checkbox change to update model dropdown - modern version
259
+ def toggle_models(show_all, current_model):
260
+ # Make a fresh copy of the models lists to avoid any reference issues
261
+ preferred_choices = list(preferred_models)
262
+ all_choices = list(all_models_list)
263
+
264
+ if show_all:
265
+ # When showing all models, use the fresh copy of all models
266
+ return gr.Dropdown(choices=all_choices, value=current_model)
267
+ else:
268
+ # Check if current model is in preferred models list
269
+ preferred_values = [value for _, value in preferred_choices]
270
+
271
+ if current_model in preferred_values:
272
+ # Keep the current model if it's in preferred models
273
+ return gr.Dropdown(choices=preferred_choices, value=current_model)
274
+ else:
275
+ # Reset to default model if current model is not in preferred models
276
+ return gr.Dropdown(choices=preferred_choices, value=default_model)
277
+
278
+ # Update model info when model selection changes
279
+ def update_model_info(model_value):
280
+ # Find display name
281
+ model_name = "Unknown Model"
282
+ for name, value in all_models_list:
283
+ if value == model_value:
284
+ model_name = name
285
+ break
286
+
287
+ # Get cost
288
+ cost = MODEL_PRICING.get(model_value, "Unknown")
289
+
290
+ # Create markdown
291
+ return f"""**Current Model**: {model_name}
292
+ **Estimated cost per 100 Images**: {cost}"""
293
+
294
+ # Connect checkbox to toggle model choices
295
+ show_all_models.change(
296
+ fn=toggle_models,
297
+ inputs=[show_all_models, model_choice],
298
+ outputs=[model_choice]
299
+ )
300
+
301
+ # Connect model selection to update info
302
+ model_choice.change(
303
+ fn=update_model_info,
304
+ inputs=[model_choice],
305
+ outputs=[model_info]
306
+ )
307
+
308
+ # Handle file uploads
309
+ def handle_upload(files, current_paths, current_filenames):
310
+ file_paths = []
311
+ file_names = []
312
+ if files:
313
+ for file in files:
314
+ file_paths.append(file.name)
315
+ file_names.append(get_base_filename(file.name))
316
+ return file_paths, file_paths, file_names, 0, None, "0 of 0", "Upload images and click 'Generate Alt-Text'."
317
+
318
+ upload_button.upload(
319
+ fn=handle_upload,
320
+ inputs=[upload_button, image_state, filename_state],
321
+ outputs=[input_gallery, image_state, filename_state,
322
+ current_index, current_image, image_counter, analysis_text]
323
+ )
324
+
325
+ # Analyze images
326
+ def analyze_images(image_paths, model_choice, length_choice, filenames, content_type_choice):
327
+ if not image_paths:
328
+ return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None
329
+
330
+ is_photography = content_type_choice == "Photography"
331
+ sys_prompt = get_sys_prompt(length_choice, photograph=is_photography)
332
+ image_results = []
333
+ analysis_progress = gr.Progress(track_tqdm=True)
334
+
335
+ for i, image_path in enumerate(analysis_progress.tqdm(image_paths, desc="Analyzing Images")):
336
+ image_id = filenames[i] if i < len(filenames) and filenames[i] else f"Image_{i+1}_{os.path.basename(image_path)}"
337
+ try:
338
+ img = Image.open(image_path)
339
+ prompt0 = prompt_new()
340
+ model_name = model_choice
341
+ client_to_use = OR # Default client
342
+
343
+ result = client_to_use.generate_caption(
344
+ img, model=model_name, max_image_size=512,
345
+ prompt=prompt0, prompt_dev=sys_prompt, temperature=1
346
+ )
347
+ image_results.append({"image_id": image_id, "content": result.strip()})
348
+ except FileNotFoundError:
349
+ error_message = f"Error: File not found at path '{image_path}'"
350
+ print(error_message)
351
+ image_results.append({"image_id": image_id, "content": error_message})
352
+ except Exception as e:
353
+ error_message = f"Error processing {image_id}: {str(e)}"
354
+ print(error_message)
355
+ image_results.append({"image_id": image_id, "content": error_message})
356
+
357
+ csv_path = create_csv_file_simple(image_results)
358
+ initial_image = image_paths[0] if image_paths else None
359
+ initial_counter = f"1 of {len(image_paths)}" if image_paths else "0 of 0"
360
+ initial_text = image_results[0]["content"] if image_results else "Analysis complete, but no results generated."
361
+
362
+ return (image_paths, image_results, 0, initial_image, initial_counter,
363
+ initial_text, csv_path)
364
+
365
+ # Navigate previous
366
+ def go_to_prev(current_idx, images, results):
367
+ if not images or not results or len(images) == 0:
368
+ return current_idx, None, "0 of 0", ""
369
+ new_idx = (current_idx - 1 + len(images)) % len(images)
370
+ counter_text = f"{new_idx + 1} of {len(images)}"
371
+ result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
372
+ return (new_idx, images[new_idx], counter_text, result_content)
373
+
374
+ # Navigate next
375
+ def go_to_next(current_idx, images, results):
376
+ if not images or not results or len(images) == 0:
377
+ return current_idx, None, "0 of 0", ""
378
+ new_idx = (current_idx + 1) % len(images)
379
+ counter_text = f"{new_idx + 1} of {len(images)}"
380
+ result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
381
+ return (new_idx, images[new_idx], counter_text, result_content)
382
+
383
+ # Connect analyze button
384
+ analyze_button.click(
385
+ fn=analyze_images,
386
+ inputs=[image_state, model_choice, length_choice, filename_state, content_type],
387
+ outputs=[all_images, all_results, current_index, current_image, image_counter,
388
+ analysis_text, csv_download]
389
+ )
390
+
391
+ # Connect navigation buttons
392
+ prev_button.click(
393
+ fn=go_to_prev, inputs=[current_index, all_images, all_results],
394
+ outputs=[current_index, current_image, image_counter, analysis_text], queue=False
395
+ )
396
+ next_button.click(
397
+ fn=go_to_next, inputs=[current_index, all_images, all_results],
398
+ outputs=[current_index, current_image, image_counter, analysis_text], queue=False
399
+ )
400
+
401
+ # About section
402
+ with gr.Accordion("About", open=False):
403
+ gr.Markdown("""
404
+ ## About MATCHA 🍵:
405
+
406
+ This demo generates alternative text for images.
407
+
408
+ - Upload one or more images using the upload button
409
+ - Choose a model and response length for generation
410
+ - Navigate through the images with the Previous and Next buttons
411
+ - Download CSV with all results
412
+
413
+ Developed by the Natural History Museum in Partnership with National Museums Liverpool.
414
+
415
+ If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
416
+ chris.addis@nhm.ac.uk
417
+ """)
418
+
419
+ return demo
420
+
421
+ # Launch the app
422
+ if __name__ == "__main__":
423
+ app = create_demo()
424
+ app.launch()
app.py CHANGED
@@ -1,425 +1,107 @@
1
  import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import io
5
  import os
6
- import requests
7
- import json
8
- from dotenv import load_dotenv
9
- import openai
10
- import base64
11
- import csv
12
- import tempfile
13
- import datetime
14
 
15
- # Load environment variables from .env file if it exists (for local development)
16
- # On Hugging Face Spaces, the secrets are automatically available as environment variables
17
- if os.path.exists(".env"):
18
- load_dotenv()
19
-
20
- from io import BytesIO
21
- import numpy as np
22
- import requests
23
- from PIL import Image
24
-
25
- # import libraries
26
- from library.utils_model import *
27
- from library.utils_html import *
28
- from library.utils_prompt import *
29
-
30
- OR = OpenRouterAPI()
31
-
32
- # Path for storing user preferences
33
- PREFERENCES_FILE = "data/user_preferences.csv"
34
-
35
- # Ensure directory exists
36
- os.makedirs(os.path.dirname(PREFERENCES_FILE), exist_ok=True)
37
-
38
- # Define model pricing information (approximate costs per 100 image API calls)
39
- MODEL_PRICING = {
40
- "google/gemini-2.0-flash-001": "$0.03",
41
- "gpt-4.1-mini": "$0.07",
42
- "gpt-4.1": "$0.35",
43
- "anthropic/claude-3.7-sonnet": "$0.70",
44
- "google/gemini-2.5-pro-preview-03-25": "$1.20",
45
- "google/gemini-2.5-flash-preview:thinking": "$0.35",
46
- "gpt-4.1-nano": "$0.02",
47
- "openai/chatgpt-4o-latest": "$0.75",
48
- "meta-llama/llama-4-maverick": "$0.04"
49
- }
50
-
51
- def get_sys_prompt(length="medium", photograph=False):
52
- extra_prompt = ""
53
-
54
- if photograph:
55
- object_type = "wildlife photography"
56
- extra_prompt = " Do not guess the exact species of the animals in the photograph unless you are certain - simply use a broader terms e.g. the genus or family to make less errors, "
57
- else:
58
- object_type = "museum objects"
59
-
60
- dev_prompt = f"""You are a museum curator tasked with generating long descriptions (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows', 'This is an image of', 'The photograph'. Be precise, concise and avoid filler and subjective statements."""
61
-
62
- if length == "short":
63
- dev_prompt = f"""You are a museum curator tasked with generating alt-text (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
64
- elif length == "medium":
65
- dev_prompt += " Repsonses should be a maximum of 250-300 characters."
66
- else: # long
67
- dev_prompt += " Repsonses should be a maximum of 450 characters."
68
- return dev_prompt + extra_prompt
69
-
70
-
71
- def create_csv_file_simple(results):
72
- """Create a CSV file from the results and return the path"""
73
- try:
74
- with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, newline='', encoding='utf-8') as f:
75
- path = f.name
76
- writer = csv.writer(f)
77
- writer.writerow(['image_id', 'content'])
78
- for result in results:
79
- writer.writerow([
80
- result.get('image_id', ''),
81
- result.get('content', '')
82
- ])
83
- return path
84
- except Exception as e:
85
- print(f"Error creating CSV: {e}")
86
- return None
87
-
88
- def get_base_filename(filepath):
89
- if not filepath:
90
- return ""
91
- basename = os.path.basename(filepath)
92
- filename = os.path.splitext(basename)[0]
93
- return filename
94
-
95
- # Define the Gradio interface
96
- def create_demo():
97
- # --- Updated CSS with model info styling ---
98
- custom_css = """
99
- /* Container for the image component (#current-image-display is the elem_id of gr.Image) */
100
- #current-image-display {
101
- height: 600px; /* Define container height */
102
- width: 100%; /* Define container width (takes column width) */
103
- display: flex; /* Use flexbox for alignment */
104
- justify-content: center; /* Center content horizontally */
105
- align-items: center; /* Center content vertically */
106
- overflow: hidden; /* Hide any potential overflow from container */
107
- }
108
-
109
- /* The actual <img> element inside the container */
110
- #current-image-display img {
111
- object-fit: contain !important; /* Scale keeping aspect ratio, within bounds */
112
- max-width: 100%; /* Prevent image exceeding container width */
113
- max-height: 600px !important; /* Prevent image exceeding container height */
114
- width: auto; /* Use natural width unless constrained by max-width */
115
- height: auto; /* Use natural height unless constrained by max-height */
116
- display: block; /* Ensure image behaves predictably in flex */
117
- }
118
-
119
- /* Custom style for model info display */
120
- #model-info-display {
121
- font-size: 0.85rem; /* Small font size */
122
- color: #666; /* Subtle color */
123
- margin-top: 0.5rem; /* Small top margin */
124
- margin-bottom: 1rem; /* Bottom margin before next element */
125
- padding-left: 0.5rem; /* Slight indentation */
126
- }
127
  """
128
- # --- Pass css to gr.Blocks ---
129
- with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
130
- with gr.Row():
131
- with gr.Column(scale=3):
132
- gr.Markdown("# MATCHA: Museum Alt-Text for Cultural Heritage with AI 🍵 🌿")
133
- gr.Markdown("Upload one or more images to generate accessible alternative text (designed to meet WCAG Guidelines)")
134
- gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme")
135
- with gr.Column(scale=1):
136
- with gr.Row():
137
- gr.Image("images/nhm_logo.png", show_label=False, height=120,
138
- interactive=False, show_download_button=False,
139
- show_share_button=False, show_fullscreen_button=False,
140
- container=False, elem_id="nhm-logo")
141
- gr.Image("images/nml_logo.png", show_label=False, height=120,
142
- interactive=False, show_download_button=False,
143
- show_share_button=False, show_fullscreen_button=False,
144
- container=False, elem_id="nml-logo")
145
-
146
- # Store model choices and state
147
- show_all_models_state = gr.State(False)
148
-
149
- # Define preferred and additional models directly in the function
150
- preferred_models = [
151
- ("Gemini 2.0 Flash (cheap)", "google/gemini-2.0-flash-001"),
152
- ("GPT-4.1 Mini", "gpt-4.1-mini"),
153
- ("GPT-4.1 (Recommended)", "gpt-4.1"),
154
- ("Claude 3.7 Sonnet", "anthropic/claude-3.7-sonnet"),
155
- ("Gemini 2.5 Pro", "google/gemini-2.5-pro-preview-03-25"),
156
- ("Gemini 2.5 Flash Thinking (Recommended)", "google/gemini-2.5-flash-preview:thinking")
157
- ]
158
-
159
- additional_models = [
160
- ("GPT-4.1 Nano", "gpt-4.1-nano"),
161
- ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
162
- ("Llama 4 Maverick", "meta-llama/llama-4-maverick")
163
- ]
164
-
165
- # Calculate all models once
166
- all_models_list = preferred_models + additional_models
167
-
168
- # Default model value
169
- default_model = "google/gemini-2.0-flash-001"
170
-
171
- with gr.Row():
172
- # Left column: Controls and uploads
173
- with gr.Column(scale=1):
174
- upload_button = gr.UploadButton(
175
- "Click to Upload Images",
176
- file_types=["image"],
177
- file_count="multiple"
178
- )
179
-
180
- # Model dropdown
181
- model_choice = gr.Dropdown(
182
- choices=preferred_models,
183
- label="Select Model",
184
- value=default_model
185
- )
186
-
187
- length_choice = gr.Radio(
188
- choices=["short", "medium", "long"],
189
- label="Response Length",
190
- value="medium",
191
- info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
192
- )
193
-
194
- # Advanced settings accordion
195
- with gr.Accordion("Advanced Settings", open=False):
196
- show_all_models = gr.Checkbox(
197
- label="Show Additional Models",
198
- value=False,
199
- info="Display additional model options in the dropdown above"
200
- )
201
-
202
- content_type = gr.Radio(
203
- choices=["Museum Object", "Photography"],
204
- label="Content Type",
205
- value="Museum Object"
206
- )
207
-
208
- # Find the default model's display name
209
- default_model_name = "Unknown Model"
210
- for name, value in preferred_models:
211
- if value == default_model:
212
- default_model_name = name
213
- break
214
-
215
- # Model info outside the accordion
216
- model_info = gr.Markdown(
217
- f"""**Current Model**: {default_model_name}
218
- **Cost per 100 Images**: {MODEL_PRICING[default_model]}""",
219
- elem_id="model-info-display"
220
- )
221
-
222
- gr.Markdown("### Uploaded Images")
223
- input_gallery = gr.Gallery(
224
- label="Uploaded Image Previews", columns=3, height=150,
225
- object_fit="contain", show_label=False
226
- )
227
- analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
228
- image_state = gr.State([])
229
- filename_state = gr.State([])
230
- csv_download = gr.File(label="Download CSV Results")
231
-
232
- # Right column: Display area
233
- with gr.Column(scale=2):
234
- current_image = gr.Image(
235
- label="Current Image",
236
- type="filepath",
237
- elem_id="current-image-display",
238
- show_fullscreen_button=True,
239
- show_download_button=False,
240
- show_share_button=False,
241
- show_label=False
242
- )
243
-
244
- with gr.Row():
245
- prev_button = gr.Button("← Previous", size="sm")
246
- image_counter = gr.Markdown("0 of 0", elem_id="image-counter")
247
- next_button = gr.Button("Next →", size="sm")
248
-
249
- gr.Markdown("### Generated Alt-text")
250
- analysis_text = gr.Textbox(
251
- label="Generated Text",
252
- value="Upload images and click 'Generate Alt-Text'.",
253
- lines=6, max_lines=10, interactive=True, show_label=False
254
- )
255
- current_index = gr.State(0)
256
- all_images = gr.State([])
257
- all_results = gr.State([])
258
-
259
- # Handle checkbox change to update model dropdown - modern version
260
- def toggle_models(show_all, current_model):
261
- # Make a fresh copy of the models lists to avoid any reference issues
262
- preferred_choices = list(preferred_models)
263
- all_choices = list(all_models_list)
264
-
265
- if show_all:
266
- # When showing all models, use the fresh copy of all models
267
- return gr.Dropdown(choices=all_choices, value=current_model)
268
- else:
269
- # Check if current model is in preferred models list
270
- preferred_values = [value for _, value in preferred_choices]
271
-
272
- if current_model in preferred_values:
273
- # Keep the current model if it's in preferred models
274
- return gr.Dropdown(choices=preferred_choices, value=current_model)
275
- else:
276
- # Reset to default model if current model is not in preferred models
277
- return gr.Dropdown(choices=preferred_choices, value=default_model)
278
-
279
- # Update model info when model selection changes
280
- def update_model_info(model_value):
281
- # Find display name
282
- model_name = "Unknown Model"
283
- for name, value in all_models_list:
284
- if value == model_value:
285
- model_name = name
286
- break
287
-
288
- # Get cost
289
- cost = MODEL_PRICING.get(model_value, "Unknown")
290
-
291
- # Create markdown
292
- return f"""**Current Model**: {model_name}
293
- **Cost per 100 Images**: {cost}"""
294
-
295
- # Connect checkbox to toggle model choices
296
- show_all_models.change(
297
- fn=toggle_models,
298
- inputs=[show_all_models, model_choice],
299
- outputs=[model_choice]
300
- )
301
-
302
- # Connect model selection to update info
303
- model_choice.change(
304
- fn=update_model_info,
305
- inputs=[model_choice],
306
- outputs=[model_info]
307
- )
308
-
309
- # Handle file uploads
310
- def handle_upload(files, current_paths, current_filenames):
311
- file_paths = []
312
- file_names = []
313
- if files:
314
- for file in files:
315
- file_paths.append(file.name)
316
- file_names.append(get_base_filename(file.name))
317
- return file_paths, file_paths, file_names, 0, None, "0 of 0", "Upload images and click 'Generate Alt-Text'."
318
-
319
- upload_button.upload(
320
- fn=handle_upload,
321
- inputs=[upload_button, image_state, filename_state],
322
- outputs=[input_gallery, image_state, filename_state,
323
- current_index, current_image, image_counter, analysis_text]
324
- )
325
-
326
- # Analyze images
327
- def analyze_images(image_paths, model_choice, length_choice, filenames, content_type_choice):
328
- if not image_paths:
329
- return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None
330
-
331
- is_photography = content_type_choice == "Photography"
332
- sys_prompt = get_sys_prompt(length_choice, photograph=is_photography)
333
- image_results = []
334
- analysis_progress = gr.Progress(track_tqdm=True)
335
-
336
- for i, image_path in enumerate(analysis_progress.tqdm(image_paths, desc="Analyzing Images")):
337
- image_id = filenames[i] if i < len(filenames) and filenames[i] else f"Image_{i+1}_{os.path.basename(image_path)}"
338
- try:
339
- img = Image.open(image_path)
340
- prompt0 = prompt_new()
341
- model_name = model_choice
342
- client_to_use = OR # Default client
343
-
344
- result = client_to_use.generate_caption(
345
- img, model=model_name, max_image_size=512,
346
- prompt=prompt0, prompt_dev=sys_prompt, temperature=1
347
- )
348
- image_results.append({"image_id": image_id, "content": result.strip()})
349
- except FileNotFoundError:
350
- error_message = f"Error: File not found at path '{image_path}'"
351
- print(error_message)
352
- image_results.append({"image_id": image_id, "content": error_message})
353
- except Exception as e:
354
- error_message = f"Error processing {image_id}: {str(e)}"
355
- print(error_message)
356
- image_results.append({"image_id": image_id, "content": error_message})
357
-
358
- csv_path = create_csv_file_simple(image_results)
359
- initial_image = image_paths[0] if image_paths else None
360
- initial_counter = f"1 of {len(image_paths)}" if image_paths else "0 of 0"
361
- initial_text = image_results[0]["content"] if image_results else "Analysis complete, but no results generated."
362
-
363
- return (image_paths, image_results, 0, initial_image, initial_counter,
364
- initial_text, csv_path)
365
-
366
- # Navigate previous
367
- def go_to_prev(current_idx, images, results):
368
- if not images or not results or len(images) == 0:
369
- return current_idx, None, "0 of 0", ""
370
- new_idx = (current_idx - 1 + len(images)) % len(images)
371
- counter_text = f"{new_idx + 1} of {len(images)}"
372
- result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
373
- return (new_idx, images[new_idx], counter_text, result_content)
374
-
375
- # Navigate next
376
- def go_to_next(current_idx, images, results):
377
- if not images or not results or len(images) == 0:
378
- return current_idx, None, "0 of 0", ""
379
- new_idx = (current_idx + 1) % len(images)
380
- counter_text = f"{new_idx + 1} of {len(images)}"
381
- result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
382
- return (new_idx, images[new_idx], counter_text, result_content)
383
-
384
- # Connect analyze button
385
- analyze_button.click(
386
- fn=analyze_images,
387
- inputs=[image_state, model_choice, length_choice, filename_state, content_type],
388
- outputs=[all_images, all_results, current_index, current_image, image_counter,
389
- analysis_text, csv_download]
390
- )
391
-
392
- # Connect navigation buttons
393
- prev_button.click(
394
- fn=go_to_prev, inputs=[current_index, all_images, all_results],
395
- outputs=[current_index, current_image, image_counter, analysis_text], queue=False
396
- )
397
- next_button.click(
398
- fn=go_to_next, inputs=[current_index, all_images, all_results],
399
- outputs=[current_index, current_image, image_counter, analysis_text], queue=False
400
  )
401
-
402
- # About section
403
- with gr.Accordion("About", open=False):
404
- gr.Markdown("""
405
- ## About MATCHA 🍵:
406
-
407
- This demo generates alternative text for images.
408
-
409
- - Upload one or more images using the upload button
410
- - Choose a model and response length for generation
411
- - Navigate through the images with the Previous and Next buttons
412
- - Download CSV with all results
413
-
414
- Developed by the Natural History Museum in Partnership with National Museums Liverpool.
415
-
416
- If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
417
- chris.addis@nhm.ac.uk
418
- """)
419
-
420
- return demo
421
-
422
- # Launch the app
423
- if __name__ == "__main__":
424
- app = create_demo()
425
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
  import os
 
 
 
 
 
 
 
 
3
 
4
+ def check_access(request: gr.Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
+ Checks if the user is logged in via HF OAuth and if they are authorized
7
+ for full access.
8
+ """
9
+ user_info = request.auth # request.auth contains user info if hf_oauth is enabled
10
+
11
+ if user_info is None:
12
+ # User is not logged in
13
+ return (
14
+ gr.update(visible=True), # Show login prompt
15
+ gr.update(visible=False), # Hide full content
16
+ gr.update(visible=True), # Show limited content
17
+ "Please sign in to check your access.",
18
+ "Not Logged In"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
+ else:
21
+ # User is logged in, get their HF user ID
22
+ user_id = user_info.get('sub') # 'sub' is the standard claim for user ID in OpenID Connect
23
+ username = user_info.get('preferred_username', 'N/A') # Or use other claims like 'name', 'email'
24
+
25
+ print(f"Logged in user: Username={username}, ID={user_id}") # Log for debugging
26
+
27
+ if user_id in AUTHORIZED_USER_IDS:
28
+ # User is authorized for full access
29
+ return (
30
+ gr.update(visible=False), # Hide login prompt
31
+ gr.update(visible=True), # Show full content
32
+ gr.update(visible=False), # Hide limited content
33
+ f"Welcome, {username}! You have Full Access.",
34
+ f"Logged in as: {username} (ID: {user_id})"
35
+ )
36
+ else:
37
+ # User is logged in but not authorized for full access
38
+ return (
39
+ gr.update(visible=False), # Hide login prompt
40
+ gr.update(visible=False), # Hide full content
41
+ gr.update(visible=True), # Show limited content
42
+ f"Welcome, {username}. You have Limited Access.",
43
+ f"Logged in as: {username} (ID: {user_id})"
44
+ )
45
+
46
+ # --- Gradio Interface ---
47
+
48
+ with gr.Blocks() as demo:
49
+ # We'll use state to manage visibility based on access
50
+ login_prompt = gr.Markdown("### Sign in with Hugging Face to check your access.")
51
+ login_status_text = gr.Textbox(label="Status", interactive=False)
52
+
53
+ # The gr.LoginButton component initiates the HF OAuth flow
54
+ # This button itself doesn't need a click handler in this pattern;
55
+ # the check_access function runs on page load or other events.
56
+ # However, we can add a dummy event or rely on the initial load.
57
+ # For demonstration, we'll have the check happen on page load
58
+ # and re-check maybe with a button click if needed, but the
59
+ # primary check should handle the redirect return.
60
+ # A common pattern is to just rely on the state after redirect.
61
+
62
+ with gr.Column(visible=False) as limited_content:
63
+ gr.Markdown("## Limited Access Content")
64
+ gr.Textbox(value="This is content available to all users (or non-authorized logged-in users).", interactive=False)
65
+ # Add other limited features here
66
+
67
+ with gr.Column(visible=False) as full_content:
68
+ gr.Markdown("## Full Access Content")
69
+ gr.Textbox(value="🥳 This is special content for authorized users! 🥳", interactive=False)
70
+ # Add other full features here
71
+
72
+ # Initial check on page load is tricky with Blocks and redirects.
73
+ # A more reliable way is often to have a button that triggers the check
74
+ # after the potential redirect or on first load, or use a mechanism
75
+ # that runs logic on startup. For a simple example, let's add a check button.
76
+ # In a real app, the check_access logic might be called by a function
77
+ # that runs on initial page load or tied to a state change after login redirect.
78
+
79
+ # Let's simulate the check happening after potential login on page load
80
+ # by linking it to a component that is present from the start, or by
81
+ # calling the check function and updating outputs.
82
+ # Gradio often re-runs parts of the app logic after OAuth redirect.
83
+ # The outputs of check_access will update the visibility of components.
84
+
85
+ # A more robust pattern in Gradio might involve using the `load` event
86
+ # or triggering `check_access` via a component interaction that happens
87
+ # after the page loads and the request object is populated.
88
+ # For simplicity in a runnable example, let's trigger check_access
89
+ # with a button after the page is loaded and the user *might* have just logged in.
90
+ # In a production app, you'd want this check to happen automatically.
91
+
92
+ check_access_button = gr.Button("Check My Access Level") # Button to manually trigger check
93
+
94
+ check_access_button.click(
95
+ fn=check_access,
96
+ inputs=None, # The request object is implicitly available
97
+ outputs=[login_prompt, full_content, limited_content, login_status_text, login_status_text], # Update multiple outputs
98
+ api_name="check_access" # Give it an API name if needed
99
+ )
100
+
101
+ # A more seamless approach in a real space would involve using Gradio's state
102
+ # and potentially calling check_access when the space state indicates a user is present.
103
+ # The gr.LoginButton handles the *initiation* of the login.
104
+ # The *result* of the login populates the request.auth object on subsequent interactions
105
+ # or potentially on page reload after redirect.
106
+
107
+ demo.launch()