File size: 21,086 Bytes
eec37f2
376d3e4
 
9883bdb
376d3e4
 
 
7b3d9d6
376d3e4
 
 
 
eec37f2
ee3d9fb
376d3e4
 
 
 
ee3d9fb
 
85c57b7
 
 
 
f074acd
376d3e4
bc55cd8
376d3e4
 
bc55cd8
 
376d3e4
 
e9a730f
4bd8cea
bc55cd8
 
376d3e4
 
91630c9
 
bc55cd8
91630c9
bc55cd8
 
 
d58103b
91630c9
 
 
 
 
054bd75
bc55cd8
91630c9
 
 
 
 
f7ed6e0
376d3e4
 
f7ed6e0
 
 
4b5c5ff
376d3e4
 
 
 
 
 
 
 
 
 
 
 
ee3d9fb
376d3e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee3d9fb
 
376d3e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee3d9fb
 
376d3e4
 
 
 
 
7b3d9d6
376d3e4
 
86cec04
376d3e4
 
 
86cec04
376d3e4
 
 
f074acd
376d3e4
f074acd
376d3e4
6949658
 
3830538
 
 
 
 
055d5f4
6949658
 
 
 
bc55cd8
 
6949658
bc55cd8
6949658
 
 
 
 
7b3d9d6
6949658
57f4443
 
 
 
0e13f7c
33e8ffe
f074acd
 
 
 
 
 
57f4443
376d3e4
f074acd
376d3e4
57f4443
376d3e4
57f4443
376d3e4
 
f074acd
 
376d3e4
 
 
f074acd
 
 
376d3e4
f074acd
 
4c4fdd7
93054df
055d5f4
ab8f25b
055d5f4
ab8f25b
055d5f4
39009df
 
f7ed6e0
39009df
 
 
055d5f4
93054df
8cfc55d
f074acd
 
8cfc55d
 
 
 
7b3d9d6
8cfc55d
376d3e4
b39d335
41f0936
 
7b3d9d6
41f0936
 
376d3e4
f074acd
 
 
 
376d3e4
 
 
 
 
 
 
 
f074acd
 
 
 
 
 
 
376d3e4
f074acd
376d3e4
 
 
 
f074acd
376d3e4
 
f074acd
 
376d3e4
 
 
 
 
f074acd
 
 
 
 
 
4c4fdd7
f074acd
 
 
86cec04
f074acd
 
 
 
 
 
 
 
bc55cd8
f074acd
 
 
 
 
 
 
 
 
376d3e4
f074acd
 
86cec04
f074acd
 
 
 
 
 
 
 
 
5a20aa7
f074acd
 
5a20aa7
f074acd
 
5a20aa7
376d3e4
5a20aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81abc3a
5a20aa7
 
 
f7ed6e0
5a20aa7
 
 
 
 
 
 
81abc3a
055d5f4
f074acd
 
5a20aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376d3e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81abc3a
376d3e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee3d9fb
f074acd
ee3d9fb
376d3e4
4c4fdd7
0ff4dd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
import gradio as gr
import numpy as np
from PIL import Image
import os
import requests
import json
from dotenv import load_dotenv
#import openai
import base64
import csv
import tempfile
import datetime

# import libraries
from library.utils_model import *
from library.utils_html import *
from library.utils_prompt import *

OR = OpenRouterAPI()

# Get authorized users from environment variable/secret
authorized_users_str = os.environ.get("AUTHORIZED_USER_IDS", "")
AUTHORIZED_USER_IDS = set(authorized_users_str.split(',') if authorized_users_str and authorized_users_str.strip() else [])

# Define model pricing information (approximate costs per 100 image API calls)
MODEL_PRICING = {
    "google/gemini-2.5-flash": "$0.08",
    "gpt-4.1-mini": "$0.07",
    "gpt-4.1": "$0.35",
    "anthropic/claude-sonnet-4": "$0.70",
    "google/gemini-2.5-pro": "$1.20",
    "gpt-4.1-nano": "$0.02",
    "openai/chatgpt-4o-latest": "$0.75",
    "meta-llama/llama-4-maverick": "$0.04",
    "meta-llama/llama-4-maverick:free": "Free",
    "openai/gpt-5-chat": "N/A",
    "openai/gpt-5-mini": "N/A"
}

