fossil_app / app.py
piperod91's picture
Update closest images display: convert gallery to table format with full specimen names
8dc677a
import os
import sys
from env import config_env
config_env()
import gradio as gr
from huggingface_hub import snapshot_download
import cv2
import dotenv
dotenv.load_dotenv()
import numpy as np
import gradio as gr
import glob
from inference_sam import segmentation_sam
from explanations import explain
from inference_resnet import get_triplet_model
from inference_resnet_v2 import get_resnet_model,inference_resnet_embedding_v2,inference_resnet_finer_v2
from inference_beit import get_triplet_model_beit
import pathlib
import tensorflow as tf
import pandas as pd
import re
import random
from closest_sample import get_images,get_diagram
if not os.path.exists('images'):
REPO_ID='Serrelab/image_examples_gradio'
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
HEADER = '''
<div style='display: flex; align-items: baseline;'>
<h1 style='margin-right: 10px;'><b>Official Gradio Demo:</b></h1>
<h1>🍁 <a href='https://huggingface.co/spaces/Serrelab/fossil_app' target='_blank'><b>Identifying Florissant Leaf Fossils to Family using Deep Neural Networks</b></a></h1>
</div>
'''
"""
**Fossil** a brief intro to the project.
# ❗️❗️❗️**Important Notes:**
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users .
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users.
Code: <a href='https://github.com/orgs/serre-lab/projects/2' target='_blank'>GitHub</a>. Paper: <a href='' target='_blank'>ArXiv</a>.
"""
USER_GUIDE = """
<div class="user-guide-wrapper">
### ❗️ User Guide
Welcome to the interactive fossil exploration tool. Here's how to get started:
- **Upload an Image:** Drag and drop or choose from given samples to upload images of fossils.
- **Process Image:** After uploading, click the 'Process Image' button to analyze the image.
- **Explore Results:** Switch to the 'Workbench' tab to check out detailed analysis and results.
#### Tips
- Zoom into images on the workbench for finer details.
- Use the examples below as references for what types of images to upload.
Enjoy exploring! 🌟
</div>
"""
TIPS = """
## Tips
- Zoom into images on the workbench for finer details.
- Use the examples below as references for what types of images to upload.
Enjoy exploring!
"""
CITATION = '''
πŸ“§ **Contact** <br>
If you have any questions, feel free to contact us at <b>ivan_felipe_rodriguez@brown.edu</b>.
'''
"""
πŸ“ **Citation**
cite using this bibtex:...
```
```
πŸ“‹ **License**
"""
def get_model(model_name):
if model_name=='Mummified 170':
n_classes = 170
model = get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
model.load_weights('model_classification/mummified-170.h5')
elif model_name=='Rock 170':
n_classes = 171
model = get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
model.load_weights('model_classification/rock-170.h5')
# elif model_name == 'Fossils 142': #BEiT
# n_classes = 142
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
# embedding_units = 256,
# embedding_depth = 2,
# n_classes = n_classes)
# model.load_weights('model_classification/fossil-142.h5')
# elif model_name == 'Fossils new': # BEiT-v2
# n_classes = 142
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
# embedding_units = 256,
# embedding_depth = 2,
# n_classes = n_classes)
# model.load_weights('model_classification/fossil-new.h5')
elif model_name == 'Fossils 142': # new resnet
n_classes = 142
model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
else:
raise ValueError(f"Model name '{model_name}' is not recognized")
return model,n_classes
def segment_image(input_image):
img = segmentation_sam(input_image)
return img
def classify_image(input_image, model_name):
#segmented_image = segment_image(input_image)
if 'Rock 170' ==model_name:
from inference_resnet import inference_resnet_finer
model,n_classes= get_model(model_name)
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Mummified 170' ==model_name:
from inference_resnet import inference_resnet_finer
model, n_classes= get_model(model_name)
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Fossils BEiT' ==model_name:
from inference_beit import inference_resnet_finer_beit
model,n_classes = get_model(model_name)
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
return result
# elif 'Fossils new' ==model_name:
# from inference_beit import inference_resnet_finer_beit
# model,n_classes = get_model(model_name)
# result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
# return result
elif 'Fossils 142' ==model_name:
from inference_beit import inference_resnet_finer_beit
model,n_classes = get_model(model_name)
result = inference_resnet_finer_v2(input_image,model,size=384,n_classes=n_classes)
return result
return None
def get_embeddings(input_image,model_name):
if 'Rock 170' ==model_name:
from inference_resnet import inference_resnet_embedding
model,n_classes= get_model(model_name)
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Mummified 170' ==model_name:
from inference_resnet import inference_resnet_embedding
model, n_classes= get_model(model_name)
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Fossils BEiT' ==model_name:
from inference_beit import inference_resnet_embedding_beit
model,n_classes = get_model(model_name)
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
return result
# elif 'Fossils new' ==model_name:
# from inference_beit import inference_resnet_embedding_beit
# model,n_classes = get_model(model_name)
# result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
# return result
elif 'Fossils 142' ==model_name:
from inference_beit import inference_resnet_embedding_beit
model,n_classes = get_model(model_name)
result = inference_resnet_embedding_v2(input_image,model,size=384,n_classes=n_classes)
return result
return None
def find_closest(input_image,model_name):
embedding = get_embeddings(input_image,model_name)
classes, paths, filenames = get_images(embedding,model_name)
#outputs = classes+paths
return classes, paths, filenames
def generate_diagram_closest(input_image,model_name,top_k):
embedding = get_embeddings(input_image,model_name)
diagram_path = get_diagram(embedding,top_k,model_name)
return diagram_path
def explain_image(input_image,model_name,explain_method,nb_samples):
model,n_classes= get_model(model_name)
if model_name=='Fossils BEiT' or 'Fossils 142':
size = 384
else:
size = 600
#saliency, integrated, smoothgrad,
h, w = input_image.shape[:2]
classes,exp_list = explain(model,input_image, h, w, explain_method,nb_samples,size = size, n_classes=n_classes)
#original = saliency + integrated + smoothgrad
print('done')
return classes,exp_list
def setup_examples():
"""
Setup example images from the CSV file with fossil responses.
Prioritizes 'Plausible' entries, then includes 'Not Sure' entries.
"""
# Use absolute path to ensure CSV is found regardless of working directory
csv_path = os.path.join(os.path.dirname(__file__), 'fossil_responses_with_images.csv')
fossil_samples = []
# Try to load from CSV first
print(f"DEBUG: Looking for CSV at: {csv_path}")
print(f"DEBUG: CSV exists: {os.path.exists(csv_path)}")
if os.path.exists(csv_path):
try:
df = pd.read_csv(csv_path)
print(f"DEBUG: CSV file found with {len(df)} rows")
# Extract URLs from HYPERLINK format: =HYPERLINK("url", "text")
def extract_url(hyperlink_str):
if pd.isna(hyperlink_str) or not hyperlink_str:
return None
# Convert to string and handle escaped quotes
url_str = str(hyperlink_str)
# Match URL - handle both escaped and unescaped quotes
# Pattern: https:// followed by characters until quote or comma
match = re.search(r'https://[^",\']+', url_str)
if match:
return match.group(0)
return None
# Filter entries with valid image URLs
df['Image_URL'] = df['Image URL'].apply(extract_url)
df_valid = df[df['Image_URL'].notna()].copy()
print(f"DEBUG: Found {len(df_valid)} entries with valid URLs")
if len(df_valid) > 0:
# Prioritize Plausible entries, then Not Sure
plausible = df_valid[df_valid['User Selection'] == 'Plausible'].head(15)
not_sure = df_valid[df_valid['User Selection'] == 'Not Sure'].head(8)
# Combine and use as fossil examples
fossil_samples = plausible['Image_URL'].tolist() + not_sure['Image_URL'].tolist()
# Shuffle the list to randomize the order
random.shuffle(fossil_samples)
print(f"DEBUG: Loaded {len(fossil_samples)} fossil examples from CSV (shuffled)")
print(f"DEBUG: - {len(plausible)} Plausible entries")
print(f"DEBUG: - {len(not_sure)} Not Sure entries")
if len(fossil_samples) > 0:
print(f"DEBUG: - Sample URL: {fossil_samples[0]}")
else:
print("DEBUG: No valid URLs found in CSV")
except Exception as e:
print(f"DEBUG: Error loading CSV examples: {e}")
import traceback
traceback.print_exc()
fossil_samples = []
else:
print(f"DEBUG: CSV file not found at {csv_path}")
# No fallback - only use CSV URLs
if not fossil_samples:
print("WARNING: No fossil samples loaded from CSV. Examples will be empty.")
# Gradio Examples can handle URLs directly - they will fetch and display the images
# Pass URLs as the first argument - Gradio will automatically fetch and display them
# Note: Gradio downloads URLs to temp directory, which is normal behavior
print(f"DEBUG: Final fossil_samples count: {len(fossil_samples)}")
if len(fossil_samples) > 0:
print(f"DEBUG: First fossil sample (should be URL): {fossil_samples[0]}")
print(f"DEBUG: Is URL: {fossil_samples[0].startswith('http') if fossil_samples else False}")
examples_fossils = gr.Examples(
fossil_samples,
inputs=input_image,
examples_per_page=6, # Reduced for better spacing and organization
label='Leaf fossil examples from the dataset',
elem_id="fossil-examples"
)
return examples_fossils
def preprocess_image(image, output_size=(300, 300)):
"""
Preprocess image for display.
Handles both numpy arrays and PIL images.
"""
# Convert PIL Image to numpy array if needed
if hasattr(image, 'size'): # PIL Image
image = np.array(image)
# Ensure image is a numpy array
if not isinstance(image, np.ndarray):
raise ValueError(f"Expected numpy array or PIL Image, got {type(image)}")
# Handle grayscale images (add channel dimension)
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
# Handle RGBA images (convert to RGB)
elif len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
# Ensure RGB images are converted to BGR for OpenCV
elif len(image.shape) == 3 and image.shape[2] == 3:
# Assume RGB, convert to BGR
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
#shape (height, width, channels)
h, w = image.shape[:2]
#padding
if h > w:
padding = (h - w) // 2
image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
else:
padding = (w - h) // 2
image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
# resize
image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
return image_resized
def increase_brightness(img, value=30):
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # Convert to HSV
h, s, v = cv2.split(hsv)
lim = 255 - value
v[v > lim] = 255
v[v <= lim] += value
final_hsv = cv2.merge((h, s, v))
img_bright = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
return img_bright
def update_display(image):
if image is None:
return None, None, None, "Please upload or select an image first.", "Fossils 142", "Rise", 10, 50, None, None, None, None
try:
print(f"DEBUG: update_display called with image type: {type(image)}")
if hasattr(image, 'shape'):
print(f"DEBUG: Image shape: {image.shape}")
original_image = image
processed_image = preprocess_image(image)
# Convert BGR back to RGB for display (Gradio expects RGB)
if len(processed_image.shape) == 3:
processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs."
print("DEBUG: Image processed successfully")
except Exception as e:
print(f"DEBUG: Error in update_display: {e}")
import traceback
traceback.print_exc()
return None, None, None, f"Error processing image: {str(e)}", "Fossils 142", "Rise", 10, 50, None, None, None, None
model_name = "Fossils 142"
# gr.Dropdown(
# ["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
# multiselect=False,
# value="Fossils new", # default option
# label="Model",
# interactive=True,
# info="Choose the model you'd like to use"
# )
explain_method = "Rise"
# gr.Dropdown(
# ["Sobol", "HSIC","Rise","Saliency"],
# multiselect=False,
# value="Rise", # default option
# label="Explain method",
# interactive=True,
# info="Choose one method to explain the model"
# )
sampling_size = 10
# gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
# info="Choose between 1 and 5000")
top_k = 50
# gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
class_predicted = None # gr.Label(label='Class Predicted',num_top_classes=10)
exp_gallery = None
# gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
closest_table = None
# gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
diagram= None
# gr.Image(label = 'Bar Chart')
return original_image,processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_table,diagram
def update_slider_visibility(explain_method):
bool = explain_method=="Rise"
return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)}
#minimalist theme
custom_css = """
.user-guide-wrapper {
padding: 20px;
border-radius: 10px;
border: 1px solid rgba(128, 128, 128, 0.3);
background-color: #f0f0f0 !important;
}
.dark .user-guide-wrapper,
[data-theme*="dark"] .user-guide-wrapper,
.gradio-container.dark .user-guide-wrapper,
body.dark .user-guide-wrapper {
background-color: #1e1e1e !important;
color: #ffffff !important;
}
.dark .user-guide-wrapper h3,
.dark .user-guide-wrapper h4,
.dark .user-guide-wrapper p,
.dark .user-guide-wrapper ul,
.dark .user-guide-wrapper li,
[data-theme*="dark"] .user-guide-wrapper h3,
[data-theme*="dark"] .user-guide-wrapper h4,
[data-theme*="dark"] .user-guide-wrapper p,
[data-theme*="dark"] .user-guide-wrapper ul,
[data-theme*="dark"] .user-guide-wrapper li {
color: #ffffff !important;
}
"""
with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
with gr.Tab(" Florrissant Fossils"):
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
gr.Markdown(USER_GUIDE)
with gr.Column(scale=2):
instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.")
input_image = gr.Image(label="Input",width="100%",container=True)
process_button = gr.Button("Process Image",icon = "https://www.svgrepo.com/show/13672/play-button.svg")
with gr.Column(scale=1):
with gr.Accordion("πŸ“Έ Example Fossils", open=True):
gr.Markdown("<p style='font-size: 14px; margin-bottom: 10px;'>Click on any example below to load it:</p>")
examples_fossils = setup_examples()
gr.Markdown(CITATION)
with gr.Tab("Specimen Workbench"):
with gr.Row():
with gr.Column():
original_image = gr.Image(visible = False)
workbench_image = gr.Image(label="Workbench Image")
classify_image_button = gr.Button("Classify Image",icon = "https://www.svgrepo.com/show/13672/play-button.svg")
# with gr.Column():
# #segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
# segmented_image=gr.Image(label="Segmented Image", type='numpy')
# segment_button = gr.Button("Segment Image")
# #classify_segmented_button = gr.Button("Classify Segmented Image")
with gr.Column():
model_name = gr.Dropdown(
["Fossils 142"],#"Mummified 170", "Rock 170","Fossils BEiT" removed
multiselect=False,
value="Fossils 142", # default option
label="Model",
interactive=True,
info="Choose the model you'd like to use"
)
explain_method = gr.Dropdown(
["Sobol", "HSIC","Rise","Saliency"],
multiselect=False,
value="Rise", # default option
label="Explain method",
interactive=True,
info="Choose one method to explain the model"
)
# explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"],
# label="explain method",
# value="Rise",
# multiselect=False,
# interactive=True,)
sampling_size = gr.Slider(10, 3000, value=10, label="Sampling Size in Rise",interactive=True,visible=True,
info="Choose between 10 and 3000")
top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
explain_method.change(
fn=update_slider_visibility,
inputs=explain_method,
outputs=sampling_size
)
with gr.Row():
with gr.Column(scale=1):
class_predicted = gr.Label(label='Plant Family Predicted',num_top_classes=10)
with gr.Column(scale=4):
with gr.Accordion("Explanations "):
gr.Markdown("Computing Explanations from the model for Top 5 Predicted Plant Families")
with gr.Column():
with gr.Row():
#original_input = gr.Image(label="Original Frame")
#saliency = gr.Image(label="saliency")
#gradcam = gr.Image(label='integraged gradients')
#guided_gradcam = gr.Image(label='gradcam')
#guided_backprop = gr.Image(label='guided backprop')
# exp1 = gr.Image(label = 'Class_name1')
# exp2= gr.Image(label = 'Class_name2')
# exp3= gr.Image(label = 'Class_name3')
# exp4= gr.Image(label = 'Class_name4')
# exp5= gr.Image(label = 'Class_name5')
exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
generate_explanations = gr.Button("Generate Explanations",icon = "https://www.svgrepo.com/show/13672/play-button.svg")
# with gr.Accordion('Closest Images'):
# gr.Markdown("Finding the closest images in the dataset")
# with gr.Row():
# with gr.Column():
# label_closest_image_0 = gr.Markdown('')
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_1 = gr.Markdown('')
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_2 = gr.Markdown('')
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_3 = gr.Markdown('')
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
# with gr.Column():
# label_closest_image_4 = gr.Markdown('')
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
# find_closest_btn = gr.Button("Find Closest Images")
with gr.Accordion('Closest Fossil Images'):
gr.Markdown("Finding 5 closest images in the dataset")
closest_table = gr.HTML(label="Closest Images Table")
find_closest_btn = gr.Button("Find Closest Images",icon = "https://www.svgrepo.com/show/13672/play-button.svg")
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted)
# generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) #
# with gr.Accordion('Closest Leaves Images'):
# gr.Markdown("5 closest leaves")
with gr.Accordion("Family Distribution of Closest Samples "):
gr.Markdown("Visualize plant family distribution of top-k closest samples in our dataset")
with gr.Column():
with gr.Row():
diagram= gr.Image(label = 'Bar Chart')
generate_diagram = gr.Button("Generate Diagram",icon = "https://www.svgrepo.com/show/13672/play-button.svg")
# with gr.Accordion("Using Diffuser"):
# with gr.Column():
# prompt = gr.Textbox(lines=1, label="Prompt")
# output_image = gr.Image(label="Output")
# generate_button = gr.Button("Generate Leave")
# with gr.Column():
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
# classify_button = gr.Button("Classify Image")
def update_exp_outputs(input_image,model_name,explain_method,nb_samples):
labels, images = explain_image(input_image,model_name,explain_method,nb_samples)
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
image_caption=[]
for i in range(5):
image_caption.append((images[i],"Predicted Plant Family "+str(i)+": "+labels[i]))
return image_caption
generate_explanations.click(fn=update_exp_outputs, inputs=[original_image,model_name,explain_method,sampling_size], outputs=[exp_gallery])
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
def update_closest_outputs(input_image,model_name):
labels, images, filenames = find_closest(input_image,model_name)
# Create HTML table with images and full specimen names
table_html = """
<style>
.closest-images-table {
width: 100%;
border-collapse: collapse;
margin: 20px 0;
}
.closest-images-table th {
background-color: #f0f0f0;
padding: 12px;
text-align: left;
border: 1px solid #ddd;
font-weight: bold;
}
.closest-images-table td {
padding: 12px;
border: 1px solid #ddd;
vertical-align: middle;
}
.closest-images-table tr:nth-child(even) {
background-color: inherit;
}
.closest-images-table img {
max-width: 200px;
max-height: 200px;
object-fit: contain;
border-radius: 4px;
display: block;
margin: 0 auto;
}
.specimen-name {
font-size: 16px;
font-weight: bold;
color: #0066cc;
font-family: monospace;
}
.plant-family {
font-size: 14px;
font-weight: 500;
}
</style>
<table class="closest-images-table">
<thead>
<tr>
<th>Rank</th>
<th>Image</th>
<th>Plant Family</th>
<th>Specimen Name</th>
</tr>
</thead>
<tbody>
"""
import os
import base64
from PIL import Image
import numpy as np
for i in range(5):
rank = i + 1
# Handle image - convert to base64 for HTML display
img_src = ""
if isinstance(images[i], str) and os.path.exists(images[i]):
# Local file path - convert to base64
try:
with open(images[i], 'rb') as f:
img_data = f.read()
img_base64 = base64.b64encode(img_data).decode('utf-8')
img_src = f"data:image/jpeg;base64,{img_base64}"
except Exception as e:
print(f"Error loading image {images[i]}: {e}")
img_src = ""
elif isinstance(images[i], np.ndarray):
# NumPy array - convert to PIL and then base64
try:
img = Image.fromarray(images[i])
import io
buffer = io.BytesIO()
img.save(buffer, format='JPEG')
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
img_src = f"data:image/jpeg;base64,{img_base64}"
except Exception as e:
print(f"Error converting numpy array to image: {e}")
img_src = ""
elif hasattr(images[i], 'save'):
# PIL Image
try:
import io
buffer = io.BytesIO()
images[i].save(buffer, format='JPEG')
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
img_src = f"data:image/jpeg;base64,{img_base64}"
except Exception as e:
print(f"Error converting PIL image: {e}")
img_src = ""
plant_family = labels[i] if labels[i] else "Unknown"
specimen_name = filenames[i] if (i < len(filenames) and filenames[i] and len(filenames[i]) > 0) else "N/A"
# Debug output
print(f"DEBUG: Rank {rank} - Family: {plant_family}, Specimen: {specimen_name}, Image: {images[i]}")
print(f"DEBUG: filenames array length: {len(filenames)}, filenames[{i}]: {filenames[i] if i < len(filenames) else 'OUT OF RANGE'}")
table_html += f"""
<tr>
<td style="text-align: center; font-weight: bold; width: 60px; font-size: 18px;">{rank}</td>
<td style="text-align: center;">
<img src="{img_src}" alt="Closest image {rank}" />
<div style="margin-top: 8px; font-size: 12px; color: #666; word-break: break-all;">{specimen_name}</div>
</td>
<td class="plant-family">{plant_family}</td>
<td class="specimen-name" style="word-break: break-all;">{specimen_name}</td>
</tr>
"""
table_html += """
</tbody>
</table>
"""
return table_html
find_closest_btn.click(fn=update_closest_outputs, inputs=[original_image,model_name], outputs=[closest_table])
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram)
process_button.click(
fn=update_display,
inputs=input_image,
outputs=[original_image,input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_table,diagram]
)
demo.queue() # manage multiple incoming requests
if os.getenv('SYSTEM') == 'spaces':
demo.launch(width='40%', debug=True)
#,auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD'))
else:
demo.launch(debug=True)