varshithkumar commited on
Commit
07aa38c
Β·
1 Parent(s): 7001abb

Add Gradio app and requirements for WBC ResNet50

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -11,22 +11,31 @@ from huggingface_hub import snapshot_download
11
  MODEL_ID = "varshithkumar/wbc_resnet50" # your model repo id
12
  CLASS_NAMES = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil']
13
 
 
 
 
 
14
  # === Load model ===
15
  def load_model():
 
16
  hf_token = os.environ.get("HF_TOKEN") # set in Space secrets if repo is private
17
  try:
18
  print(f"⏳ Downloading model from Hugging Face Hub: {MODEL_ID}")
19
  repo_dir = snapshot_download(repo_id=MODEL_ID, repo_type="model", token=hf_token)
20
  print("βœ… Model snapshot downloaded at:", repo_dir)
21
 
22
- # Try loading with tf.saved_model.load instead of keras
23
  model = tf.saved_model.load(repo_dir)
24
  print("βœ… Model loaded using tf.saved_model.load()")
25
- # πŸ” Debug: print signatures and IO structure
26
- print("Available signatures:", list(model.signatures.keys()))
27
  infer = model.signatures["serving_default"]
 
 
 
28
  print("Serving function inputs:", infer.structured_input_signature)
29
  print("Serving function outputs:", infer.structured_outputs)
 
30
  return model
31
  except Exception as e:
32
  print("❌ Failed to load model:", e)
@@ -35,7 +44,7 @@ def load_model():
35
 
36
  model = load_model()
37
  if model is None:
38
- print("WARNING: Model failed to load. Predictions will return zeros.")
39
 
40
  # === preprocessing & prediction ===
41
  def preprocess_image(img: Image.Image):
@@ -46,30 +55,27 @@ def preprocess_image(img: Image.Image):
46
  return arr
47
 
48
  def predict(image):
49
- if model is None:
50
- return {cls: 0.0 for cls in CLASS_NAMES}
 
51
  try:
52
  arr = preprocess_image(image)
 
53
 
54
- # SavedModel inference requires calling the serving function
55
- infer = model.signatures["serving_default"]
56
- preds = infer(tf.constant(arr)) # run inference
57
- preds = list(preds.values())[0].numpy() # extract tensor
58
  probs = preds[0].tolist()
59
-
60
  if len(probs) == len(CLASS_NAMES):
61
  out = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
62
  else:
63
- out = {f"class_{i}": float(p) for i, p in enumerate(probs)}
64
  return out
65
  except Exception as e:
66
  print("Prediction error:", e)
67
  traceback.print_exc()
68
- return {cls: 0.0 for cls in CLASS_NAMES}
69
 
70
  # === Gradio UI ===
71
  title = "WBC ResNet50 - White Blood Cell Classifier"
72
- description = "Upload a blood-smear image. Model resizes input to 224Γ—224. If model fails to load, all predictions will be 0."
73
 
74
  demo = gr.Interface(
75
  fn=predict,
 
11
  MODEL_ID = "varshithkumar/wbc_resnet50" # your model repo id
12
  CLASS_NAMES = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil']
13
 
14
+ # Globals
15
+ model = None
16
+ infer = None # serving function
17
+
18
  # === Load model ===
19
  def load_model():
20
+ global infer
21
  hf_token = os.environ.get("HF_TOKEN") # set in Space secrets if repo is private
22
  try:
23
  print(f"⏳ Downloading model from Hugging Face Hub: {MODEL_ID}")
24
  repo_dir = snapshot_download(repo_id=MODEL_ID, repo_type="model", token=hf_token)
25
  print("βœ… Model snapshot downloaded at:", repo_dir)
26
 
27
+ # Load TF SavedModel
28
  model = tf.saved_model.load(repo_dir)
29
  print("βœ… Model loaded using tf.saved_model.load()")
30
+
31
+ # Get serving function
32
  infer = model.signatures["serving_default"]
33
+
34
+ # πŸ” Debug info
35
+ print("Available signatures:", list(model.signatures.keys()))
36
  print("Serving function inputs:", infer.structured_input_signature)
37
  print("Serving function outputs:", infer.structured_outputs)
38
+
39
  return model
40
  except Exception as e:
41
  print("❌ Failed to load model:", e)
 
44
 
45
  model = load_model()
46
  if model is None:
47
+ print("WARNING: Model failed to load. Predictions will return an error.")
48
 
49
  # === preprocessing & prediction ===
50
  def preprocess_image(img: Image.Image):
 
55
  return arr
56
 
57
  def predict(image):
58
+ global infer
59
+ if infer is None:
60
+ return {"error": "Model not loaded. Check Space logs."}
61
  try:
62
  arr = preprocess_image(image)
63
+ preds = infer(input_layer_2=tf.constant(arr))["output_0"].numpy()
64
 
 
 
 
 
65
  probs = preds[0].tolist()
 
66
  if len(probs) == len(CLASS_NAMES):
67
  out = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
68
  else:
69
+ out = {"class_" + str(i): float(p) for i, p in enumerate(probs)}
70
  return out
71
  except Exception as e:
72
  print("Prediction error:", e)
73
  traceback.print_exc()
74
+ return {"error": str(e)}
75
 
76
  # === Gradio UI ===
77
  title = "WBC ResNet50 - White Blood Cell Classifier"
78
+ description = "Upload a blood-smear image. Model resizes input to 224Γ—224. If model fails to load, predictions will error."
79
 
80
  demo = gr.Interface(
81
  fn=predict,