CarolinaSMarques's picture
Update app.py
8f35791 verified
raw
history blame
29.9 kB
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""
Paleontology-facing dinosaur footprint classifier
(DenseNet-121 / EfficientNet-B0 + Morphometric models)
- Image Input: one or more images (Photograph / Binarized / Depth map)
- Morphometric Input: one or more .xlsx tables with track measurements
- Output: top-3 most probable classes + probabilities
Controls:
- Data type:
- Photograph
- Binarized
- Depth map
- Morphometric (xlsx morphometric measurements)
- Classification type:
- Ichnogenus (8 classes)
- Theropod/Ornithopod (2 classes)
"""
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import html
import zipfile
import tempfile
import cv2
import joblib
import numpy as np
import pandas as pd
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from tensorflow.keras.models import load_model
# =========================
# Config
# =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 224
MAX_ROWS_DISPLAY = 50
CLASS_NAMES_ICHNO = [
"Anomoepus",
"Caririchnium",
"Dinehichnus",
"Grallator",
"Iguanodontipus",
"Kalohipus",
"Kayentapus",
"Megalosauripus",
]
CLASS_NAMES_TO = [
"Ornithopod",
"Theropod",
]
NUM_CLASSES_ICHNO = len(CLASS_NAMES_ICHNO)
NUM_CLASSES_TO = len(CLASS_NAMES_TO)
MORPH_FEATURE_COLS = [
"FL",
"FW",
"FL_FW",
"LII",
"LIII",
"LIV",
"II_IV",
"II_minus_IV",
]
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
CHECKPOINT_DIR = os.path.join(THIS_DIR, "checkpoints")
# =========================
# Checkpoint paths
# =========================
CHECKPOINTS_ICHNO = {
"Photograph": os.path.join(CHECKPOINT_DIR, "model_checkpoint_3.pth"),
"Binarized": os.path.join(CHECKPOINT_DIR, "model_checkpoint_11.pth"),
"Depth map": os.path.join(CHECKPOINT_DIR, "model_checkpoint_9.pth"),
}
CHECKPOINTS_TO = {
"Photograph": os.path.join(CHECKPOINT_DIR, "model_checkpoint1_7.pth"),
"Binarized": os.path.join(CHECKPOINT_DIR, "model_checkpoint1_5.pth"),
"Depth map": os.path.join(CHECKPOINT_DIR, "model_checkpoint1_3.pth"),
}
MORPHO_ICHNO_MODEL_PATH = os.path.join(
CHECKPOINT_DIR, "model_Random Forest_StandardScaler_SMOTE.joblib"
)
MORPHO_ICHNO_SCALER_PATH = os.path.join(
CHECKPOINT_DIR, "scaler_Random Forest_StandardScaler_SMOTE.joblib"
)
MORPHO_TO_KERAS_PATH = os.path.join(
CHECKPOINT_DIR, "morpho_TO_keras_model.keras"
)
# =========================
# Image preprocessing
# =========================
INFER_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
# =========================
# CNN model definitions
# =========================
def create_densenet121(num_classes: int) -> nn.Module:
model = models.densenet121(pretrained=True)
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(in_features, num_classes),
)
return model
def create_efficientnet_b0(num_classes: int) -> nn.Module:
model = models.efficientnet_b0(pretrained=True)
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(in_features, num_classes),
)
return model
def create_efficientnet_b01(num_classes: int) -> nn.Module:
"""
Custom EfficientNet-B0 head for binarized ichnogenus,
kept aligned with your checkpoint structure.
"""
model = models.efficientnet_b0(pretrained=True)
model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(model.classifier[1].in_features, 2),
)
model.classifier = nn.Sequential(
nn.Linear(1280, num_classes),
)
return model
def get_checkpoint_dict(classification_type: str):
if classification_type == "Theropod/Ornithopod":
return CHECKPOINTS_TO
return CHECKPOINTS_ICHNO
def _safe_torch_load(path: str):
try:
return torch.load(path, map_location="cpu", weights_only=False)
except TypeError:
return torch.load(path, map_location="cpu")
def load_model_for_type(data_type: str, classification_type: str) -> nn.Module:
ckpt_dict = get_checkpoint_dict(classification_type)
if data_type not in ckpt_dict:
raise ValueError(f"Unknown data type: {data_type}")
ckpt_path = ckpt_dict[data_type]
if not os.path.exists(ckpt_path):
raise FileNotFoundError(
f"Checkpoint not found for {classification_type} / {data_type}: {ckpt_path}"
)
if classification_type == "Ichnogenus":
num_classes = NUM_CLASSES_ICHNO
if data_type == "Photograph":
model = create_densenet121(num_classes)
elif data_type == "Binarized":
model = create_efficientnet_b01(num_classes)
elif data_type == "Depth map":
model = create_efficientnet_b0(num_classes)
else:
raise ValueError(f"Unsupported data type: {data_type}")
else:
num_classes = NUM_CLASSES_TO
if data_type in ("Photograph", "Binarized"):
model = create_densenet121(num_classes)
elif data_type == "Depth map":
model = create_efficientnet_b0(num_classes)
else:
raise ValueError(f"Unsupported data type: {data_type}")
ckpt = _safe_torch_load(ckpt_path)
state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
return model
_MODELS = {}
def get_model(data_type: str, classification_type: str) -> nn.Module:
data_type = data_type.strip()
classification_type = classification_type.strip()
if classification_type not in ("Ichnogenus", "Theropod/Ornithopod"):
classification_type = "Ichnogenus"
ckpt_dict = get_checkpoint_dict(classification_type)
if data_type not in ckpt_dict:
data_type = "Photograph"
key = (classification_type, data_type)
if key not in _MODELS:
_MODELS[key] = load_model_for_type(data_type, classification_type)
return _MODELS[key]
# =========================
# Morphometric models
# =========================
_MORPHO_ICHNO_MODEL = None
_MORPHO_ICHNO_SCALER = None
_MORPHO_TO_MODEL = None
def get_morpho_ichno_model_and_scaler():
global _MORPHO_ICHNO_MODEL, _MORPHO_ICHNO_SCALER
if _MORPHO_ICHNO_MODEL is None:
if not os.path.exists(MORPHO_ICHNO_MODEL_PATH):
raise FileNotFoundError(f"Morphometric Ichnogenus model not found: {MORPHO_ICHNO_MODEL_PATH}")
if not os.path.exists(MORPHO_ICHNO_SCALER_PATH):
raise FileNotFoundError(f"Morphometric Ichnogenus scaler not found: {MORPHO_ICHNO_SCALER_PATH}")
_MORPHO_ICHNO_MODEL = joblib.load(MORPHO_ICHNO_MODEL_PATH)
_MORPHO_ICHNO_SCALER = joblib.load(MORPHO_ICHNO_SCALER_PATH)
return _MORPHO_ICHNO_MODEL, _MORPHO_ICHNO_SCALER
def get_morpho_to_model():
global _MORPHO_TO_MODEL
if _MORPHO_TO_MODEL is None:
if not os.path.exists(MORPHO_TO_KERAS_PATH):
raise FileNotFoundError(f"Morphometric TO Keras model not found: {MORPHO_TO_KERAS_PATH}")
_MORPHO_TO_MODEL = load_model(MORPHO_TO_KERAS_PATH, compile=False)
return _MORPHO_TO_MODEL
def _extract_features_from_excel(path: str):
df = pd.read_excel(path)
missing = [c for c in MORPH_FEATURE_COLS if c not in df.columns]
if missing:
raise ValueError(f"{os.path.basename(path)} is missing required columns: {missing}")
X = df[MORPH_FEATURE_COLS].copy()
return X, df
# =========================
# Helper functions
# =========================
def _map_model_classes_to_names(model_classes, classification_type: str):
if classification_type == "Theropod/Ornithopod":
base_names = CLASS_NAMES_TO
else:
base_names = CLASS_NAMES_ICHNO
mapped = []
for c in model_classes:
if isinstance(c, (int, np.integer)) and 0 <= int(c) < len(base_names):
mapped.append(base_names[int(c)])
else:
mapped.append(str(c))
return mapped
def _top3_from_prob_vector(prob_vec: np.ndarray, class_names):
prob_vec = np.asarray(prob_vec).reshape(-1)
n = min(len(prob_vec), len(class_names))
if n == 0:
return [None, None, None], [None, None, None]
prob_vec = prob_vec[:n]
k = min(3, n)
top_idx = np.argsort(-prob_vec)[:k]
top_classes = [class_names[i] for i in top_idx]
top_probs = [float(prob_vec[i]) for i in top_idx]
while len(top_classes) < 3:
top_classes.append(None)
top_probs.append(None)
return top_classes, top_probs
# =========================
# Image prediction helpers
# =========================
@torch.no_grad()
def predict_top3_from_pil(pil_img: Image.Image, data_type: str, classification_type: str):
model = get_model(data_type, classification_type)
class_names = CLASS_NAMES_TO if classification_type == "Theropod/Ornithopod" else CLASS_NAMES_ICHNO
img = pil_img.convert("RGB")
x = INFER_TRANSFORM(img).unsqueeze(0).to(DEVICE)
logits = model(x)
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
return _top3_from_prob_vector(probs, class_names)
def _predict_top_idx_from_pil(pil_img: Image.Image, data_type: str, classification_type: str) -> int:
model = get_model(data_type, classification_type)
model.eval()
img = pil_img.convert("RGB")
x = INFER_TRANSFORM(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(x)
return int(torch.argmax(logits, dim=1).item())
# =========================
# Grad-CAM helpers
# =========================
def get_target_layer_for_gradcam(model):
model_name = model.__class__.__name__.lower()
if "densenet" in model_name and hasattr(model, "features"):
return model.features
if "efficientnet" in model_name and hasattr(model, "features"):
return model.features[-1]
if hasattr(model, "features"):
try:
return model.features[-1]
except Exception:
return model.features
raise ValueError("Could not determine a Grad-CAM target layer for this model.")
def generate_gradcam_from_pil(
pil_img: Image.Image,
data_type: str,
classification_type: str,
target_class_idx=None,
):
model = get_model(data_type, classification_type)
model.eval()
rgb = np.array(pil_img.convert("RGB"))
H, W = rgb.shape[:2]
x = INFER_TRANSFORM(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
x.requires_grad_(True)
target_layer = get_target_layer_for_gradcam(model)
activations = []
gradients = []
def forward_hook(module, inp, out):
activations.append(out)
def save_grad(grad):
gradients.append(grad.detach())
out.register_hook(save_grad)
handle_fwd = target_layer.register_forward_hook(forward_hook)
try:
model.zero_grad(set_to_none=True)
logits = model(x)
if target_class_idx is None:
target_class_idx = int(torch.argmax(logits, dim=1).item())
score = logits[0, target_class_idx]
score.backward()
act = activations[0].detach()
grad = gradients[0]
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * act).sum(dim=1)
cam = torch.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
cam = cv2.resize(cam.astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
heatmap_uint8 = np.uint8(255 * cam)
heatmap_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
overlay = cv2.addWeighted(rgb, 0.55, heatmap_rgb, 0.45, 0)
gradcam_overlay_pil = Image.fromarray(overlay)
gradcam_heatmap_pil = Image.fromarray(heatmap_rgb)
return gradcam_overlay_pil, gradcam_heatmap_pil
finally:
handle_fwd.remove()
# =========================
# HTML table helper
# =========================
def df_to_html_italic(df: pd.DataFrame, classification_type: str) -> str:
if df.empty:
return "<p>No predictions to display yet.</p>"
headers = df.columns.tolist()
header_cells = "".join(f"<th>{html.escape(str(h))}</th>" for h in headers)
rows_html = []
for _, row in df.iterrows():
cells = []
for col in headers:
val = row[col]
if val is None or (isinstance(val, float) and np.isnan(val)):
disp = ""
elif isinstance(val, float):
disp = f"{val:.3f}"
else:
disp = str(val)
if (
classification_type == "Ichnogenus"
and col.endswith("_class")
and disp != ""
and not disp.startswith("Error")
):
disp = f"<i>{html.escape(disp)}</i>"
else:
disp = html.escape(disp)
cells.append(f"<td>{disp}</td>")
rows_html.append("<tr>" + "".join(cells) + "</tr>")
return (
"<div class='pred-table'>"
"<table>"
"<thead><tr>"
f"{header_cells}"
"</tr></thead>"
"<tbody>"
f"{''.join(rows_html)}"
"</tbody></table>"
"</div>"
)
# =========================
# Morphometric classification
# =========================
def classify_morphometric_batch(xlsx_paths, classification_type: str):
cols = [
"sample_id",
"top1_class", "top1_prob",
"top2_class", "top2_prob",
"top3_class", "top3_prob",
]
if not xlsx_paths:
empty_df = pd.DataFrame(columns=cols)
html_table = df_to_html_italic(empty_df, classification_type)
return (
html_table,
"Please upload at least one .xlsx file for Morphometric data.",
None,
[],
[],
None,
)
rows = []
total_rows = 0
for path in xlsx_paths:
base_name = os.path.basename(str(path))
try:
X, _ = _extract_features_from_excel(path)
n_samples = len(X)
if n_samples == 0:
continue
if classification_type == "Ichnogenus":
rf_model, scaler = get_morpho_ichno_model_and_scaler()
X_scaled = scaler.transform(X.values.astype(np.float32))
probs = rf_model.predict_proba(X_scaled)
if hasattr(rf_model, "classes_"):
class_names = _map_model_classes_to_names(rf_model.classes_, classification_type)
else:
class_names = CLASS_NAMES_ICHNO
else:
nn_model = get_morpho_to_model()
X_np = X.values.astype(np.float32)
probs = nn_model.predict(X_np, verbose=0)
class_names = CLASS_NAMES_TO
for idx in range(n_samples):
try:
prob_vec = probs[idx]
top_classes, top_probs = _top3_from_prob_vector(prob_vec, class_names)
sample_name = f"{base_name}#row{idx+1:02d}"
rows.append({
"sample_id": sample_name,
"top1_class": top_classes[0],
"top1_prob": top_probs[0],
"top2_class": top_classes[1],
"top2_prob": top_probs[1],
"top3_class": top_classes[2],
"top3_prob": top_probs[2],
})
total_rows += 1
except Exception as row_e:
sample_name = f"{base_name}#row{idx+1:02d}"
rows.append({
"sample_id": sample_name,
"top1_class": f"Error (row): {row_e}",
"top1_prob": None,
"top2_class": None,
"top2_prob": None,
"top3_class": None,
"top3_prob": None,
})
total_rows += 1
except Exception as e:
rows.append({
"sample_id": base_name,
"top1_class": f"Error (file): {e}",
"top1_prob": None,
"top2_class": None,
"top2_prob": None,
"top3_class": None,
"top3_prob": None,
})
df = pd.DataFrame(rows)
status = (
f"Processed {total_rows} morphometric row(s) "
f"from {len(xlsx_paths)} file(s) using '{classification_type}' classification."
)
if len(df) > MAX_ROWS_DISPLAY:
status += f" Showing first {MAX_ROWS_DISPLAY} rows in the table; full results are in the CSV."
tmpdir = tempfile.mkdtemp()
csv_path = os.path.join(
tmpdir,
f"predictions_morphometric_{classification_type.replace('/','_')}.csv"
)
df.to_csv(csv_path, index=False)
html_table = df_to_html_italic(df.head(MAX_ROWS_DISPLAY), classification_type)
return html_table, status, csv_path, [], [], None
# =========================
# Image classification + Grad-CAM
# =========================
def classify_images_batch(uploaded_files, data_type, classification_type):
cols = [
"image_name",
"top1_class", "top1_prob",
"top2_class", "top2_prob",
"top3_class", "top3_prob",
]
image_paths = []
if uploaded_files:
for p in uploaded_files:
p_str = str(p)
ext = os.path.splitext(p_str)[1].lower()
if ext in [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]:
image_paths.append(p_str)
if not image_paths:
empty_df = pd.DataFrame(columns=cols)
html_table = df_to_html_italic(empty_df, classification_type)
return html_table, "Please upload at least one image file.", None, [], [], None
rows = []
gradcam_overlays = []
gradcam_heatmaps = []
tmpdir = tempfile.mkdtemp()
gradcam_zip_path = os.path.join(tmpdir, "gradcam_outputs.zip")
with zipfile.ZipFile(gradcam_zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for path in image_paths:
base_name = os.path.splitext(os.path.basename(str(path)))[0]
try:
pil = Image.open(path).convert("RGB")
top_classes, top_probs = predict_top3_from_pil(
pil, data_type, classification_type
)
rows.append({
"image_name": os.path.basename(str(path)),
"top1_class": top_classes[0],
"top1_prob": top_probs[0],
"top2_class": top_classes[1],
"top2_prob": top_probs[1],
"top3_class": top_classes[2],
"top3_prob": top_probs[2],
})
try:
top_idx = _predict_top_idx_from_pil(pil, data_type, classification_type)
cam_overlay, cam_heatmap = generate_gradcam_from_pil(
pil,
data_type,
classification_type,
target_class_idx=top_idx,
)
gradcam_overlays.append(cam_overlay)
gradcam_heatmaps.append(cam_heatmap)
overlay_path = os.path.join(tmpdir, f"{base_name}_gradcam_overlay.png")
heatmap_path = os.path.join(tmpdir, f"{base_name}_gradcam_heatmap.png")
cam_overlay.save(overlay_path)
cam_heatmap.save(heatmap_path)
zf.write(overlay_path, arcname=os.path.basename(overlay_path))
zf.write(heatmap_path, arcname=os.path.basename(heatmap_path))
except Exception as cam_e:
print(f"Grad-CAM failed for {base_name}: {cam_e}")
except Exception as e:
rows.append({
"image_name": os.path.basename(str(path)),
"top1_class": f"Error: {e}",
"top1_prob": None,
"top2_class": None,
"top2_prob": None,
"top3_class": None,
"top3_prob": None,
})
df = pd.DataFrame(rows)
status = (
f"Processed {len(rows)} specimen image(s) "
f"using '{classification_type}' classification / '{data_type}' data type."
)
if len(df) > MAX_ROWS_DISPLAY:
status += f" Showing first {MAX_ROWS_DISPLAY} rows in the table; full results are in the CSV."
if len(gradcam_overlays) == 0:
gradcam_zip_path = None
status += " Grad-CAM could not be generated for the uploaded images."
csv_path = os.path.join(
tmpdir,
f"predictions_{classification_type.replace('/','_')}_{data_type}.csv"
)
df.to_csv(csv_path, index=False)
html_table = df_to_html_italic(df.head(MAX_ROWS_DISPLAY), classification_type)
return html_table, status, csv_path, gradcam_overlays, gradcam_heatmaps, gradcam_zip_path
# =========================
# Unified runner
# =========================
def run_classifier(all_files, data_type, classification_type):
if data_type == "Morphometric":
xlsx_paths = []
if all_files:
for p in all_files:
p_str = str(p)
ext = os.path.splitext(p_str)[1].lower()
if ext in [".xlsx", ".xls"]:
xlsx_paths.append(p_str)
return classify_morphometric_batch(xlsx_paths, classification_type)
image_paths = []
if all_files:
for p in all_files:
p_str = str(p)
ext = os.path.splitext(p_str)[1].lower()
if ext in [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]:
image_paths.append(p_str)
return classify_images_batch(image_paths, data_type, classification_type)
# =========================
# UI
# =========================
theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="amber",
neutral_hue="gray",
)
with gr.Blocks(theme=theme, css="""
.gradio-container {
font-family: 'Georgia', 'Times New Roman', serif;
}
.app-wrapper {
max-width: 1100px;
margin: 0 auto;
padding: 1.5rem 1rem 2rem 1rem;
}
.app-header {
text-align: center;
margin-bottom: 1.2rem;
}
.app-header h1 {
font-size: 2.1rem;
margin-bottom: 0.3rem;
}
.app-header h2 {
font-size: 1.1rem;
font-weight: normal;
opacity: 0.9;
}
.app-panel {
background: rgba(255, 255, 255, 0.85);
border-radius: 14px;
padding: 1.2rem 1.5rem;
margin-bottom: 1rem;
border: 1px solid rgba(120, 82, 45, 0.18);
}
.pred-table {
width: 100%;
max-height: 400px;
overflow-y: auto;
overflow-x: auto;
}
.pred-table table {
width: 100%;
min-width: 650px;
border-collapse: collapse;
margin-top: 0.5rem;
font-size: 0.9rem;
}
.pred-table thead {
background: #e0cfb3;
}
.pred-table th, .pred-table td {
border: 1px solid #d0b897;
padding: 0.4rem 0.6rem;
text-align: center;
color: #000000;
white-space: nowrap;
}
.pred-table td i {
color: #000000 !important;
}
.pred-table th {
font-weight: 600;
}
.pred-table tbody tr:nth-child(even) {
background: #f7eee2;
}
.pred-table tbody tr:nth-child(odd) {
background: #fbf4ea;
}
.pred-table td:first-child {
text-align: left;
}
""") as demo:
gr.HTML("<div class='app-wrapper'>")
gr.HTML("""
<div class="app-header">
<h1>🦖 Dinosaur Track Classifier</h1>
<h2>Multi-data machine learning assisted ichnological identifications</h2>
<p style="margin-top:0.5rem; font-size:0.9rem;">
Imagery models finetuned from a model trained on data obtained by the
<a href="https://zenodo.org/records/15092442" target="_blank">Deep Tracks</a>
App.<br>
Developed by <b>Carolina S. Marques</b>
(<a href="https://orcid.org/0000-0002-5936-9342" target="_blank">ORCID</a>)
as part of her PhD research, funded by CEAUL through FCT - Fundação para a Ciência e Tecnologia
(<a href="https://doi.org/10.54499/UI/BD/154258/2022" target="_blank">DOI</a>).
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<div class='app-panel'>")
gr.Markdown(
"#### 1. Choose classification task\n"
"- **Ichnogenus** → 8 ichnotaxa\n"
" (*Anomoepus*, *Caririchnium*, *Dinehichnus*, *Grallator*, "
"*Iguanodontipus*, *Kalohipus*, *Kayentapus*, *Megalosauripus*)\n"
"- **Theropod/Ornithopod** → broader trackmaker grouping"
)
classification_type = gr.Radio(
choices=["Ichnogenus", "Theropod/Ornithopod"],
value="Ichnogenus",
label="Classification level",
)
gr.Markdown(
"#### 2. Select input data type\n"
"- **Photograph**: raw field or lab photograph\n"
"- **Binarized**: cleaned binary outline / mask\n"
"- **Depth map**: elevation / depth information\n"
"- **Morphometric**: .xlsx tables with track measurements "
"(must contain: FL, FW, FL_FW, LII, LIII, LIV, II_IV, II_minus_IV)"
)
data_type = gr.Radio(
choices=["Photograph", "Binarized", "Depth map", "Morphometric"],
value="Photograph",
label="Input data type",
)
gr.HTML("</div>")
gr.HTML("<div class='app-panel'>")
gr.Markdown(
"#### 3. Upload data\n"
"Use the same upload area for either images (_jpeg, png, tiff, bmp_) "
"or morphometric tables (_.xlsx, .xls_)."
)
data_files = gr.Files(
label="Specimen images or morphometric spreadsheets",
file_types=["image", ".xlsx", ".xls"],
file_count="multiple",
type="filepath",
)
run_btn = gr.Button("Run classifier", variant="primary")
gr.HTML("</div>")
with gr.Column(scale=1.4):
gr.HTML("<div class='app-panel'>")
gr.Markdown("#### Predicted classes and probabilities")
results_html = gr.HTML(label="Top-3 predictions per image / row")
gr.Markdown(
"_How to read the table:_\n"
"- **top1_class** / **top1_prob**: class with the highest predicted probability.\n"
"- **top2_class** / **top2_prob**: second most probable class.\n"
"- **top3_class** / **top3_prob**: third most probable class, if available.\n"
"- For **Theropod/Ornithopod**, top3 entries are empty.\n"
"- Only the first displayed rows are shown in the table; the CSV contains the full output."
)
gr.HTML("</div>")
gr.HTML("<div class='app-panel'>")
gr.Markdown("#### Grad-CAM (imagery only)")
gr.Markdown(
"For image-based inputs, Grad-CAM highlights the image regions that contributed most strongly "
"to the top-1 prediction."
)
gradcam_gallery = gr.Gallery(
label="Grad-CAM overlays",
columns=2,
allow_preview=True,
)
gradcam_heatmap_gallery = gr.Gallery(
label="Grad-CAM heatmaps",
columns=2,
allow_preview=True,
)
gradcam_zip = gr.File(
label="Download Grad-CAM images as ZIP",
file_types=[".zip"],
)
gr.HTML("</div>")
gr.HTML("<div class='app-panel'>")
status_md = gr.Markdown()
df_file = gr.File(
label="Download full predictions as CSV",
file_types=[".csv"],
)
gr.Markdown(
"_Note_: CSV export uses plain text (no italics), suitable for further analysis and plotting."
)
gr.HTML("</div>")
gr.Markdown(" **Researchers involved in this study:** Carolina S. Marques, Diego Castanera, Matteo Belvedere, Ignacio Díaz-Martínez, Josué García-Cobeña, Elisabete Malafaia, Afonso Mota, Soraia Pereira, Vanda F. Santos, Lara Sciscio and Emmanuel Dufourq.")
gr.Markdown(" **Acknowledgments:** We want to acknowledge the following institutions that have allowed the access to the analyzed specimens: Jurassica Museum, Museo Aragonés de Paleontologia, Museo de Ciencias Naturales de la Universidad de Zaragoza, Museo Numantino de Soria, and Sociedade de História Natural de Torres Vedras.")
gr.HTML("</div>")
run_btn.click(
fn=run_classifier,
inputs=[data_files, data_type, classification_type],
outputs=[
results_html,
status_md,
df_file,
gradcam_gallery,
gradcam_heatmap_gallery,
gradcam_zip,
],
)
if __name__ == "__main__":
demo.queue()
demo.launch()