Vedag812 ibrahim313 commited on
Commit
1698179
·
verified ·
1 Parent(s): 489a078

Update app.py (#6)

Browse files

- Update app.py (fddd2289eb49d81e570be22d6b265f316194ce72)


Co-authored-by: MUHAMMAD IBRAHIM <ibrahim313@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +41 -65
app.py CHANGED
@@ -3,117 +3,93 @@ import os, glob, traceback
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
6
- import tensorflow as tf
7
 
8
- # try to use Keras 3 loader if present (many models saved with Keras 3 need this)
9
- KERAS3_AVAILABLE = False
10
- try:
11
- import keras # pip package "keras" (v3.x)
12
- KERAS3_AVAILABLE = int(keras.__version__.split(".")[0]) >= 3
13
- except Exception:
14
- keras = None
15
 
16
  HF_MODEL_ID = "Vedag812/xray_cnn"
17
  CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
18
 
19
-
20
  def load_model():
21
- from huggingface_hub import hf_hub_download
22
  model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
23
- # Keras 3 path
24
- if KERAS3_AVAILABLE:
25
- os.environ.setdefault("KERAS_BACKEND", "tensorflow")
26
- try:
27
- return keras.saving.load_model(model_path, compile=False, safe_mode=False)
28
- except Exception:
29
- # fall back to tf.keras if that fails for any reason
30
- pass
31
- # tf.keras path
32
- return tf.keras.models.load_model(model_path, compile=False)
33
 
34
  def _infer_input_shape(model):
35
- """returns (H, W, C) with integers if available, else defaults to (150,150,1)"""
36
- shape = None
37
  try:
38
- shape = tuple(model.inputs[0].shape.as_list()) # works on many TF models
39
  except Exception:
40
- try:
41
- shape = tuple(model.input_shape)
42
- except Exception:
43
- pass
44
- if not shape or len(shape) < 4:
45
  return 150, 150, 1
46
- H = int(shape[1]) if shape[1] else 150
47
- W = int(shape[2]) if shape[2] else 150
48
- C = int(shape[3]) if shape[3] else 1
49
  return H, W, C
50
 
51
- def preprocess(pil_img: Image.Image, target_hw_c):
52
- H, W, C = target_hw_c
53
- # always start from grayscale so intensity stays consistent
54
  g = pil_img.convert("L").resize((W, H))
55
- g_arr = np.array(g).astype("float32") / 255.0 # (H,W)
56
  if C == 1:
57
- x = np.expand_dims(g_arr, axis=(0, -1)) # (1,H,W,1)
58
  elif C == 3:
59
- x3 = np.stack([g_arr, g_arr, g_arr], axis=-1) # (H,W,3)
60
- x = np.expand_dims(x3, axis=0) # (1,H,W,3)
61
  else:
62
- # unexpected channel count. tile to that count safely
63
- xC = np.repeat(g_arr[..., None], C, axis=-1)
64
- x = np.expand_dims(xC, axis=0)
65
  return x
66
 
67
  def predict_fn(pil_img: Image.Image):
68
  try:
 
 
69
  model = load_model()
70
  H, W, C = _infer_input_shape(model)
71
  x = preprocess(pil_img, (H, W, C))
72
- preds = model.predict(x, verbose=0)
73
- # handle models that output shape (1,1) or (1,)
74
- prob = float(preds.ravel()[0])
75
- pred_idx = int(prob > 0.5)
76
- confidence = prob if pred_idx == 1 else 1 - prob
77
  probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
78
- msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%"
79
  return probs, msg
80
  except Exception as e:
81
- # show a readable error with a tip
82
  tip = (
83
- "Tip: if this keeps happening, the Space may need keras>=3 to load a model "
84
- "saved with newer Keras. I handled both paths here, but if your model was saved "
85
- "with a very new version, updating the Space deps can help."
86
  )
87
- err_text = "⚠️ Error during prediction:\n\n" + str(e) + "\n\n" + tip
88
- return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err_text
 
 
89
 
90
  def list_examples():
91
  files = []
92
- for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
93
- files.extend(glob.glob(pattern))
94
- files = sorted(files)
95
- return [[p] for p in files]
96
 
97
  with gr.Blocks(css="""
98
  .gradio-container {max-width: 980px !important; margin: auto;}
99
  #title {text-align:center;}
100
- .card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;}
101
  """) as demo:
102
  gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
103
- gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.")
104
 
105
  with gr.Row():
106
  with gr.Column(scale=2):
107
  inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
108
  with gr.Row():
109
  btn = gr.Button("Predict", variant="primary")
110
- gr.ClearButton(components=[inp], value="Clear")
111
  gr.Markdown("### Samples")
112
- gr.Examples(
113
- examples=list_examples(),
114
- inputs=inp,
115
- examples_per_page=12,
116
- )
117
  with gr.Column(scale=1):
118
  probs = gr.Label(num_top_classes=2, label="Class probabilities")
119
  out_text = gr.Markdown()
 
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
 
6
 
7
+ # Use Keras 3 with the TensorFlow backend
8
+ os.environ.setdefault("KERAS_BACKEND", "tensorflow")
9
+ import keras # Keras 3.x
10
+ from huggingface_hub import hf_hub_download
 
 
 
11
 
12
  HF_MODEL_ID = "Vedag812/xray_cnn"
13
  CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
14
 
15
+ @gr.cache_resource
16
  def load_model():
 
17
  model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
18
+ # safe_mode=False allows loading models saved with older or custom configs
19
+ model = keras.saving.load_model(model_path, compile=False, safe_mode=False)
20
+ return model
 
 
 
 
 
 
 
21
 
22
  def _infer_input_shape(model):
23
+ # returns (H, W, C)
 
24
  try:
25
+ shp = tuple(model.inputs[0].shape)
26
  except Exception:
27
+ shp = getattr(model, "input_shape", None)
28
+ if shp is None:
29
+ return 150, 150, 1
30
+ if len(shp) < 4:
 
31
  return 150, 150, 1
32
+ H = int(shp[1]) if shp[1] is not None else 150
33
+ W = int(shp[2]) if shp[2] is not None else 150
34
+ C = int(shp[3]) if shp[3] is not None else 1
35
  return H, W, C
36
 
37
+ def preprocess(pil_img: Image.Image, target):
38
+ H, W, C = target
 
39
  g = pil_img.convert("L").resize((W, H))
40
+ arr = np.array(g).astype("float32") / 255.0 # (H, W)
41
  if C == 1:
42
+ x = np.expand_dims(arr, axis=(0, -1)) # (1,H,W,1)
43
  elif C == 3:
44
+ x = np.expand_dims(np.stack([arr]*3, axis=-1), 0) # (1,H,W,3)
 
45
  else:
46
+ x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0)
 
 
47
  return x