# Define preferred and additional models directly in the function
preferred_models_auth = [
    ("Gemini 2.5 Flash", "google/gemini-2.5-flash"),
    ("GPT-4.1 Mini", "gpt-4.1-mini"),
    ("GPT-4.1", "gpt-4.1"),
    ("Claude Sonnet 4", "anthropic/claude-sonnet-4"),
    ("Gemini 2.5 Pro", "google/gemini-2.5-pro"),
    ("openai/gpt-5-chat", "GPT-5-chat")
]

additional_models = [
    ("GPT-4.1 Nano", "gpt-4.1-nano"),
    ("ChatGPT Latest", "openai/chatgpt-4o-latest"),
    ("Llama 4 Maverick", "meta-llama/llama-4-maverick"),
    ("GPT-5-mini", "openai/gpt-5-mini")
]

# Calculate all models once
all_models_list = preferred_models_auth + additional_models

def get_sys_prompt(length="medium", nat_hist=False,filename=""):
    extra_prompt = ""

    if nat_hist:
        object_type = "Natural History Images"
        extra_prompt = " Do not guess the exact species of the animal in the image unless you are certain - simply use a broader terms to make less errors e.g. say Swan rather mute Swan or Whooper Swan unless you are certain."
    else:
        object_type = "museum objects"
        
    dev_prompt = f"""You are a museum curator tasked with generating long descriptions (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows', 'This is an image of', 'The photograph'. Be precise, concise and avoid filler and subjective statements."""
    
    if length == "short":
        dev_prompt = f"""You are a museum curator tasked with generating alt-text (as defined by W3C) of {object_type} for visually impaired and blind users from images. Use British English and follow museum accessibility best practices. Do not start with phrases like 'The image shows' or 'This is an image of'. Be precise, concise and avoid filler and subjective statements. Repsonses should be a maximum of 130 characters."""
    elif length == "medium":
        dev_prompt += " Repsonses should be a maximum of 250-300 characters."
    else: # long
        dev_prompt += " Repsonses should be a maximum of 450 characters."
    return dev_prompt + extra_prompt


def create_csv_file_simple(results):
    """Create a CSV file from the results and return the path"""
    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, newline='', encoding='utf-8') as f:
            path = f.name
            writer = csv.writer(f)
            writer.writerow(['image_id', 'content'])
            for result in results:
                writer.writerow([
                    result.get('image_id', ''),
                    result.get('content', '')
                ])
        return path
    except Exception as e:
        print(f"Error creating CSV: {e}")
        return None

def get_base_filename(filepath):
    if not filepath:
        return ""
    basename = os.path.basename(filepath)
    filename = os.path.splitext(basename)[0]
    return filename

