piperod91 commited on
Commit
29697f1
·
1 Parent(s): 23d486f

Fix model loading: BEiT support, token fallback (HF_TOKEN), condition bug

Browse files
app.py CHANGED
@@ -143,21 +143,17 @@ def get_model(model_name):
143
  backbone_class=tf.keras.applications.ResNet50V2,
144
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
145
  model.load_weights('model_classification/rock-170.h5')
146
- # elif model_name == 'Fossils 142': #BEiT
147
- # n_classes = 142
148
- # model = get_triplet_model_beit(input_shape = (384, 384, 3),
149
- # embedding_units = 256,
150
- # embedding_depth = 2,
151
- # n_classes = n_classes)
152
- # model.load_weights('model_classification/fossil-142.h5')
153
- # elif model_name == 'Fossils new': # BEiT-v2
154
- # n_classes = 142
155
- # model = get_triplet_model_beit(input_shape = (384, 384, 3),
156
- # embedding_units = 256,
157
- # embedding_depth = 2,
158
- # n_classes = n_classes)
159
- # model.load_weights('model_classification/fossil-new.h5')
160
- elif model_name == 'Fossils 142': # new resnet
161
  n_classes = 142
162
  from inference_resnet_v2 import get_resnet_model
163
  model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
@@ -246,7 +242,7 @@ def generate_diagram_closest(input_image,model_name,top_k):
246
  def explain_image(input_image,model_name,explain_method,nb_samples):
247
  model,n_classes= get_model(model_name)
248
  from explanations import explain
249
- if model_name=='Fossils BEiT' or 'Fossils 142':
250
  size = 384
251
  else:
252
  size = 600
@@ -559,10 +555,18 @@ custom_css = """
559
  margin-left: calc(50% - 50vw);
560
  margin-right: calc(50% - 50vw);
561
  padding: 48px 0 28px 0;
562
- background: #14161c !important; /* force dark header in light mode */
563
- color: #e9edf5 !important;
564
  border-bottom: 1px solid rgba(255,255,255,0.08);
565
  }
 
 
 
 
 
 
 
 
566
  .hero-inner {
567
  max-width: 1200px;
568
  margin: 0 auto;
@@ -575,21 +579,41 @@ custom_css = """
575
  letter-spacing: -0.03em;
576
  color: #e9edf5;
577
  }
 
 
 
 
 
578
  .hero-subtitle {
579
  margin-top: 10px;
580
  font-size: 18px;
581
  color: rgba(233,237,245,0.75);
582
  }
 
 
 
 
 
583
  .hero-links {
584
  margin-top: 18px;
585
  font-size: 14px;
586
  color: rgba(233,237,245,0.70);
587
  }
 
 
 
 
 
588
  .hero-links a {
589
  color: #9ae66e;
590
  text-decoration: none;
591
  font-weight: 600;
592
  }
 
 
 
 
 
593
  .hero-links a:hover {
594
  text-decoration: underline;
595
  }
@@ -597,11 +621,21 @@ custom_css = """
597
  margin: 0 10px;
598
  color: rgba(233,237,245,0.35);
599
  }
 
 
 
 
 
600
  .hero-guide {
601
  margin-top: 24px;
602
  padding-top: 18px;
603
  border-top: 1px solid rgba(233,237,245,0.12);
604
  }
 
 
 
 
 
605
  .hero-guide-grid {
606
  display: grid;
607
  grid-template-columns: 1fr 1fr;
@@ -615,6 +649,11 @@ custom_css = """
615
  text-transform: uppercase;
616
  margin-bottom: 8px;
617
  }
 
 
 
 
 
618
  .hero-guide-ul,
