wbc-resnet50 / app.py
varshithkumar's picture
Add Gradio app and requirements for WBC ResNet50
7cfd6a9
# app.py
import os
import traceback
import tensorflow as tf
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
# === CONFIG ===
MODEL_ID = "varshithkumar/wbc_resnet50" # your model repo id
CLASS_NAMES = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil']
# Globals
model = None
infer = None # serving function
# === Load model ===
def load_model():
global infer
hf_token = os.environ.get("HF_TOKEN") # set in Space secrets if repo is private
try:
print(f"⏳ Downloading model from Hugging Face Hub: {MODEL_ID}")
repo_dir = snapshot_download(repo_id=MODEL_ID, repo_type="model", token=hf_token)
print("βœ… Model snapshot downloaded at:", repo_dir)
# Load TF SavedModel
model = tf.saved_model.load(repo_dir)
print("βœ… Model loaded using tf.saved_model.load()")
# Get serving function
infer = model.signatures["serving_default"]
# πŸ” Debug info
print("Available signatures:", list(model.signatures.keys()))
print("Serving function inputs:", infer.structured_input_signature)
print("Serving function outputs:", infer.structured_outputs)
return model
except Exception as e:
print("❌ Failed to load model:", e)
traceback.print_exc()
return None
model = load_model()
if model is None:
print("WARNING: Model failed to load. Predictions will return an error.")
# === preprocessing & prediction ===
def preprocess_image(img: Image.Image):
img = img.convert("RGB")
img = img.resize((224, 224)) # ResNet50 expected input size
arr = np.array(img).astype(np.float32) / 255.0
arr = np.expand_dims(arr, 0)
return arr
def predict(image):
global infer
if infer is None:
return {"error": "Model not loaded. Check Space logs."}
try:
arr = preprocess_image(image)
preds = infer(input_layer_2=tf.constant(arr))["output_0"].numpy()
probs = preds[0].tolist()
if len(probs) == len(CLASS_NAMES):
out = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
else:
out = {"class_" + str(i): float(p) for i, p in enumerate(probs)}
return out
except Exception as e:
print("Prediction error:", e)
traceback.print_exc()
return {"error": str(e)}
# === Gradio UI ===
title = "WBC ResNet50 - White Blood Cell Classifier"
description = "Upload a blood-smear image. Model resizes input to 224Γ—224. If model fails to load, predictions will error."
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload image"),
outputs=gr.Label(num_top_classes=None, label="Predictions"),
title=title,
description=description,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(show_error=True) # βœ… enable verbose error reporting