File size: 2,939 Bytes
767b8c3
 
 
 
 
 
 
84c156c
767b8c3
 
84c156c
 
767b8c3
07aa38c
 
 
 
84c156c
767b8c3
07aa38c
84c156c
767b8c3
84c156c
767b8c3
84c156c
 
07aa38c
84c156c
 
07aa38c
 
7001abb
07aa38c
 
 
7001abb
 
07aa38c
84c156c
767b8c3
84c156c
767b8c3
 
 
77dd53d
767b8c3
07aa38c
767b8c3
 
 
 
84c156c
767b8c3
 
 
 
 
07aa38c
 
 
767b8c3
 
07aa38c
84c156c
767b8c3
 
 
 
07aa38c
767b8c3
 
 
 
07aa38c
767b8c3
 
 
07aa38c
767b8c3
 
 
 
0638dc4
767b8c3
 
 
 
 
 
7cfd6a9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# app.py
import os
import traceback
import tensorflow as tf
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download

# === CONFIG ===
MODEL_ID = "varshithkumar/wbc_resnet50"   # your model repo id
CLASS_NAMES = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil']

# Globals
model = None
infer = None   # serving function

# === Load model ===
def load_model():
    global infer
    hf_token = os.environ.get("HF_TOKEN")  # set in Space secrets if repo is private
    try:
        print(f"⏳ Downloading model from Hugging Face Hub: {MODEL_ID}")
        repo_dir = snapshot_download(repo_id=MODEL_ID, repo_type="model", token=hf_token)
        print("βœ… Model snapshot downloaded at:", repo_dir)

        # Load TF SavedModel
        model = tf.saved_model.load(repo_dir)
        print("βœ… Model loaded using tf.saved_model.load()")

        # Get serving function
        infer = model.signatures["serving_default"]

        # πŸ” Debug info
        print("Available signatures:", list(model.signatures.keys()))
        print("Serving function inputs:", infer.structured_input_signature)
        print("Serving function outputs:", infer.structured_outputs)

        return model
    except Exception as e:
        print("❌ Failed to load model:", e)
        traceback.print_exc()
        return None

model = load_model()
if model is None:
    print("WARNING: Model failed to load. Predictions will return an error.")

# === preprocessing & prediction ===
def preprocess_image(img: Image.Image):
    img = img.convert("RGB")
    img = img.resize((224, 224))  # ResNet50 expected input size
    arr = np.array(img).astype(np.float32) / 255.0
    arr = np.expand_dims(arr, 0)
    return arr

def predict(image):
    global infer
    if infer is None:
        return {"error": "Model not loaded. Check Space logs."}
    try:
        arr = preprocess_image(image)
        preds = infer(input_layer_2=tf.constant(arr))["output_0"].numpy()

        probs = preds[0].tolist()
        if len(probs) == len(CLASS_NAMES):
            out = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
        else:
            out = {"class_" + str(i): float(p) for i, p in enumerate(probs)}
        return out
    except Exception as e:
        print("Prediction error:", e)
        traceback.print_exc()
        return {"error": str(e)}

# === Gradio UI ===
title = "WBC ResNet50 - White Blood Cell Classifier"
description = "Upload a blood-smear image. Model resizes input to 224Γ—224. If model fails to load, predictions will error."

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload image"),
    outputs=gr.Label(num_top_classes=None, label="Predictions"),
    title=title,
    description=description,
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch(show_error=True)  # βœ… enable verbose error reporting