piperod91 commited on
Commit
22fee3c
·
1 Parent(s): 689cd27

Closest neighbors: start at 5 and allow 'Add K more neighbors' with dynamic K and tables

Browse files
Files changed (2) hide show
  1. app.py +56 -21
  2. closest_sample.py +7 -7
app.py CHANGED
@@ -76,7 +76,7 @@ USER_GUIDE = """
76
 
77
  - **Input**: Upload a fossil-leaf photograph (ideally a tightly cropped Florissant compression fossil with minimal background clutter).\n
78
  - **Output A — predicted families**: The model suggests the top-ranked plant families (as hypotheses to guide expert review).\n
79
- - **Output B — closest training fossils**: The most similar Florissant fossils in the reference set (for qualitative comparison).\n
80
  - **Output C — explanations**: Heatmaps showing image regions that support each suggested family (qualitative).\n
81
 
82
  ### How to use
@@ -251,17 +251,17 @@ def _top_predicted_class(class_predicted):
251
  return None
252
 
253
 
254
- def find_closest(input_image, model_name, predicted_class=None):
255
  embedding = get_embeddings(input_image, model_name)
256
  from closest_sample import get_images, get_images_fossils
257
- classes, paths, filenames = get_images(embedding, model_name, predicted_class=predicted_class)
258
  return classes, paths, filenames
259
 
260
 
261
- def find_closest_fossils(input_image, model_name, predicted_class=None):
262
  embedding = get_embeddings(input_image, model_name)
263
  from closest_sample import get_images_fossils
264
- return get_images_fossils(embedding, model_name, predicted_class=predicted_class)
265
 
266
  def generate_diagram_closest(input_image,model_name,top_k):
267
  embedding = get_embeddings(input_image,model_name)
@@ -487,17 +487,17 @@ def setup_examples():
487
  # Gradio Examples can handle URLs directly - they will fetch and display the images
488
  # Pass URLs as the first argument - Gradio will automatically fetch and display them
489
  # Note: Gradio downloads URLs to temp directory, which is normal behavior
490
- print(f"DEBUG: Final fossil_samples count: {len(fossil_samples)}")
491
  if len(fossil_samples) > 0:
492
  print(f"DEBUG: First fossil sample (should be URL): {fossil_samples[0]}")
493
  print(f"DEBUG: Is URL: {fossil_samples[0].startswith('http') if fossil_samples else False}")
494
 
495
  examples_fossils = gr.Examples(
496
- fossil_samples,
497
  inputs=input_image,
498
  examples_per_page=6, # Reduced for better spacing and organization
499
- label='Leaf fossil examples from the dataset',
500
- elem_id="fossil-examples"
501
  )
502
  return examples_fossils
503
 
@@ -842,7 +842,9 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
842
  with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
843
  gr.Markdown(
844
  "<p style='font-size: 13px; margin-bottom: 10px;'>"
845
- "These thumbnails are sourced from the Florissant dataset, from specimens where there are doubts about their family. "
 
 
846
  "For full context pages, use: "
847
  "<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
848
  "</p>"
@@ -971,20 +973,28 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
971
  # label_closest_image_4 = gr.Markdown('')
972
  # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
973
  # find_closest_btn = gr.Button("Find Closest Images")
 
 
 
974
  with gr.Accordion('Closest Extant Leaves'):
975
  gr.Markdown(
976
- "5 closest extant leaves (2024 dataset). **Ordered from most to least similar.** "
 
 
977
  "Neighbors are chosen by *embedding* similarity (visual), not by predicted family; "
978
  "if you ran **Classify Image** first, we try to include at least one from the top-predicted family when possible."
979
  )
980
  closest_table = gr.HTML(label="Closest Extant Table")
981
  with gr.Accordion('Closest Fossils'):
982
  gr.Markdown(
983
- "5 closest fossil images (reference fossil dataset). **Ordered from most to least similar.** "
984
- "Same logic: embedding similarity + at least one from predicted family when you have run Classify Image."
 
 
985
  )
986
  closest_fossils_table = gr.HTML(label="Closest Fossils Table")
987
  find_closest_btn = gr.Button("Find Closest Images", icon="https://www.svgrepo.com/show/13672/play-button.svg")
 
988
 
989
  #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
990
  classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted)
@@ -992,7 +1002,10 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
992
  # with gr.Accordion('Closest Leaves Images'):
993
  # gr.Markdown("5 closest leaves")
994
  with gr.Accordion("Family Distribution of Closest Samples "):
995
- gr.Markdown("Visualize plant family distribution of top-k closest samples in our dataset")
 
 
 
996
  with gr.Column():
997
  with gr.Row():
998
  diagram= gr.Image(label = 'Bar Chart')
@@ -1042,7 +1055,8 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
1042
  <thead><tr><th>Rank (1 = most similar)</th><th>Image</th><th>Plant Family</th><th>Specimen Name</th></tr></thead>
1043
  <tbody>
1044
  """
