abuhanzala commited on
Commit
465e140
·
verified ·
1 Parent(s): 04dd316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -17,6 +17,10 @@ CONFIDENCE_THRESHOLD = 0.7
17
 
18
  def predict_image(image):
19
  try:
 
 
 
 
20
  # Preprocess
21
  image = image.resize((224, 224)).convert("RGB")
22
  img_array = np.array(image, dtype=np.float32) / 255.0
@@ -27,22 +31,25 @@ def predict_image(image):
27
  interpreter.invoke()
28
  output = interpreter.get_tensor(output_details[0]['index'])[0] # shape (num_classes,)
29
 
30
- # Normalize if needed
31
  probs = tf.nn.softmax(output).numpy()
32
 
 
 
 
 
33
  # Convert to dict for Gradio Label
34
  probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
35
-
36
- return probs_dict
37
 
38
  except Exception as e:
39
- return {"Error": 1.0} # dummy error output
40
 
41
  # Gradio UI
42
  gr.Interface(
43
  fn=predict_image,
44
  inputs=gr.Image(type="pil"),
45
- outputs=gr.Label(num_top_classes=len(class_names)), # shows all classes with bars
46
  title="Cervical Cancer Classification",
47
- description="Upload an image. The model shows probabilities for each class."
48
  ).launch()
 
17
 
18
  def predict_image(image):
19
  try:
20
+ # Validate input
21
+ if image.mode != "RGB":
22
+ return {"Error": 1.0}, "⚠️ Please upload a valid cervical cell image (RGB image required)."
23
+
24
  # Preprocess
25
  image = image.resize((224, 224)).convert("RGB")
26
  img_array = np.array(image, dtype=np.float32) / 255.0
 
31
  interpreter.invoke()
32
  output = interpreter.get_tensor(output_details[0]['index'])[0] # shape (num_classes,)
33
 
34
+ # Normalize
35
  probs = tf.nn.softmax(output).numpy()
36
 
37
+ # Check if prediction is below confidence threshold
38
+ if np.max(probs) < CONFIDENCE_THRESHOLD:
39
+ return {"Error": 1.0}, "⚠️ The model is unsure. Please upload a clearer/correct image of cervical cells."
40
+
41
  # Convert to dict for Gradio Label
42
  probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
43
+ return probs_dict, f"✅ Prediction: {class_names[np.argmax(probs)]} ({np.max(probs)*100:.2f}%)"
 
44
 
45
  except Exception as e:
46
+ return {"Error": 1.0}, f"⚠️ Something went wrong. Please upload a correct cervical cell image. ({str(e)})"
47
 
48
  # Gradio UI
49
  gr.Interface(
50
  fn=predict_image,
51
  inputs=gr.Image(type="pil"),
52
+ outputs=[gr.Label(num_top_classes=len(class_names)), gr.Textbox()],
53
  title="Cervical Cancer Classification",
54
+ description="Upload an image. The model shows probabilities for each class and warns if the image is incorrect."
55
  ).launch()