# app.py import os import traceback import tensorflow as tf import numpy as np from PIL import Image import gradio as gr from huggingface_hub import snapshot_download # === CONFIG === MODEL_ID = "varshithkumar/wbc_resnet50" # your model repo id CLASS_NAMES = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil'] # Globals model = None infer = None # serving function # === Load model === def load_model(): global infer hf_token = os.environ.get("HF_TOKEN") # set in Space secrets if repo is private try: print(f"⏳ Downloading model from Hugging Face Hub: {MODEL_ID}") repo_dir = snapshot_download(repo_id=MODEL_ID, repo_type="model", token=hf_token) print("✅ Model snapshot downloaded at:", repo_dir) # Load TF SavedModel model = tf.saved_model.load(repo_dir) print("✅ Model loaded using tf.saved_model.load()") # Get serving function infer = model.signatures["serving_default"] # 🔍 Debug info print("Available signatures:", list(model.signatures.keys())) print("Serving function inputs:", infer.structured_input_signature) print("Serving function outputs:", infer.structured_outputs) return model except Exception as e: print("❌ Failed to load model:", e) traceback.print_exc() return None model = load_model() if model is None: print("WARNING: Model failed to load. Predictions will return an error.") # === preprocessing & prediction === def preprocess_image(img: Image.Image): img = img.convert("RGB") img = img.resize((224, 224)) # ResNet50 expected input size arr = np.array(img).astype(np.float32) / 255.0 arr = np.expand_dims(arr, 0) return arr def predict(image): global infer if infer is None: return {"error": "Model not loaded. Check Space logs."} try: arr = preprocess_image(image) preds = infer(input_layer_2=tf.constant(arr))["output_0"].numpy() probs = preds[0].tolist() if len(probs) == len(CLASS_NAMES): out = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} else: out = {"class_" + str(i): float(p) for i, p in enumerate(probs)} return out except Exception as e: print("Prediction error:", e) traceback.print_exc() return {"error": str(e)} # === Gradio UI === title = "WBC ResNet50 - White Blood Cell Classifier" description = "Upload a blood-smear image. Model resizes input to 224×224. If model fails to load, predictions will error." demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload image"), outputs=gr.Label(num_top_classes=None, label="Predictions"), title=title, description=description, allow_flagging="never" ) if __name__ == "__main__": demo.launch(show_error=True) # ✅ enable verbose error reporting