1045
- for i in range(5):
 
1046
  rank = i + 1
1047
  img_src = ""
1048
  if i < len(images):
@@ -1077,13 +1091,34 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
1077
  table_html += "</tbody></table>"
1078
  return table_html
1079
 
1080
- def update_closest_outputs(input_image, model_name, class_predicted):
1081
  predicted_class = _top_predicted_class(class_predicted)
1082
- extant_labels, extant_images, extant_filenames = find_closest(input_image, model_name, predicted_class=predicted_class)
1083
- fossil_labels, fossil_images, fossil_filenames = find_closest_fossils(input_image, model_name, predicted_class=predicted_class)
1084
- return _closest_table_html(extant_labels, extant_images, extant_filenames), _closest_table_html(fossil_labels, fossil_images, fossil_filenames)
1085
-
1086
- find_closest_btn.click(fn=update_closest_outputs, inputs=[original_image, model_name, class_predicted], outputs=[closest_table, closest_fossils_table])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
  #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
1088
 
1089
  generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram)
 
76
 
77
  - **Input**: Upload a fossil-leaf photograph (ideally a tightly cropped Florissant compression fossil with minimal background clutter).\n
78
  - **Output A — predicted families**: The model suggests the top-ranked plant families (as hypotheses to guide expert review).\n
79
+ - **Output B — closest training fossils**: The most similar Florissant fossils in [The image dataset](https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698) (for qualitative comparison).\n
80
  - **Output C — explanations**: Heatmaps showing image regions that support each suggested family (qualitative).\n
81
 
82
  ### How to use
 
251
  return None
252
 
253
 
254
+ def find_closest(input_image, model_name, predicted_class=None, top_k=5):
255
  embedding = get_embeddings(input_image, model_name)
256
  from closest_sample import get_images, get_images_fossils
257
+ classes, paths, filenames = get_images(embedding, model_name, predicted_class=predicted_class, top_k=top_k)
258
  return classes, paths, filenames
259
 
260
 
261
+ def find_closest_fossils(input_image, model_name, predicted_class=None, top_k=5):
262
  embedding = get_embeddings(input_image, model_name)
263
  from closest_sample import get_images_fossils
264
+ return get_images_fossils(embedding, model_name, predicted_class=predicted_class, top_k=top_k)
265
 
266
  def generate_diagram_closest(input_image,model_name,top_k):
267
  embedding = get_embeddings(input_image,model_name)
 
487
  # Gradio Examples can handle URLs directly - they will fetch and display the images
488
  # Pass URLs as the first argument - Gradio will automatically fetch and display them
489
  # Note: Gradio downloads URLs to temp directory, which is normal behavior
490
+ print(f"DEBUG: Final fossil_samples count: {len(fossil_samples)}")
491
  if len(fossil_samples) > 0:
492
  print(f"DEBUG: First fossil sample (should be URL): {fossil_samples[0]}")
493
  print(f"DEBUG: Is URL: {fossil_samples[0].startswith('http') if fossil_samples else False}")
494
 
495
  examples_fossils = gr.Examples(
496
+ fossil_samples,
497
  inputs=input_image,
498
  examples_per_page=6, # Reduced for better spacing and organization
499
+ label='Leaf fossil examples from The image dataset',
500
+ elem_id="fossil-examples",
501
  )
502
  return examples_fossils
503
 
 
842
  with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