619
  .hero-guide-ol {
620
  margin: 0;
@@ -623,10 +662,26 @@ custom_css = """
623
  font-size: 14px;
624
  line-height: 1.45;
625
  }
 
 
 
 
 
 
 
 
626
  .hero-guide-ul b,
627
  .hero-guide-ol b {
628
  color: rgba(233,237,245,0.92);
629
  }
 
 
 
 
 
 
 
 
630
  @media (max-width: 900px) {
631
  .hero-title { font-size: 52px; }
632
  .hero-guide-grid { grid-template-columns: 1fr; }
@@ -693,8 +748,7 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
693
  with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
694
  gr.Markdown(
695
  "<p style='font-size: 13px; margin-bottom: 10px;'>"
696
- "These thumbnails are sourced from the Florissant unknown-fossils list "
697
- "and exclude <b>Not Applicable</b>. "
698
  "For full context pages, use: "
699
  "<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
700
  "</p>"
 
143
  backbone_class=tf.keras.applications.ResNet50V2,
144
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
145
  model.load_weights('model_classification/rock-170.h5')
146
+ elif model_name == 'Fossils BEiT':
147
+ n_classes = 142
148
+ from inference_resnet import _ensure_models_downloaded
149
+ from inference_beit import get_triplet_model_beit
150
+ _ensure_models_downloaded()
151
+ model = get_triplet_model_beit(input_shape=(384, 384, 3),
152
+ embedding_units=256,
153
+ embedding_depth=2,
154
+ n_classes=n_classes)
155
+ model.load_weights('model_classification/fossil-142.h5')
156
+ elif model_name == 'Fossils 142': # resnet v2 (full model from file)
 
 
 
 
157
  n_classes = 142
158
  from inference_resnet_v2 import get_resnet_model
159
  model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
 
242
  def explain_image(input_image,model_name,explain_method,nb_samples):
243
  model,n_classes= get_model(model_name)
244
  from explanations import explain
245
+ if model_name in ('Fossils BEiT', 'Fossils 142'):
246
  size = 384
247
  else:
248
  size = 600
 
555
  margin-left: calc(50% - 50vw);
556
  margin-right: calc(50% - 50vw);
557
  padding: 48px 0 28px 0;
558
+ background: #14161c;
559
+ color: #e9edf5;
560
  border-bottom: 1px solid rgba(255,255,255,0.08);
561
  }
562
+ /* Light mode overrides */
563
+ body:not(.dark) .hero,
564
+ [data-theme="default"] .hero,
565
+ .gradio-container:not(.dark) .hero {
566
+ background: #f5f5f5;
567
+ color: #1a1a1a;
568
+ border-bottom: 1px solid rgba(0,0,0,0.1);
569
+ }
570
  .hero-inner {
571
  max-width: 1200px;
572
  margin: 0 auto;
 
579
  letter-spacing: -0.03em;
580
  color: #e9edf5;
581
  }
582
+ body:not(.dark) .hero-title,
583
+ [data-theme="default"] .hero-title,
584
+ .gradio-container:not(.dark) .hero-title {
585
+ color: #1a1a1a;
586
+ }
587
  .hero-subtitle {
588
  margin-top: 10px;
589
  font-size: 18px;
590
  color: rgba(233,237,245,0.75);
591
  }
592
+ body:not(.dark) .hero-subtitle,
593
+ [data-theme="default"] .hero-subtitle,
594
+ .gradio-container:not(.dark) .hero-subtitle {
595
+ color: rgba(26,26,26,0.75);
596
+ }
597
  .hero-links {
598
  margin-top: 18px;
599
  font-size: 14px;
600
  color: rgba(233,237,245,0.70);
601
  }
602
+ body:not(.dark) .hero-links,
603
+ [data-theme="default"] .hero-links,
604
+ .gradio-container:not(.dark) .hero-links {
605
+ color: rgba(26,26,26,0.70);
606
+ }
607
  .hero-links a {
608
  color: #9ae66e;
609
  text-decoration: none;
610
  font-weight: 600;
611
  }
612
+ body:not(.dark) .hero-links a,
613
+ [data-theme="default"] .hero-links a,
614
+ .gradio-container:not(.dark) .hero-links a {
615
+ color: #2d7a32;
616
+ }
617
  .hero-links a:hover {
618
  text-decoration: underline;
619
  }
 
621
  margin: 0 10px;
622
  color: rgba(233,237,245,0.35);
623
  }
624
+ body:not(.dark) .hero-sep,
625
+ [data-theme="default"] .hero-sep,
626
+ .gradio-container:not(.dark) .hero-sep {
627
+ color: rgba(26,26,26,0.35);
628
+ }
629
  .hero-guide {
630
  margin-top: 24px;
631
  padding-top: 18px;
632
  border-top: 1px solid rgba(233,237,245,0.12);
633
  }
634
+ body:not(.dark) .hero-guide,
635
+ [data-theme="default"] .hero-guide,
636
+ .gradio-container:not(.dark) .hero-guide {
637
+ border-top: 1px solid rgba(0,0,0,0.12);
638
+ }
639
  .hero-guide-grid {
640
  display: grid;
641
  grid-template-columns: 1fr 1fr;
 
649
  text-transform: uppercase;
650
  margin-bottom: 8px;
651
  }
652
+ body:not(.dark) .hero-guide-h,
653
+ [data-theme="default"] .hero-guide-h,
654
+ .gradio-container:not(.dark) .hero-guide-h {
655
+ color: rgba(26,26,26,0.92);
656
+ }
657
  .hero-guide-ul,
658
  .hero-guide-ol {
659
  margin: 0;
 
662
  font-size: 14px;
663
  line-height: 1.45;
664
  }
665
+ body:not(.dark) .hero-guide-ul,
666
+ body:not(.dark) .hero-guide-ol,
667
+ [data-theme="default"] .hero-guide-ul,
668
+ [data-theme="default"] .hero-guide-ol,
669
+ .gradio-container:not(.dark) .hero-guide-ul,
670
+ .gradio-container:not(.dark) .hero-guide-ol {
671
+ color: rgba(26,26,26,0.78);
672
+ }
673
  .hero-guide-ul b,
674
  .hero-guide-ol b {
675
  color: rgba(233,237,245,0.92);
676
  }
677
+ body:not(.dark) .hero-guide-ul b,
678
+ body:not(.dark) .hero-guide-ol b,
679
+ [data-theme="default"] .hero-guide-ul b,
680
+ [data-theme="default"] .hero-guide-ol b,
681
+ .gradio-container:not(.dark) .hero-guide-ul b,
682
+ .gradio-container:not(.dark) .hero-guide-ol b {
683
+ color: rgba(26,26,26,0.92);
684
+ }
685
  @media (max-width: 900px) {
686
  .hero-title { font-size: 52px; }
687
  .hero-guide-grid { grid-template-columns: 1fr; }
 
748
  with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
749
  gr.Markdown(
750
  "<p style='font-size: 13px; margin-bottom: 10px;'>"
751
+ "These thumbnails are sourced from the Florissant dataset, from specimens where there are doubts about their family. "
 
752
  "For full context pages, use: "
753
  "<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
754
  "</p>"
closest_sample.py CHANGED
@@ -93,13 +93,13 @@ def load_pickle_safe(filepath):
93
  if not os.path.exists("dataset"):
94
  # Avoid downloading large datasets automatically during local runs.
95
  # If needed, set READ_TOKEN and download manually, or run in Spaces.
96
- token = os.environ.get("READ_TOKEN")
97
  if token:
98
  REPO_ID = "Serrelab/Fossils"
99
  print(f"Read token:{token}")
100
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="dataset", local_dir="dataset")
101
  else:
102
- print("WARNING: dataset/ not found and READ_TOKEN is not set. Closest-sample features may fail until the dataset is available.")
103
 
104
 
105
  fossils_pd= pd.read_csv(os.path.join(os.path.dirname(__file__), "data", "all_fossils_filtered_100.csv"))
@@ -196,7 +196,7 @@ def get_images(embedding,model_name):
196
  pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
197
  embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
198
  #embedding_leaves = np.load('embedding_leaves.npy')
199
- elif model_name in ['Fossils 142']:
200
  pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
201
  pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
202
  embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
@@ -265,7 +265,7 @@ def get_diagram(embedding,top_k,model_name):
265
  pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
266
  embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
267
  #embedding_leaves = np.load('embedding_leaves.npy')
268
- elif model_name in ['Fossils 142']:
269
  pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
270
  pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
271
  embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
 
93
  if not os.path.exists("dataset"):
94
  # Avoid downloading large datasets automatically during local runs.
95
  # If needed, set READ_TOKEN and download manually, or run in Spaces.
96
+ token = os.environ.get("READ_TOKEN") or os.environ.get("HF_TOKEN")
97
  if token:
98
  REPO_ID = "Serrelab/Fossils"
99
  print(f"Read token:{token}")
100
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="dataset", local_dir="dataset")
101
  else:
102
+ print("WARNING: dataset/ not found and READ_TOKEN (or HF_TOKEN) is not set. Closest-sample features may fail until the dataset is available.")
103
 
104
 
105
  fossils_pd= pd.read_csv(os.path.join(os.path.dirname(__file__), "data", "all_fossils_filtered_100.csv"))
 
196
  pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
197
  embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
198
  #embedding_leaves = np.load('embedding_leaves.npy')
199
+ elif model_name in ['Fossils 142', 'Fossils BEiT']:
200
  pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
201
  pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
202
  embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
 
265
  pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
266
  embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
267
  #embedding_leaves = np.load('embedding_leaves.npy')
268
+ elif model_name in ['Fossils 142', 'Fossils BEiT']:
269
  pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
270
  pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
271
  embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
inference_resnet.py CHANGED
@@ -22,11 +22,11 @@ def _ensure_models_downloaded():
22
  if os.path.exists("model_classification"):
23
  return
24
  REPO_ID = "Serrelab/fossil_classification_models"
25
- token = os.getenv("READ_TOKEN")
26
  if token is None:
27
  raise RuntimeError(
28
- "model_classification/ is missing and READ_TOKEN is not set. "
29
- "Set READ_TOKEN in .env to download models, or copy models into model_classification/."
30
  )
31
  print("read token:", token)
32
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
 
22
  if os.path.exists("model_classification"):
23
  return
24
  REPO_ID = "Serrelab/fossil_classification_models"
25
+ token = os.getenv("READ_TOKEN") or os.getenv("HF_TOKEN")
26
  if token is None:
27
  raise RuntimeError(
28
+ "model_classification/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
29
+ "Set READ_TOKEN in .env or HF_TOKEN on Spaces to download models."
30
  )
31
  print("read token:", token)
32
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
inference_resnet_v2.py CHANGED
@@ -22,11 +22,11 @@ def _ensure_models_downloaded():
22
  if os.path.exists("model_classification"):
23
  return
24
  REPO_ID = "Serrelab/fossil_classification_models"
25
- token = os.getenv("READ_TOKEN")
26
  if token is None:
27
  raise RuntimeError(
28
- "model_classification/ is missing and READ_TOKEN is not set. "
29
- "Set READ_TOKEN in .env to download models, or copy models into model_classification/."
30
  )
31
  print("read token:", token)
32
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
 
22
  if os.path.exists("model_classification"):
23
  return
24
  REPO_ID = "Serrelab/fossil_classification_models"
25
+ token = os.getenv("READ_TOKEN") or os.getenv("HF_TOKEN")
26
  if token is None:
27
  raise RuntimeError(
28
+ "model_classification/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
29
+ "Set READ_TOKEN in .env or HF_TOKEN on Spaces to download models."
30
  )
31
  print("read token:", token)
32
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
inference_sam.py CHANGED
@@ -29,11 +29,11 @@ def _ensure_sam_downloaded():
29
  if os.path.exists("model"):
30
  return
31
  REPO_ID = "Serrelab/SAM_Leaves"
32
- token = os.environ.get("READ_TOKEN")
33
  if token is None:
34
  raise RuntimeError(
35
- "model/ is missing and READ_TOKEN is not set. "
36
- "Set READ_TOKEN in .env to download SAM weights, or copy them into model/."
37
  )
38
  print(f"Read token:{token}")
39
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model")
 
29
  if os.path.exists("model"):
30
  return
31
  REPO_ID = "Serrelab/SAM_Leaves"
32
+ token = os.environ.get("READ_TOKEN") or os.environ.get("HF_TOKEN")
33
  if token is None:
34
  raise RuntimeError(
35
+ "model/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
36
+ "Set READ_TOKEN in .env or HF_TOKEN on Spaces to download SAM weights."
37
  )
38
  print(f"Read token:{token}")
39
  snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model")