# Define the Gradio interface
def create_demo():
    custom_css = """
    /* Container for the image component (#current-image-display is the elem_id of gr.Image) */
        #current-image-display {
            height: 600px;           /* Define container height */
            width: 100%;             /* Define container width (takes column width) */
            display: flex;           /* Use flexbox for alignment */
            justify-content: center; /* Center content horizontally */
            align-items: center;     /* Center content vertically */
            overflow: hidden;        /* Hide any potential overflow from container */
        }
    
        /* The actual <img> element inside the container */
        #current-image-display img {
            object-fit: contain !important;    /* Scale keeping aspect ratio, within bounds */
            max-width: 100%;        /* Prevent image exceeding container width */
            max-height: 600px !important;       /* Prevent image exceeding container height */
            width: auto;            /* Use natural width unless constrained by max-width */
            height: auto;           /* Use natural height unless constrained by max-height */
            display: block;         /* Ensure image behaves predictably in flex */
        }
        
        /* Custom style for model info display */
        #model-info-display {
            font-size: 0.85rem;     /* Small font size */
            color: #666;            /* Subtle color */
            margin-top: 0.5rem;     /* Small top margin */
            margin-bottom: 1rem;    /* Bottom margin before next element */
            padding-left: 0.5rem;   /* Slight indentation */
        }
    """
    # --- Pass css to gr.Blocks ---
    with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
        with gr.Row():
            with gr.Column(scale=3):
                gr.Markdown("# MATCHA: Museum Alt-Text for Cultural Heritage with AI 🍵 🌿")
                gr.Markdown("Upload one or more images to generate accessible alternative text (designed to meet WCAG Guidelines)")
                gr.Markdown("Developed by the Natural History Museum in Partnership with National Museums Liverpool. Funded by the DCMS Pilot Scheme")
                auth_state = gr.Markdown()
            with gr.Column(scale=1):
                with gr.Row():
                    gr.Image("images/nhm_logo.png", show_label=False, height=100,
                             interactive=False, show_download_button=False,
                             show_share_button=False, show_fullscreen_button=False,
                             container=False, elem_id="nhm-logo")
                    gr.Image("images/nml_logo.png", show_label=False, height=100,
                             interactive=False, show_download_button=False,
                             show_share_button=False, show_fullscreen_button=False,
                             container=False, elem_id="nml-logo")
        
        with gr.Row():
            # Left column: Controls and uploads
            with gr.Column(scale=1):
                # Function to check authorization
                def check_authorization(profile: gr.OAuthProfile | None):
                    if profile is None:
                        # Default model value
                        default_model = "meta-llama/llama-4-maverick:free"#preferred_models[0][1] # get free model
                        text = f"""**Current Model**: Llama 4 Maverick (free)
                    **Estimated cost per 100 Images**: {MODEL_PRICING["meta-llama/llama-4-maverick:free"]}"""
                        return gr.update(choices=preferred_models, label="Select Model",value=default_model),text,"Free version - please email chris.addis@nhm.ac.uk about full access."""
                    
                    is_authorized = profile.username in AUTHORIZED_USER_IDS
                    if is_authorized:
                
                        text = f"""**Current Model**: Gemini 2.5 Flash
                    **Estimated cost per 100 Images**: {MODEL_PRICING["google/gemini-2.5-flash"]}"""
                        
                        return gr.update(choices=preferred_models_auth, label="Select Model",value="google/gemini-2.5-flash"),text,f"Logged in as: {profile.username}"
                    else:
                        # Default model value
                        default_model = "meta-llama/llama-4-maverick:free"#preferred_models[0][1] # get free model
                        text = f"""**Current Model**: Llama 4 Maverick (free)
                    **Estimated cost per 100 Images**: {MODEL_PRICING["meta-llama/llama-4-maverick:free"]}"""
                        return gr.update(choices=preferred_models, label="Select Model",value=default_model),text,"Free version - please email chris.addis@nhm.ac.uk for full access."
                        
                # Define preferred and additional models directly in the function
                preferred_models = [
                    ("Llama 4 Maverick (free)", "meta-llama/llama-4-maverick:free")
                ]

                login_button = gr.LoginButton()#visible=False

                upload_button = gr.UploadButton(
                    "Click to Upload Images",
                    file_types=["image"],
                    file_count="multiple"
                )
                        
                model_choice = gr.Dropdown(
                    choices=preferred_models, 
                    label="Select Model",
                    value="meta-llama/llama-4-maverick:free"
                )
            
                length_choice = gr.Radio(
                    choices=["short", "medium", "long"], 
                    label="Response Length",
                    value="medium", 
                    info="Short: max 130 chars | Medium: 250-300 chars | Long: max 450 chars"
                )
                
                # Advanced settings accordion
                with gr.Accordion("Advanced Settings", open=False):
                    show_all_models = gr.Checkbox(
                        label="Show Additional Models", 
                        value=False,
                        info="Display additional model options in the dropdown above"
                    )

                    use_filename_in_prompt = gr.Checkbox(
                        label="Include filename as metadata",
                        value=False,
                        info="Useful for inputing species data if appropiate"
                    )
                    
                    content_type = gr.Radio(
                        choices=["Museum Object", "Natural History"], 
                        label="Content Type",
                        value="Museum Object"
                    )

                #markdown for current model costings
                model_info = gr.Markdown("",
                    elem_id="model-info-display"
                )

                demo.load(
                    fn=check_authorization,
                    inputs=None,
                    outputs=[model_choice,model_info,auth_state]
                )
                
                login_button.click(
                    fn=check_authorization,
                    inputs=None, # The user profile is automatically passed on login
                    outputs=[model_choice, model_info,auth_state]
                )

                gr.Markdown("### Uploaded Images")
                input_gallery = gr.Gallery(
                    label="Uploaded Image Previews", columns=3, height=150,
                    object_fit="contain", show_label=False
                )
                analyze_button = gr.Button("Generate Alt-Text", variant="primary", size="lg")
                image_state = gr.State([])
                filename_state = gr.State([])
                csv_download = gr.File(label="Download CSV Results")

            # Right column: Display area
            with gr.Column(scale=2):
                current_image = gr.Image(
                    label="Current Image",
                    type="filepath",
                    elem_id="current-image-display",
                    show_fullscreen_button=True,
                    show_download_button=False,
                    show_share_button=False,
                    show_label=False
                 )

                with gr.Row():
                    prev_button = gr.Button("← Previous", size="sm")
                    image_counter = gr.Markdown("0 of 0", elem_id="image-counter")
                    next_button = gr.Button("Next →", size="sm")

                gr.Markdown("### Generated Alt-text")
                analysis_text = gr.Textbox(
                    label="Generated Text",
                    value="Upload images and click 'Generate Alt-Text'.",
                    lines=6, max_lines=10, interactive=True, show_label=False
                )
                current_index = gr.State(0)
                all_images = gr.State([])
                all_results = gr.State([])
        
        # Handle checkbox change to update model dropdown - modern version
        def toggle_models(show_all, current_model):
            # Make a fresh copy of the models lists to avoid any reference issues
            preferred_choices = list(preferred_models)
            all_choices = list(all_models_list)
            
            if show_all:
                # When showing all models, use the fresh copy of all models
                return gr.Dropdown(choices=all_choices, value=current_model)
            else:
                # Check if current model is in preferred models list
                preferred_values = [value for _, value in preferred_choices]
                
                if current_model in preferred_values:
                    # Keep the current model if it's in preferred models
                    return gr.Dropdown(choices=preferred_choices, value=current_model)
                else:
                    # Reset to default model if current model is not in preferred models
                    return gr.Dropdown(choices=preferred_choices, value="google/gemini-2.5-flash")

        # Update model info when model selection changes
        def update_model_info(model_value):
            # Find display name
            model_name = "Unknown Model"
            for name, value in all_models_list:
                if value == model_value:
                    model_name = name
                    break
            
            # Get cost
            cost = MODEL_PRICING.get(model_value, "Unknown")
            
            # Create markdown
            return f"""**Current Model**: {model_name}  
                   **Estimated cost per 100 Images**: {cost}"""
                
        # Connect checkbox to toggle model choices
        show_all_models.change(
            fn=toggle_models,
            inputs=[show_all_models, model_choice],
            outputs=[model_choice]
        )
        
        # Connect model selection to update info
        model_choice.change(
            fn=update_model_info,
            inputs=[model_choice],
            outputs=[model_info]
        )
        
        # Handle file uploads
        def handle_upload(files, current_paths, current_filenames):
            file_paths = []
            file_names = []
            if files:
                for file in files:
                    file_paths.append(file.name)
                    file_names.append(get_base_filename(file.name))
            return file_paths, file_paths, file_names, 0, None, "0 of 0", "Upload images and click 'Generate Alt-Text'."

        upload_button.upload(
            fn=handle_upload,
            inputs=[upload_button, image_state, filename_state],
            outputs=[input_gallery, image_state, filename_state,
                     current_index, current_image, image_counter, analysis_text]
        )

        # Analyze images
        def analyze_images(image_paths, model_choice, length_choice, filenames, content_type_choice, include_filename):
            if not image_paths:
                return [], [], 0, None, "0 of 0", "No images uploaded to analyze.", None

            sys_prompt = get_sys_prompt(length_choice, nat_hist= content_type_choice == "Natural History")
            image_results = []
            analysis_progress = gr.Progress(track_tqdm=True)

            for i, image_path in enumerate(analysis_progress.tqdm(image_paths, desc="Analyzing Images")):
                image_id = filenames[i] if i < len(filenames) and filenames[i] else f"Image_{i+1}_{os.path.basename(image_path)}"
                try:
                    img = Image.open(image_path)
                    user_prompt_filename = image_id if include_filename else None
                    prompt0 = prompt_new(user_prompt_filename)
                    model_name = model_choice
                    client_to_use = OR # Default client

                    result = client_to_use.generate_caption(
                        img, model=model_name, max_image_size=512,
                        prompt=prompt0, prompt_dev=sys_prompt, temperature=1
                    )
                    image_results.append({"image_id": image_id, "content": result.strip()})
                except FileNotFoundError:
                     error_message = f"Error: File not found at path '{image_path}'"
                     print(error_message)
                     image_results.append({"image_id": image_id, "content": error_message})
                except Exception as e:
                    error_message = f"Error processing {image_id}: {str(e)}"
                    print(error_message)
                    image_results.append({"image_id": image_id, "content": error_message})

            csv_path = create_csv_file_simple(image_results)
            initial_image = image_paths[0] if image_paths else None
            initial_counter = f"1 of {len(image_paths)}" if image_paths else "0 of 0"
            initial_text = image_results[0]["content"] if image_results else "Analysis complete, but no results generated."

            return (image_paths, image_results, 0, initial_image, initial_counter,
                    initial_text, csv_path)

        # Navigate previous
        def go_to_prev(current_idx, images, results):
            if not images or not results or len(images) == 0:
                return current_idx, None, "0 of 0", ""
            new_idx = (current_idx - 1 + len(images)) % len(images)
            counter_text = f"{new_idx + 1} of {len(images)}"
            result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
            return (new_idx, images[new_idx], counter_text, result_content)

        # Navigate next
        def go_to_next(current_idx, images, results):
            if not images or not results or len(images) == 0:
                return current_idx, None, "0 of 0", ""
            new_idx = (current_idx + 1) % len(images)
            counter_text = f"{new_idx + 1} of {len(images)}"
            result_content = results[new_idx]["content"] if new_idx < len(results) else "Error: Result not found"
            return (new_idx, images[new_idx], counter_text, result_content)

        # Connect analyze button
        analyze_button.click(
            fn=analyze_images,
            inputs=[image_state, model_choice, length_choice, filename_state, content_type, use_filename_in_prompt],
            outputs=[all_images, all_results, current_index, current_image, image_counter,
                     analysis_text, csv_download]
        )

        # Connect navigation buttons
        prev_button.click(
            fn=go_to_prev, inputs=[current_index, all_images, all_results],
            outputs=[current_index, current_image, image_counter, analysis_text], queue=False
        )
        next_button.click(
            fn=go_to_next, inputs=[current_index, all_images, all_results],
            outputs=[current_index, current_image, image_counter, analysis_text], queue=False
        )

        # About section
        with gr.Accordion("About", open=False):
             gr.Markdown("""
            ## About MATCHA 🍵:
            
            This demo generates alternative text for images.
            
            - Upload one or more images using the upload button
            - Choose a model and response length for generation
            - Navigate through the images with the Previous and Next buttons
            - Download CSV with all results
            
            Developed by the Natural History Museum in Partnership with National Museums Liverpool.
            
            If you find any bugs/have any problems/have any suggestions please feel free to get in touch:
            chris.addis@nhm.ac.uk
            """)

    return demo

# Launch the app
if __name__ == "__main__":
    app = create_demo()
    app.launch()