fossil_app / inference_resnet_v2.py
piperod91's picture
Fix model loading: BEiT support, token fallback (HF_TOKEN), condition bug
29697f1
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# if gpu_devices:
# tf.config.experimental.set_memory_growth(gpu_devices[0], True)
# else:
# print(f"TensorFlow device: {gpu_devices}")
from keras.applications import resnet
import tensorflow as tf
import keras
import os
import matplotlib.pyplot as plt
from typing import Tuple
from huggingface_hub import snapshot_download
from labels import lookup_140
import numpy as np
def _ensure_models_downloaded():
if os.path.exists("model_classification"):
return
REPO_ID = "Serrelab/fossil_classification_models"
token = os.getenv("READ_TOKEN") or os.getenv("HF_TOKEN")
if token is None:
raise RuntimeError(
"model_classification/ is missing and READ_TOKEN (or HF_TOKEN) is not set. "
"Set READ_TOKEN in .env or HF_TOKEN on Spaces to download models."
)
print("read token:", token)
snapshot_download(repo_id=REPO_ID, token=token, repo_type="model", local_dir="model_classification")
def get_resnet_model(model_path):
_ensure_models_downloaded()
cce = tf.keras.losses.categorical_crossentropy
model = keras.models.load_model(model_path, custom_objects = {"cce":cce})
g = keras.Model(model.input, model.layers[2].output)
# out = tf.keras.layers.Activation('relu')(g_.output)
# g = tf.keras.Model(model.input, out)
h = keras.Model(model.layers[3].input, model.layers[-1].output)
return model, g, h
def select_top_n(preds,n=10):
top_n = np.argsort(preds)[-n:][::-1]
return top_n
def parse_results(top_n,logits):
results = {}
for n in top_n:
label = lookup_140[n]
results[label] = float(logits[n])
return results
def inference_resnet_embedding_v2(x,model,size=384,n_classes=140,n_top=10):
x = tf.image.resize(x, (size, size))
x = tf.reshape(x, (384, 384, 3))/255
embedding = model.predict(np.array([x]))[0][0]
return embedding
def inference_resnet_finer_v2(x,model,size=384,n_classes=142,n_top=10):
x = tf.image.resize(x, (size, size))
x = tf.reshape(x, (384, 384, 3))/255
# _, batch_logits = model.predict(x)
# predictions = tf.math.top_k(batch_logits, k=10)
# print(predictions)
logits = model.predict(np.array([x]))
print(len(logits[0][0]))
print(logits)
logits = tf.nn.softmax(logits[1][0]).cpu().numpy()
print(logits)
top_n = select_top_n(logits,n=n_top)
print(top_n)
return parse_results(top_n,logits)