Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import traceback | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # === CONFIG === | |
| MODEL_ID = "varshithkumar/wbc_resnet50" # your model repo id | |
| CLASS_NAMES = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil'] | |
| # Globals | |
| model = None | |
| infer = None # serving function | |
| # === Load model === | |
| def load_model(): | |
| global infer | |
| hf_token = os.environ.get("HF_TOKEN") # set in Space secrets if repo is private | |
| try: | |
| print(f"β³ Downloading model from Hugging Face Hub: {MODEL_ID}") | |
| repo_dir = snapshot_download(repo_id=MODEL_ID, repo_type="model", token=hf_token) | |
| print("β Model snapshot downloaded at:", repo_dir) | |
| # Load TF SavedModel | |
| model = tf.saved_model.load(repo_dir) | |
| print("β Model loaded using tf.saved_model.load()") | |
| # Get serving function | |
| infer = model.signatures["serving_default"] | |
| # π Debug info | |
| print("Available signatures:", list(model.signatures.keys())) | |
| print("Serving function inputs:", infer.structured_input_signature) | |
| print("Serving function outputs:", infer.structured_outputs) | |
| return model | |
| except Exception as e: | |
| print("β Failed to load model:", e) | |
| traceback.print_exc() | |
| return None | |
| model = load_model() | |
| if model is None: | |
| print("WARNING: Model failed to load. Predictions will return an error.") | |
| # === preprocessing & prediction === | |
| def preprocess_image(img: Image.Image): | |
| img = img.convert("RGB") | |
| img = img.resize((224, 224)) # ResNet50 expected input size | |
| arr = np.array(img).astype(np.float32) / 255.0 | |
| arr = np.expand_dims(arr, 0) | |
| return arr | |
| def predict(image): | |
| global infer | |
| if infer is None: | |
| return {"error": "Model not loaded. Check Space logs."} | |
| try: | |
| arr = preprocess_image(image) | |
| preds = infer(input_layer_2=tf.constant(arr))["output_0"].numpy() | |
| probs = preds[0].tolist() | |
| if len(probs) == len(CLASS_NAMES): | |
| out = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} | |
| else: | |
| out = {"class_" + str(i): float(p) for i, p in enumerate(probs)} | |
| return out | |
| except Exception as e: | |
| print("Prediction error:", e) | |
| traceback.print_exc() | |
| return {"error": str(e)} | |
| # === Gradio UI === | |
| title = "WBC ResNet50 - White Blood Cell Classifier" | |
| description = "Upload a blood-smear image. Model resizes input to 224Γ224. If model fails to load, predictions will error." | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload image"), | |
| outputs=gr.Label(num_top_classes=None, label="Predictions"), | |
| title=title, | |
| description=description, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) # β enable verbose error reporting | |