843
  gr.Markdown(
844
  "<p style='font-size: 13px; margin-bottom: 10px;'>"
845
+ "These thumbnails are sourced from the Florissant subset of "
846
+ "<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>, "
847
+ "from specimens where there are doubts about their family. "
848
  "For full context pages, use: "
849
  "<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
850
  "</p>"
 
973
  # label_closest_image_4 = gr.Markdown('')
974
  # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
975
  # find_closest_btn = gr.Button("Find Closest Images")
976
+ neighbors_k_state = gr.State(5)
977
+ neighbors_add_k = gr.Slider(1, 20, value=5, step=1, label="Add K neighbors", info="Number of additional neighbors to fetch on demand")
978
+
979
  with gr.Accordion('Closest Extant Leaves'):
980
  gr.Markdown(
981
+ "5 closest extant leaves from the extant subset of "
982
+ "<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>. "
983
+ "**Ordered from most to least similar.** "
984
  "Neighbors are chosen by *embedding* similarity (visual), not by predicted family; "
985
  "if you ran **Classify Image** first, we try to include at least one from the top-predicted family when possible."
986
  )
987
  closest_table = gr.HTML(label="Closest Extant Table")
988
  with gr.Accordion('Closest Fossils'):
989
  gr.Markdown(
990
+ "5 closest Florissant fossil images from the fossil subset of "
991
+ "<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>. "
992
+ "**Ordered from most to least similar.** "
993
+ "Same logic: embedding similarity + at least one from predicted family when you have run **Classify Image**."
994
  )
995
  closest_fossils_table = gr.HTML(label="Closest Fossils Table")
996
  find_closest_btn = gr.Button("Find Closest Images", icon="https://www.svgrepo.com/show/13672/play-button.svg")
997
+ add_neighbors_btn = gr.Button("Add K more neighbors", icon="https://www.svgrepo.com/show/13672/play-button.svg")
998
 
999
  #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
1000
  classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted)
 
1002
  # with gr.Accordion('Closest Leaves Images'):
1003
  # gr.Markdown("5 closest leaves")
1004
  with gr.Accordion("Family Distribution of Closest Samples "):
1005
+ gr.Markdown(
1006
+ "Visualize plant family distribution of top-k closest samples in "
1007
+ "<a href='https://plus.figshare.com/articles/dataset/Image_collection_and_supporting_data_for_An_image_dataset_of_cleared_x-rayed_and_fossil_leaves_vetted_to_plant_family_for_human_and_machine_learning/14980698' target='_blank'>The image dataset</a>."
1008
+ )
1009
  with gr.Column():
1010
  with gr.Row():
1011
  diagram= gr.Image(label = 'Bar Chart')
 
1055
  <thead><tr><th>Rank (1 = most similar)</th><th>Image</th><th>Plant Family</th><th>Specimen Name</th></tr></thead>
1056
  <tbody>
