Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Closest neighbors: start at 5 and allow 'Add K more neighbors' with dynamic K and tables
Browse files- app.py +56 -21
- closest_sample.py +7 -7
app.py
CHANGED
|
@@ -76,7 +76,7 @@ USER_GUIDE = """
|
|
| 76 |
|
| 77 |
- **Input**: Upload a fossil-leaf photograph (ideally a tightly cropped Florissant compression fossil with minimal background clutter).\n
|
| 78 |
- **Output A — predicted families**: The model suggests the top-ranked plant families (as hypotheses to guide expert review).\n
|
| 79 |
-
- **Output B — closest training fossils**: The most similar Florissant fossils in
|
| 80 |
- **Output C — explanations**: Heatmaps showing image regions that support each suggested family (qualitative).\n
|
| 81 |
|
| 82 |
### How to use
|
|
@@ -251,17 +251,17 @@ def _top_predicted_class(class_predicted):
|
|
| 251 |
return None
|
| 252 |
|
| 253 |
|
| 254 |
-
def find_closest(input_image, model_name, predicted_class=None):
|
| 255 |
embedding = get_embeddings(input_image, model_name)
|
| 256 |
from closest_sample import get_images, get_images_fossils
|
| 257 |
-
classes, paths, filenames = get_images(embedding, model_name, predicted_class=predicted_class)
|
| 258 |
return classes, paths, filenames
|
| 259 |
|
| 260 |
|
| 261 |
-
def find_closest_fossils(input_image, model_name, predicted_class=None):
|
| 262 |
embedding = get_embeddings(input_image, model_name)
|
| 263 |
from closest_sample import get_images_fossils
|
| 264 |
-
return get_images_fossils(embedding, model_name, predicted_class=predicted_class)
|
| 265 |
|
| 266 |
def generate_diagram_closest(input_image,model_name,top_k):
|
| 267 |
embedding = get_embeddings(input_image,model_name)
|
|
@@ -487,17 +487,17 @@ def setup_examples():
|
|
| 487 |
# Gradio Examples can handle URLs directly - they will fetch and display the images
|
| 488 |
# Pass URLs as the first argument - Gradio will automatically fetch and display them
|
| 489 |
# Note: Gradio downloads URLs to temp directory, which is normal behavior
|
| 490 |
-
|
| 491 |
if len(fossil_samples) > 0:
|
| 492 |
print(f"DEBUG: First fossil sample (should be URL): {fossil_samples[0]}")
|
| 493 |
print(f"DEBUG: Is URL: {fossil_samples[0].startswith('http') if fossil_samples else False}")
|
| 494 |
|
| 495 |
examples_fossils = gr.Examples(
|
| 496 |
-
fossil_samples,
|
| 497 |
inputs=input_image,
|
| 498 |
examples_per_page=6, # Reduced for better spacing and organization
|
| 499 |
-
label='Leaf fossil examples from
|
| 500 |
-
elem_id="fossil-examples"
|
| 501 |
)
|
| 502 |
return examples_fossils
|
| 503 |
|
|
@@ -842,7 +842,9 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
| 842 |
with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
|
| 843 |
gr.Markdown(
|
| 844 |
"<p style='font-size: 13px; margin-bottom: 10px;'>"
|
| 845 |
-
|
|
|
|
|
|
|
| 846 |
"For full context pages, use: "
|
| 847 |
"<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
|
| 848 |
"</p>"
|
|
@@ -971,20 +973,28 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
| 971 |
# label_closest_image_4 = gr.Markdown('')
|
| 972 |
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
|
| 973 |
# find_closest_btn = gr.Button("Find Closest Images")
|
|
|
|
|
|
|
|
|
|
| 974 |
with gr.Accordion('Closest Extant Leaves'):
|
| 975 |
gr.Markdown(
|
| 976 |
-
"5 closest extant leaves
|
|
|
|
|
|
|
| 977 |
"Neighbors are chosen by *embedding* similarity (visual), not by predicted family; "
|
| 978 |
"if you ran **Classify Image** first, we try to include at least one from the top-predicted family when possible."
|
| 979 |
)
|
| 980 |
closest_table = gr.HTML(label="Closest Extant Table")
|
| 981 |
with gr.Accordion('Closest Fossils'):
|
| 982 |
gr.Markdown(
|
| 983 |
-
"5 closest fossil images
|
| 984 |
-
"
|
|
|
|
|
|
|
| 985 |
)
|
| 986 |
closest_fossils_table = gr.HTML(label="Closest Fossils Table")
|
| 987 |
find_closest_btn = gr.Button("Find Closest Images", icon="https://www.svgrepo.com/show/13672/play-button.svg")
|
|
|
|
| 988 |
|
| 989 |
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
| 990 |
classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted)
|
|
@@ -992,7 +1002,10 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
| 992 |
# with gr.Accordion('Closest Leaves Images'):
|
| 993 |
# gr.Markdown("5 closest leaves")
|
| 994 |
with gr.Accordion("Family Distribution of Closest Samples "):
|
| 995 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
| 996 |
with gr.Column():
|
| 997 |
with gr.Row():
|
| 998 |
diagram= gr.Image(label = 'Bar Chart')
|
|
@@ -1042,7 +1055,8 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
| 1042 |
<thead><tr><th>Rank (1 = most similar)</th><th>Image</th><th>Plant Family</th><th>Specimen Name</th></tr></thead>
|
| 1043 |
<tbody>
|
| 1044 |
"""
|
| 1045 |
-
|
|
|
|
| 1046 |
rank = i + 1
|
| 1047 |
img_src = ""
|
| 1048 |
if i < len(images):
|
|
@@ -1077,13 +1091,34 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
| 1077 |
table_html += "</tbody></table>"
|
| 1078 |
return table_html
|
| 1079 |
|
| 1080 |
-
def
|
| 1081 |
predicted_class = _top_predicted_class(class_predicted)
|
| 1082 |
-
extant_labels, extant_images, extant_filenames = find_closest(input_image, model_name, predicted_class=predicted_class)
|
| 1083 |
-
fossil_labels, fossil_images, fossil_filenames = find_closest_fossils(input_image, model_name, predicted_class=predicted_class)
|
| 1084 |
-
return
|
| 1085 |
-
|
| 1086 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1087 |
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
|
| 1088 |
|
| 1089 |
generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram)
|
|
|
|
| 76 |
|
| 77 |
- **Input**: Upload a fossil-leaf photograph (ideally a tightly cropped Florissant compression fossil with minimal background clutter).\n
|
| 78 |
- **Output A — predicted families**: The model suggests the top-ranked plant families (as hypotheses to guide expert review).\n
|
| 79 |
+
- **Output B — closest training fossils**: The most similar Florissant fossils in [The image dataset](https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698) (for qualitative comparison).\n
|
| 80 |
- **Output C — explanations**: Heatmaps showing image regions that support each suggested family (qualitative).\n
|
| 81 |
|
| 82 |
### How to use
|
|
|
|
| 251 |
return None
|
| 252 |
|
| 253 |
|
| 254 |
+
def find_closest(input_image, model_name, predicted_class=None, top_k=5):
|
| 255 |
embedding = get_embeddings(input_image, model_name)
|
| 256 |
from closest_sample import get_images, get_images_fossils
|
| 257 |
+
classes, paths, filenames = get_images(embedding, model_name, predicted_class=predicted_class, top_k=top_k)
|
| 258 |
return classes, paths, filenames
|
| 259 |
|
| 260 |
|
| 261 |
+
def find_closest_fossils(input_image, model_name, predicted_class=None, top_k=5):
|
| 262 |
embedding = get_embeddings(input_image, model_name)
|
| 263 |
from closest_sample import get_images_fossils
|
| 264 |
+
return get_images_fossils(embedding, model_name, predicted_class=predicted_class, top_k=top_k)
|
| 265 |
|
| 266 |
def generate_diagram_closest(input_image,model_name,top_k):
|
| 267 |
embedding = get_embeddings(input_image,model_name)
|
|
|
|
| 487 |
# Gradio Examples can handle URLs directly - they will fetch and display the images
|
| 488 |
# Pass URLs as the first argument - Gradio will automatically fetch and display them
|
| 489 |
# Note: Gradio downloads URLs to temp directory, which is normal behavior
|
| 490 |
+
print(f"DEBUG: Final fossil_samples count: {len(fossil_samples)}")
|
| 491 |
if len(fossil_samples) > 0:
|
| 492 |
print(f"DEBUG: First fossil sample (should be URL): {fossil_samples[0]}")
|
| 493 |
print(f"DEBUG: Is URL: {fossil_samples[0].startswith('http') if fossil_samples else False}")
|
| 494 |
|
| 495 |
examples_fossils = gr.Examples(
|
| 496 |
+
fossil_samples,
|
| 497 |
inputs=input_image,
|
| 498 |
examples_per_page=6, # Reduced for better spacing and organization
|
| 499 |
+
label='Leaf fossil examples from The image dataset',
|
| 500 |
+
elem_id="fossil-examples",
|
| 501 |
)
|
| 502 |
return examples_fossils
|
| 503 |
|
|
|
|
| 842 |
with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
|
| 843 |
gr.Markdown(
|
| 844 |
"<p style='font-size: 13px; margin-bottom: 10px;'>"
|
| 845 |
+
"These thumbnails are sourced from the Florissant subset of "
|
| 846 |
+
"<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>, "
|
| 847 |
+
"from specimens where there are doubts about their family. "
|
| 848 |
"For full context pages, use: "
|
| 849 |
"<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
|
| 850 |
"</p>"
|
|
|
|
| 973 |
# label_closest_image_4 = gr.Markdown('')
|
| 974 |
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
|
| 975 |
# find_closest_btn = gr.Button("Find Closest Images")
|
| 976 |
+
neighbors_k_state = gr.State(5)
|
| 977 |
+
neighbors_add_k = gr.Slider(1, 20, value=5, step=1, label="Add K neighbors", info="Number of additional neighbors to fetch on demand")
|
| 978 |
+
|
| 979 |
with gr.Accordion('Closest Extant Leaves'):
|
| 980 |
gr.Markdown(
|
| 981 |
+
"5 closest extant leaves from the extant subset of "
|
| 982 |
+
"<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>. "
|
| 983 |
+
"**Ordered from most to least similar.** "
|
| 984 |
"Neighbors are chosen by *embedding* similarity (visual), not by predicted family; "
|
| 985 |
"if you ran **Classify Image** first, we try to include at least one from the top-predicted family when possible."
|
| 986 |
)
|
| 987 |
closest_table = gr.HTML(label="Closest Extant Table")
|
| 988 |
with gr.Accordion('Closest Fossils'):
|
| 989 |
gr.Markdown(
|
| 990 |
+
"5 closest Florissant fossil images from the fossil subset of "
|
| 991 |
+
"<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>. "
|
| 992 |
+
"**Ordered from most to least similar.** "
|
| 993 |
+
"Same logic: embedding similarity + at least one from predicted family when you have run **Classify Image**."
|
| 994 |
)
|
| 995 |
closest_fossils_table = gr.HTML(label="Closest Fossils Table")
|
| 996 |
find_closest_btn = gr.Button("Find Closest Images", icon="https://www.svgrepo.com/show/13672/play-button.svg")
|
| 997 |
+
add_neighbors_btn = gr.Button("Add K more neighbors", icon="https://www.svgrepo.com/show/13672/play-button.svg")
|
| 998 |
|
| 999 |
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
| 1000 |
classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted)
|
|
|
|
| 1002 |
# with gr.Accordion('Closest Leaves Images'):
|
| 1003 |
# gr.Markdown("5 closest leaves")
|
| 1004 |
with gr.Accordion("Family Distribution of Closest Samples "):
|
| 1005 |
+
gr.Markdown(
|
| 1006 |
+
"Visualize plant family distribution of top-k closest samples in "
|
| 1007 |
+
"<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>."
|
| 1008 |
+
)
|
| 1009 |
with gr.Column():
|
| 1010 |
with gr.Row():
|
| 1011 |
diagram= gr.Image(label = 'Bar Chart')
|
|
|
|
| 1055 |
<thead><tr><th>Rank (1 = most similar)</th><th>Image</th><th>Plant Family</th><th>Specimen Name</th></tr></thead>
|
| 1056 |
<tbody>
|
| 1057 |
"""
|
| 1058 |
+
n = min(len(images), len(labels), len(filenames)) if filenames is not None else min(len(images), len(labels))
|
| 1059 |
+
for i in range(n):
|
| 1060 |
rank = i + 1
|
| 1061 |
img_src = ""
|
| 1062 |
if i < len(images):
|
|
|
|
| 1091 |
table_html += "</tbody></table>"
|
| 1092 |
return table_html
|
| 1093 |
|
| 1094 |
+
def _compute_closest(input_image, model_name, class_predicted, k):
|
| 1095 |
predicted_class = _top_predicted_class(class_predicted)
|
| 1096 |
+
extant_labels, extant_images, extant_filenames = find_closest(input_image, model_name, predicted_class=predicted_class, top_k=k)
|
| 1097 |
+
fossil_labels, fossil_images, fossil_filenames = find_closest_fossils(input_image, model_name, predicted_class=predicted_class, top_k=k)
|
| 1098 |
+
return (
|
| 1099 |
+
_closest_table_html(extant_labels, extant_images, extant_filenames),
|
| 1100 |
+
_closest_table_html(fossil_labels, fossil_images, fossil_filenames),
|
| 1101 |
+
k,
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
def update_closest_outputs_initial(input_image, model_name, class_predicted):
|
| 1105 |
+
# Reset to 5 neighbors
|
| 1106 |
+
return _compute_closest(input_image, model_name, class_predicted, k=5)
|
| 1107 |
+
|
| 1108 |
+
def update_closest_outputs_add(input_image, model_name, class_predicted, current_k, add_k):
|
| 1109 |
+
new_k = int(current_k) + int(add_k)
|
| 1110 |
+
return _compute_closest(input_image, model_name, class_predicted, k=new_k)
|
| 1111 |
+
|
| 1112 |
+
find_closest_btn.click(
|
| 1113 |
+
fn=update_closest_outputs_initial,
|
| 1114 |
+
inputs=[original_image, model_name, class_predicted],
|
| 1115 |
+
outputs=[closest_table, closest_fossils_table, neighbors_k_state],
|
| 1116 |
+
)
|
| 1117 |
+
add_neighbors_btn.click(
|
| 1118 |
+
fn=update_closest_outputs_add,
|
| 1119 |
+
inputs=[original_image, model_name, class_predicted, neighbors_k_state, neighbors_add_k],
|
| 1120 |
+
outputs=[closest_table, closest_fossils_table, neighbors_k_state],
|
| 1121 |
+
)
|
| 1122 |
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
|
| 1123 |
|
| 1124 |
generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram)
|
closest_sample.py
CHANGED
|
@@ -277,7 +277,7 @@ def download_public_image(url, destination_path):
|
|
| 277 |
else:
|
| 278 |
print(f"Failed to download image from bucket. Status code: {response.status_code}")
|
| 279 |
|
| 280 |
-
def get_images(embedding, model_name, predicted_class=None):
|
| 281 |
if model_name in ['Rock 170', 'Mummified 170']:
|
| 282 |
pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
|
| 283 |
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
|
|
@@ -296,10 +296,10 @@ def get_images(embedding, model_name, predicted_class=None):
|
|
| 296 |
families.append(parts[-2] if len(parts) >= 2 else "Unknown")
|
| 297 |
if predicted_class:
|
| 298 |
pca_d = pca_distance_simple_with_predicted(
|
| 299 |
-
pca_fossils, embedding, embedding_fossils_2024,
|
| 300 |
)
|
| 301 |
else:
|
| 302 |
-
pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=
|
| 303 |
local_paths = []
|
| 304 |
classes = []
|
| 305 |
filenames = []
|
|
@@ -322,7 +322,7 @@ def get_images(embedding, model_name, predicted_class=None):
|
|
| 322 |
raise ValueError(f'{model_name} not recognized')
|
| 323 |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
| 324 |
|
| 325 |
-
pca_d = pca_distance(pca_fossils, embedding, embedding_fossils, top_k=
|
| 326 |
|
| 327 |
fossils_paths = fossils_pd['file_name'].values
|
| 328 |
|
|
@@ -374,7 +374,7 @@ def get_images(embedding, model_name, predicted_class=None):
|
|
| 374 |
return classes, local_paths, filenames
|
| 375 |
|
| 376 |
|
| 377 |
-
def get_images_fossils(embedding, model_name, predicted_class=None):
|
| 378 |
"""
|
| 379 |
Use 2024 fossil embeddings (from fossils path) when available;
|
| 380 |
otherwise fall back to original finer.npy + fossils_pd + GCS. Fossils 142 / BEiT only.
|
|
@@ -396,10 +396,10 @@ def get_images_fossils(embedding, model_name, predicted_class=None):
|
|
| 396 |
families.append(parts[-2] if len(parts) >= 2 else "Unknown")
|
| 397 |
if predicted_class:
|
| 398 |
pca_d = pca_distance_simple_with_predicted(
|
| 399 |
-
pca_fossils, embedding, embedding_fossils_2024,
|
| 400 |
)
|
| 401 |
else:
|
| 402 |
-
pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=
|
| 403 |
local_paths = []
|
| 404 |
classes = []
|
| 405 |
filenames = []
|
|
|
|
| 277 |
else:
|
| 278 |
print(f"Failed to download image from bucket. Status code: {response.status_code}")
|
| 279 |
|
| 280 |
+
def get_images(embedding, model_name, predicted_class=None, top_k=5):
|
| 281 |
if model_name in ['Rock 170', 'Mummified 170']:
|
| 282 |
pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
|
| 283 |
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
|
|
|
|
| 296 |
families.append(parts[-2] if len(parts) >= 2 else "Unknown")
|
| 297 |
if predicted_class:
|
| 298 |
pca_d = pca_distance_simple_with_predicted(
|
| 299 |
+
pca_fossils, embedding, embedding_fossils_2024, top_k, families, predicted_class
|
| 300 |
)
|
| 301 |
else:
|
| 302 |
+
pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=top_k)
|
| 303 |
local_paths = []
|
| 304 |
classes = []
|
| 305 |
filenames = []
|
|
|
|
| 322 |
raise ValueError(f'{model_name} not recognized')
|
| 323 |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
| 324 |
|
| 325 |
+
pca_d = pca_distance(pca_fossils, embedding, embedding_fossils, top_k=top_k)
|
| 326 |
|
| 327 |
fossils_paths = fossils_pd['file_name'].values
|
| 328 |
|
|
|
|
| 374 |
return classes, local_paths, filenames
|
| 375 |
|
| 376 |
|
| 377 |
+
def get_images_fossils(embedding, model_name, predicted_class=None, top_k=5):
|
| 378 |
"""
|
| 379 |
Use 2024 fossil embeddings (from fossils path) when available;
|
| 380 |
otherwise fall back to original finer.npy + fossils_pd + GCS. Fossils 142 / BEiT only.
|
|
|
|
| 396 |
families.append(parts[-2] if len(parts) >= 2 else "Unknown")
|
| 397 |
if predicted_class:
|
| 398 |
pca_d = pca_distance_simple_with_predicted(
|
| 399 |
+
pca_fossils, embedding, embedding_fossils_2024, top_k, families, predicted_class
|
| 400 |
)
|
| 401 |
else:
|
| 402 |
+
pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=top_k)
|
| 403 |
local_paths = []
|
| 404 |
classes = []
|
| 405 |
filenames = []
|