Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Fix model loading: BEiT support, token fallback (HF_TOKEN), condition bug
Browse files- app.py +74 -20
- closest_sample.py +4 -4
- inference_resnet.py +3 -3
- inference_resnet_v2.py +3 -3
- inference_sam.py +3 -3
app.py
CHANGED
|
@@ -143,21 +143,17 @@ def get_model(model_name):
|
|
| 143 |
backbone_class=tf.keras.applications.ResNet50V2,
|
| 144 |
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
| 145 |
model.load_weights('model_classification/rock-170.h5')
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
# embedding_depth = 2,
|
| 158 |
-
# n_classes = n_classes)
|
| 159 |
-
# model.load_weights('model_classification/fossil-new.h5')
|
| 160 |
-
elif model_name == 'Fossils 142': # new resnet
|
| 161 |
n_classes = 142
|
| 162 |
from inference_resnet_v2 import get_resnet_model
|
| 163 |
model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
|
|
@@ -246,7 +242,7 @@ def generate_diagram_closest(input_image,model_name,top_k):
|
|
| 246 |
def explain_image(input_image,model_name,explain_method,nb_samples):
|
| 247 |
model,n_classes= get_model(model_name)
|
| 248 |
from explanations import explain
|
| 249 |
-
if model_name
|
| 250 |
size = 384
|
| 251 |
else:
|
| 252 |
size = 600
|
|
@@ -559,10 +555,18 @@ custom_css = """
|
|
| 559 |
margin-left: calc(50% - 50vw);
|
| 560 |
margin-right: calc(50% - 50vw);
|
| 561 |
padding: 48px 0 28px 0;
|
| 562 |
-
background: #14161c
|
| 563 |
-
color: #e9edf5
|
| 564 |
border-bottom: 1px solid rgba(255,255,255,0.08);
|
| 565 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
.hero-inner {
|
| 567 |
max-width: 1200px;
|
| 568 |
margin: 0 auto;
|
|
@@ -575,21 +579,41 @@ custom_css = """
|
|
| 575 |
letter-spacing: -0.03em;
|
| 576 |
color: #e9edf5;
|
| 577 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
.hero-subtitle {
|
| 579 |
margin-top: 10px;
|
| 580 |
font-size: 18px;
|
| 581 |
color: rgba(233,237,245,0.75);
|
| 582 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
.hero-links {
|
| 584 |
margin-top: 18px;
|
| 585 |
font-size: 14px;
|
| 586 |
color: rgba(233,237,245,0.70);
|
| 587 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
.hero-links a {
|
| 589 |
color: #9ae66e;
|
| 590 |
text-decoration: none;
|
| 591 |
font-weight: 600;
|
| 592 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
.hero-links a:hover {
|
| 594 |
text-decoration: underline;
|
| 595 |
}
|
|
@@ -597,11 +621,21 @@ custom_css = """
|
|
| 597 |
margin: 0 10px;
|
| 598 |
color: rgba(233,237,245,0.35);
|
| 599 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
.hero-guide {
|
| 601 |
margin-top: 24px;
|
| 602 |
padding-top: 18px;
|
| 603 |
border-top: 1px solid rgba(233,237,245,0.12);
|
| 604 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
.hero-guide-grid {
|
| 606 |
display: grid;
|
| 607 |
grid-template-columns: 1fr 1fr;
|
|
@@ -615,6 +649,11 @@ custom_css = """
|
|
| 615 |
text-transform: uppercase;
|
| 616 |
margin-bottom: 8px;
|
| 617 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
.hero-guide-ul,
|
| 619 |
.hero-guide-ol {
|
| 620 |
margin: 0;
|
|
@@ -623,10 +662,26 @@ custom_css = """
|
|
| 623 |
font-size: 14px;
|
| 624 |
line-height: 1.45;
|
| 625 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
.hero-guide-ul b,
|
| 627 |
.hero-guide-ol b {
|
| 628 |
color: rgba(233,237,245,0.92);
|
| 629 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
@media (max-width: 900px) {
|
| 631 |
.hero-title { font-size: 52px; }
|
| 632 |
.hero-guide-grid { grid-template-columns: 1fr; }
|
|
@@ -693,8 +748,7 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
| 693 |
with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
|
| 694 |
gr.Markdown(
|
| 695 |
"<p style='font-size: 13px; margin-bottom: 10px;'>"
|
| 696 |
-
"These thumbnails are sourced from the Florissant
|
| 697 |
-
"and exclude <b>Not Applicable</b>. "
|
| 698 |
"For full context pages, use: "
|
| 699 |
"<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
|
| 700 |
"</p>"
|
|
|
|
| 143 |
backbone_class=tf.keras.applications.ResNet50V2,
|
| 144 |
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
| 145 |
model.load_weights('model_classification/rock-170.h5')
|
| 146 |
+
elif model_name == 'Fossils BEiT':
|
| 147 |
+
n_classes = 142
|
| 148 |
+
from inference_resnet import _ensure_models_downloaded
|
| 149 |
+
from inference_beit import get_triplet_model_beit
|
| 150 |
+
_ensure_models_downloaded()
|
| 151 |
+
model = get_triplet_model_beit(input_shape=(384, 384, 3),
|
| 152 |
+
embedding_units=256,
|
| 153 |
+
embedding_depth=2,
|
| 154 |
+
n_classes=n_classes)
|
| 155 |
+
model.load_weights('model_classification/fossil-142.h5')
|
| 156 |
+
elif model_name == 'Fossils 142': # resnet v2 (full model from file)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
n_classes = 142
|
| 158 |
from inference_resnet_v2 import get_resnet_model
|
| 159 |
model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
|
|
|
|
| 242 |
def explain_image(input_image,model_name,explain_method,nb_samples):
|
| 243 |
model,n_classes= get_model(model_name)
|
| 244 |
from explanations import explain
|
| 245 |
+
if model_name in ('Fossils BEiT', 'Fossils 142'):
|
| 246 |
size = 384
|
| 247 |
else:
|
| 248 |
size = 600
|
|
|
|
| 555 |
margin-left: calc(50% - 50vw);
|
| 556 |
margin-right: calc(50% - 50vw);
|
| 557 |
padding: 48px 0 28px 0;
|
| 558 |
+
background: #14161c;
|
| 559 |
+
color: #e9edf5;
|
| 560 |
border-bottom: 1px solid rgba(255,255,255,0.08);
|
| 561 |
}
|
| 562 |
+
/* Light mode overrides */
|
| 563 |
+
body:not(.dark) .hero,
|
| 564 |
+
[data-theme="default"] .hero,
|
| 565 |
+
.gradio-container:not(.dark) .hero {
|
| 566 |
+
background: #f5f5f5;
|
| 567 |
+
color: #1a1a1a;
|
| 568 |
+
border-bottom: 1px solid rgba(0,0,0,0.1);
|
| 569 |
+
}
|
| 570 |
.hero-inner {
|
| 571 |
max-width: 1200px;
|
| 572 |
margin: 0 auto;
|
|
|
|
| 579 |
letter-spacing: -0.03em;
|
| 580 |
color: #e9edf5;
|
| 581 |
}
|
| 582 |
+
body:not(.dark) .hero-title,
|
| 583 |
+
[data-theme="default"] .hero-title,
|
| 584 |
+
.gradio-container:not(.dark) .hero-title {
|
| 585 |
+
color: #1a1a1a;
|
| 586 |
+
}
|
| 587 |
.hero-subtitle {
|
| 588 |
margin-top: 10px;
|
| 589 |
font-size: 18px;
|
| 590 |
color: rgba(233,237,245,0.75);
|
| 591 |
}
|
| 592 |
+
body:not(.dark) .hero-subtitle,
|
| 593 |
+
[data-theme="default"] .hero-subtitle,
|
| 594 |
+
.gradio-container:not(.dark) .hero-subtitle {
|
| 595 |
+
color: rgba(26,26,26,0.75);
|
| 596 |
+
}
|
| 597 |
.hero-links {
|
| 598 |
margin-top: 18px;
|
| 599 |
font-size: 14px;
|
| 600 |
color: rgba(233,237,245,0.70);
|
| 601 |
}
|
| 602 |
+
body:not(.dark) .hero-links,
|
| 603 |
+
[data-theme="default"] .hero-links,
|
| 604 |
+
.gradio-container:not(.dark) .hero-links {
|
| 605 |
+
color: rgba(26,26,26,0.70);
|
| 606 |
+
}
|
| 607 |
.hero-links a {
|
| 608 |
color: #9ae66e;
|
| 609 |
text-decoration: none;
|
| 610 |
font-weight: 600;
|
| 611 |
}
|
| 612 |
+
body:not(.dark) .hero-links a,
|
| 613 |
+
[data-theme="default"] .hero-links a,
|
| 614 |
+
.gradio-container:not(.dark) .hero-links a {
|
| 615 |
+
color: #2d7a32;
|
| 616 |
+
}
|
| 617 |
.hero-links a:hover {
|
| 618 |
text-decoration: underline;
|
| 619 |
}
|
|
|
|
| 621 |
margin: 0 10px;
|
| 622 |
color: rgba(233,237,245,0.35);
|
| 623 |
}
|
| 624 |
+
body:not(.dark) .hero-sep,
|
| 625 |
+
[data-theme="default"] .hero-sep,
|
| 626 |
+
.gradio-container:not(.dark) .hero-sep {
|
| 627 |
+
color: rgba(26,26,26,0.35);
|
| 628 |
+
}
|
| 629 |
.hero-guide {
|
| 630 |
margin-top: 24px;
|
| 631 |
padding-top: 18px;
|
| 632 |
border-top: 1px solid rgba(233,237,245,0.12);
|
| 633 |
}
|
| 634 |
+
body:not(.dark) .hero-guide,
|
| 635 |
+
[data-theme="default"] .hero-guide,
|
| 636 |
+
.gradio-container:not(.dark) .hero-guide {
|
| 637 |
+
border-top: 1px solid rgba(0,0,0,0.12);
|
| 638 |
+
}
|
| 639 |
.hero-guide-grid {
|
| 640 |
display: grid;
|
| 641 |
grid-template-columns: 1fr 1fr;
|
|
|
|
| 649 |
text-transform: uppercase;
|
| 650 |
margin-bottom: 8px;
|
| 651 |
}
|
| 652 |
+
body:not(.dark) .hero-guide-h,
|
| 653 |
+
[data-theme="default"] .hero-guide-h,
|
| 654 |
+
.gradio-container:not(.dark) .hero-guide-h {
|
| 655 |
+
color: rgba(26,26,26,0.92);
|
| 656 |
+
}
|
| 657 |
.hero-guide-ul,
|
| 658 |
.hero-guide-ol {
|
| 659 |
margin: 0;
|
|
|
|
| 662 |
font-size: 14px;
|
| 663 |
line-height: 1.45;
|
| 664 |
}
|
| 665 |
+
body:not(.dark) .hero-guide-ul,
|
| 666 |
+
body:not(.dark) .hero-guide-ol,
|
| 667 |
+
[data-theme="default"] .hero-guide-ul,
|
| 668 |
+
[data-theme="default"] .hero-guide-ol,
|
| 669 |
+
.gradio-container:not(.dark) .hero-guide-ul,
|
| 670 |
+
.gradio-container:not(.dark) .hero-guide-ol {
|
| 671 |
+
color: rgba(26,26,26,0.78);
|
| 672 |
+
}
|
| 673 |
.hero-guide-ul b,
|
| 674 |
.hero-guide-ol b {
|
| 675 |
color: rgba(233,237,245,0.92);
|
| 676 |
}
|
| 677 |
+
body:not(.dark) .hero-guide-ul b,
|
| 678 |
+
body:not(.dark) .hero-guide-ol b,
|
| 679 |
+
[data-theme="default"] .hero-guide-ul b,
|
| 680 |
+
[data-theme="default"] .hero-guide-ol b,
|
| 681 |
+
.gradio-container:not(.dark) .hero-guide-ul b,
|
| 682 |
+
.gradio-container:not(.dark) .hero-guide-ol b {
|
| 683 |
+
color: rgba(26,26,26,0.92);
|
| 684 |
+
}
|
| 685 |
@media (max-width: 900px) {
|
| 686 |
.hero-title { font-size: 52px; }
|
| 687 |
.hero-guide-grid { grid-template-columns: 1fr; }
|
|
|
|
| 748 |
with gr.Accordion("📸 Browse Florissant fossils (non-NA)", open=True):
|
| 749 |
gr.Markdown(
|
| 750 |
"<p style='font-size: 13px; margin-bottom: 10px;'>"
|
| 751 |
+
"These thumbnails are sourced from the Florissant dataset, from specimens where there are doubts about their family. "
|
|
|
|
| 752 |
"For full context pages, use: "
|
| 753 |
"<a href='https://serre-lab.github.io/FossilLeafLens/' target='_blank'>Fossil Leaf Lens</a>."
|
| 754 |
"</p>"
|
closest_sample.py
CHANGED
|
@@ -93,13 +93,13 @@ def load_pickle_safe(filepath):
|
|
| 93 |
if not os.path.exists("dataset"):
|
| 94 |
# Avoid downloading large datasets automatically during local runs.
|
| 95 |
# If needed, set READ_TOKEN and download manually, or run in Spaces.
|
| 96 |
-
token = os.environ.get("READ_TOKEN")
|
| 97 |
if token:
|
| 98 |
REPO_ID = "Serrelab/Fossils"
|
| 99 |
print(f"Read token:{token}")
|
| 100 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="dataset", local_dir="dataset")
|
| 101 |
else:
|
| 102 |
-
print("WARNING: dataset/ not found and READ_TOKEN is not set. Closest-sample features may fail until the dataset is available.")
|
| 103 |
|
| 104 |
|
| 105 |
fossils_pd= pd.read_csv(os.path.join(os.path.dirname(__file__), "data", "all_fossils_filtered_100.csv"))
|
|
@@ -196,7 +196,7 @@ def get_images(embedding,model_name):
|
|
| 196 |
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
|
| 197 |
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
|
| 198 |
#embedding_leaves = np.load('embedding_leaves.npy')
|
| 199 |
-
elif model_name in ['Fossils 142']:
|
| 200 |
pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
|
| 201 |
pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
|
| 202 |
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
|
|
@@ -265,7 +265,7 @@ def get_diagram(embedding,top_k,model_name):
|
|
| 265 |
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
|
| 266 |
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
|
| 267 |
#embedding_leaves = np.load('embedding_leaves.npy')
|
| 268 |
-
elif model_name in ['Fossils 142']:
|
| 269 |
pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
|
| 270 |
pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
|
| 271 |
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
|
|
|
|
| 93 |
if not os.path.exists("dataset"):
|
| 94 |
# Avoid downloading large datasets automatically during local runs.
|
| 95 |
# If needed, set READ_TOKEN and download manually, or run in Spaces.
|
| 96 |
+
token = os.environ.get("READ_TOKEN") or os.environ.get("HF_TOKEN")
|
| 97 |
if token:
|
| 98 |
REPO_ID = "Serrelab/Fossils"
|
| 99 |
print(f"Read token:{token}")
|
| 100 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="dataset", local_dir="dataset")
|
| 101 |
else:
|
| 102 |
+
print("WARNING: dataset/ not found and READ_TOKEN (or HF_TOKEN) is not set. Closest-sample features may fail until the dataset is available.")
|
| 103 |
|
| 104 |
|
| 105 |
fossils_pd= pd.read_csv(os.path.join(os.path.dirname(__file__), "data", "all_fossils_filtered_100.csv"))
|
|
|
|
| 196 |
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
|
| 197 |
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
|
| 198 |
#embedding_leaves = np.load('embedding_leaves.npy')
|
| 199 |
+
elif model_name in ['Fossils 142', 'Fossils BEiT']:
|
| 200 |
pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
|
| 201 |
pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
|
| 202 |
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
|
|
|
|
| 265 |
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
|
| 266 |
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
|
| 267 |
#embedding_leaves = np.load('embedding_leaves.npy')
|
| 268 |
+
elif model_name in ['Fossils 142', 'Fossils BEiT']:
|
| 269 |
pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
|
| 270 |
pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
|
| 271 |
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
|
inference_resnet.py
CHANGED
|
@@ -22,11 +22,11 @@ def _ensure_models_downloaded():
|
|
| 22 |
if os.path.exists("model_classification"):
|
| 23 |
return
|
| 24 |
REPO_ID = "Serrelab/fossil_classification_models"
|
| 25 |
-
token = os.getenv("READ_TOKEN")
|
| 26 |
if token is None:
|
| 27 |
raise RuntimeError(
|
| 28 |
-
"model_classification/ is missing and READ_TOKEN is not set. "
|
| 29 |
-
"Set READ_TOKEN in .env
|
| 30 |
)
|
| 31 |
print("read token:", token)
|
| 32 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
|
|
|
|
| 22 |
if os.path.exists("model_classification"):
|
| 23 |
return
|
| 24 |
REPO_ID = "Serrelab/fossil_classification_models"
|
| 25 |
+
token = os.getenv("READ_TOKEN") or os.getenv("HF_TOKEN")
|
| 26 |
if token is None:
|
| 27 |
raise RuntimeError(
|
| 28 |
+
"model_classification/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
|
| 29 |
+
"Set READ_TOKEN in .env or HF_TOKEN on Spaces to download models."
|
| 30 |
)
|
| 31 |
print("read token:", token)
|
| 32 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
|
inference_resnet_v2.py
CHANGED
|
@@ -22,11 +22,11 @@ def _ensure_models_downloaded():
|
|
| 22 |
if os.path.exists("model_classification"):
|
| 23 |
return
|
| 24 |
REPO_ID = "Serrelab/fossil_classification_models"
|
| 25 |
-
token = os.getenv("READ_TOKEN")
|
| 26 |
if token is None:
|
| 27 |
raise RuntimeError(
|
| 28 |
-
"model_classification/ is missing and READ_TOKEN is not set. "
|
| 29 |
-
"Set READ_TOKEN in .env
|
| 30 |
)
|
| 31 |
print("read token:", token)
|
| 32 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
|
|
|
|
| 22 |
if os.path.exists("model_classification"):
|
| 23 |
return
|
| 24 |
REPO_ID = "Serrelab/fossil_classification_models"
|
| 25 |
+
token = os.getenv("READ_TOKEN") or os.getenv("HF_TOKEN")
|
| 26 |
if token is None:
|
| 27 |
raise RuntimeError(
|
| 28 |
+
"model_classification/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
|
| 29 |
+
"Set READ_TOKEN in .env or HF_TOKEN on Spaces to download models."
|
| 30 |
)
|
| 31 |
print("read token:", token)
|
| 32 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
|
inference_sam.py
CHANGED
|
@@ -29,11 +29,11 @@ def _ensure_sam_downloaded():
|
|
| 29 |
if os.path.exists("model"):
|
| 30 |
return
|
| 31 |
REPO_ID = "Serrelab/SAM_Leaves"
|
| 32 |
-
token = os.environ.get("READ_TOKEN")
|
| 33 |
if token is None:
|
| 34 |
raise RuntimeError(
|
| 35 |
-
"model/ is missing and READ_TOKEN is not set. "
|
| 36 |
-
"Set READ_TOKEN in .env
|
| 37 |
)
|
| 38 |
print(f"Read token:{token}")
|
| 39 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model")
|
|
|
|
| 29 |
if os.path.exists("model"):
|
| 30 |
return
|
| 31 |
REPO_ID = "Serrelab/SAM_Leaves"
|
| 32 |
+
token = os.environ.get("READ_TOKEN") or os.environ.get("HF_TOKEN")
|
| 33 |
if token is None:
|
| 34 |
raise RuntimeError(
|
| 35 |
+
"model/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
|
| 36 |
+
"Set READ_TOKEN in .env or HF_TOKEN on Spaces to download SAM weights."
|
| 37 |
)
|
| 38 |
print(f"Read token:{token}")
|
| 39 |
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model")
|