1057
  """
1058
+ n = min(len(images), len(labels), len(filenames)) if filenames is not None else min(len(images), len(labels))
1059
+ for i in range(n):
1060
  rank = i + 1
1061
  img_src = ""
1062
  if i < len(images):
 
1091
  table_html += "</tbody></table>"
1092
  return table_html
1093
 
1094
+ def _compute_closest(input_image, model_name, class_predicted, k):
1095
  predicted_class = _top_predicted_class(class_predicted)
1096
+ extant_labels, extant_images, extant_filenames = find_closest(input_image, model_name, predicted_class=predicted_class, top_k=k)
1097
+ fossil_labels, fossil_images, fossil_filenames = find_closest_fossils(input_image, model_name, predicted_class=predicted_class, top_k=k)
1098
+ return (
1099
+ _closest_table_html(extant_labels, extant_images, extant_filenames),
1100
+ _closest_table_html(fossil_labels, fossil_images, fossil_filenames),
1101
+ k,
1102
+ )
1103
+
1104
+ def update_closest_outputs_initial(input_image, model_name, class_predicted):
1105
+ # Reset to 5 neighbors
1106
+ return _compute_closest(input_image, model_name, class_predicted, k=5)
1107
+
1108
+ def update_closest_outputs_add(input_image, model_name, class_predicted, current_k, add_k):
1109
+ new_k = int(current_k) + int(add_k)
1110
+ return _compute_closest(input_image, model_name, class_predicted, k=new_k)
1111
+
1112
+ find_closest_btn.click(
1113
+ fn=update_closest_outputs_initial,
1114
+ inputs=[original_image, model_name, class_predicted],
1115
+ outputs=[closest_table, closest_fossils_table, neighbors_k_state],
1116
+ )
1117
+ add_neighbors_btn.click(
1118
+ fn=update_closest_outputs_add,
1119
+ inputs=[original_image, model_name, class_predicted, neighbors_k_state, neighbors_add_k],
1120
+ outputs=[closest_table, closest_fossils_table, neighbors_k_state],
1121
+ )
1122
  #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
1123
 
1124
  generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram)
closest_sample.py CHANGED
@@ -277,7 +277,7 @@ def download_public_image(url, destination_path):
277
  else:
278
  print(f"Failed to download image from bucket. Status code: {response.status_code}")
279
 
280
- def get_images(embedding, model_name, predicted_class=None):
281
  if model_name in ['Rock 170', 'Mummified 170']:
282
  pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
283
  pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
@@ -296,10 +296,10 @@ def get_images(embedding, model_name, predicted_class=None):
296
  families.append(parts[-2] if len(parts) >= 2 else "Unknown")
297
  if predicted_class:
298
  pca_d = pca_distance_simple_with_predicted(
299
- pca_fossils, embedding, embedding_fossils_2024, 5, families, predicted_class
300
  )
301
  else:
302
- pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=5)
303
  local_paths = []
304
  classes = []
305
  filenames = []
@@ -322,7 +322,7 @@ def get_images(embedding, model_name, predicted_class=None):
322
  raise ValueError(f'{model_name} not recognized')
323
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
324
 
325
- pca_d = pca_distance(pca_fossils, embedding, embedding_fossils, top_k=5)
326
 
327
  fossils_paths = fossils_pd['file_name'].values
328
 
@@ -374,7 +374,7 @@ def get_images(embedding, model_name, predicted_class=None):
374
  return classes, local_paths, filenames
375
 
376
 
377
- def get_images_fossils(embedding, model_name, predicted_class=None):
378
  """
379
  Use 2024 fossil embeddings (from fossils path) when available;
380
  otherwise fall back to original finer.npy + fossils_pd + GCS. Fossils 142 / BEiT only.
@@ -396,10 +396,10 @@ def get_images_fossils(embedding, model_name, predicted_class=None):
396
  families.append(parts[-2] if len(parts) >= 2 else "Unknown")
397
  if predicted_class:
398
  pca_d = pca_distance_simple_with_predicted(
399
- pca_fossils, embedding, embedding_fossils_2024, 5, families, predicted_class
400
  )
401
  else:
402
- pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=5)
403
  local_paths = []
404
  classes = []
405
  filenames = []
 
277
  else:
278
  print(f"Failed to download image from bucket. Status code: {response.status_code}")
279
 
280
+ def get_images(embedding, model_name, predicted_class=None, top_k=5):
281
  if model_name in ['Rock 170', 'Mummified 170']:
282
  pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
283
  pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
 
296
  families.append(parts[-2] if len(parts) >= 2 else "Unknown")
297
  if predicted_class:
298
  pca_d = pca_distance_simple_with_predicted(
299
+ pca_fossils, embedding, embedding_fossils_2024, top_k, families, predicted_class
300
  )
301
  else:
302
+ pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=top_k)
303
  local_paths = []
304
  classes = []
305
  filenames = []
 
322
  raise ValueError(f'{model_name} not recognized')
323
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
324
 
325
+ pca_d = pca_distance(pca_fossils, embedding, embedding_fossils, top_k=top_k)
326
 
327
  fossils_paths = fossils_pd['file_name'].values
328
 
 
374
  return classes, local_paths, filenames
375
 
376
 
377
+ def get_images_fossils(embedding, model_name, predicted_class=None, top_k=5):
378
  """
379
  Use 2024 fossil embeddings (from fossils path) when available;
380
  otherwise fall back to original finer.npy + fossils_pd + GCS. Fossils 142 / BEiT only.
 
396
  families.append(parts[-2] if len(parts) >= 2 else "Unknown")
397
  if predicted_class:
398
  pca_d = pca_distance_simple_with_predicted(
399
+ pca_fossils, embedding, embedding_fossils_2024, top_k, families, predicted_class
400
  )
401
  else:
402
+ pca_d = pca_distance_simple(pca_fossils, embedding, embedding_fossils_2024, top_k=top_k)
403
  local_paths = []
404
  classes = []
405
  filenames = []