48
 
49
  def predict_fn(pil_img: Image.Image):
50
  try:
51
+ if pil_img is None:
52
+ return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample."
53
  model = load_model()
54
  H, W, C = _infer_input_shape(model)
55
  x = preprocess(pil_img, (H, W, C))
56
+ y = model.predict(x, verbose=0)
57
+ prob = float(np.ravel(y)[0]) # sigmoid output
58
+ idx = int(prob > 0.5)
59
+ conf = prob if idx == 1 else 1 - prob
 
60
  probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
61
+ msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%"
62
  return probs, msg
63
  except Exception as e:
 
64
  tip = (
65
+ "If this persists, make sure the Space has keras>=3 and tensorflow>=2.16."
 
 
66
  )
67
+ err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}"
68
+ # Optional: uncomment next line to print full stack to the Space logs
69
+ # print(traceback.format_exc())
70
+ return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err
71
 
72
  def list_examples():
73
  files = []
74
+ for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"):
75
+ files.extend(glob.glob(pat))
76
+ return [[p] for p in sorted(files)]
 
77
 
78
  with gr.Blocks(css="""
79
  .gradio-container {max-width: 980px !important; margin: auto;}
80
  #title {text-align:center;}
 
81
  """) as demo:
82
  gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
83
+ gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.")
84
 
85
  with gr.Row():
86
  with gr.Column(scale=2):
87
  inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
88
  with gr.Row():
89
  btn = gr.Button("Predict", variant="primary")
90
+ gr.ClearButton(components=[inp])
91
  gr.Markdown("### Samples")
92
+ gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12)
 
 
 
 
93
  with gr.Column(scale=1):
94
  probs = gr.Label(num_top_classes=2, label="Class probabilities")
95
  out_text = gr.Markdown()