abuhanzala commited on
Commit
05e9c1c
·
verified ·
1 Parent(s): 1ef3fee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -43
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import numpy as np
3
- from PIL import Image, ImageFilter, ImageStat
4
  import tensorflow as tf
5
 
6
  # Load TFLite model
@@ -15,38 +15,8 @@ output_details = interpreter.get_output_details()
15
  class_names = ['Dyskeratotic', 'Koilocytotic', 'Metaplastic', 'Parabasal', 'Superficial-Intermediat']
16
  CONFIDENCE_THRESHOLD = 0.25
17
 
18
- def is_valid_cervical_image(image):
19
- """Basic smart validation: checks variance, edges, brightness"""
20
- # Convert to grayscale
21
- gray = image.convert("L")
22
- stat = ImageStat.Stat(gray)
23
-
24
- # Variance check (texture)
25
- variance = stat.var[0]
26
- if variance < 500: # threshold, tweak as needed
27
- return False, "Image lacks texture. Upload a proper cervical cell image."
28
-
29
- # Edge detection check
30
- edges = gray.filter(ImageFilter.FIND_EDGES)
31
- edge_data = np.array(edges)
32
- edge_pixels = np.sum(edge_data > 50)
33
- if edge_pixels < 1000: # too few edges
34
- return False, "Image has too few edges. Upload a clear cervical cell image."
35
-
36
- # Brightness check
37
- brightness = stat.mean[0]
38
- if brightness < 30 or brightness > 220:
39
- return False, "Image brightness/contrast is not suitable. Adjust and upload again."
40
-
41
- return True, ""
42
-
43
  def predict_image(image):
44
  try:
45
- # Validate input
46
- valid, message = is_valid_cervical_image(image)
47
- if not valid:
48
- return {"Error": 1.0}, f"⚠️ {message}"
49
-
50
  # Preprocess
51
  image = image.resize((224, 224)).convert("RGB")
52
  img_array = np.array(image, dtype=np.float32) / 255.0
@@ -55,27 +25,35 @@ def predict_image(image):
55
  # Run inference
56
  interpreter.set_tensor(input_details[0]['index'], img_array)
57
  interpreter.invoke()
58
- output = interpreter.get_tensor(output_details[0]['index'])[0]
59
 
60
- # Normalize
61
  probs = tf.nn.softmax(output).numpy()
62
 
63
- # Confidence check
64
- if np.max(probs) < CONFIDENCE_THRESHOLD:
65
- return {"Error": 1.0}, "⚠️ The model is unsure. Please upload a clearer cervical cell image."
 
 
 
 
 
 
 
66
 
67
- # Convert to dict for Gradio Label
68
- probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
69
- return probs_dict, f"✅ Prediction: {class_names[np.argmax(probs)]} ({np.max(probs)*100:.2f}%)"
 
70
 
71
  except Exception as e:
72
- return {"Error": 1.0}, f"⚠️ Something went wrong. Please upload a correct cervical cell image. ({str(e)})"
73
 
74
  # Gradio UI
75
  gr.Interface(
76
  fn=predict_image,
77
  inputs=gr.Image(type="pil"),
78
- outputs=[gr.Label(num_top_classes=len(class_names)), gr.Textbox()],
79
- title="Cervical Cancer Classification",
80
- description="Upload a cervical cell image. The model shows probabilities for each class and warns if the image is incorrect."
81
  ).launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ from PIL import Image
4
  import tensorflow as tf
5
 
6
  # Load TFLite model
 
15
  class_names = ['Dyskeratotic', 'Koilocytotic', 'Metaplastic', 'Parabasal', 'Superficial-Intermediat']
16
  CONFIDENCE_THRESHOLD = 0.25
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def predict_image(image):
19
  try:
 
 
 
 
 
20
  # Preprocess
21
  image = image.resize((224, 224)).convert("RGB")
22
  img_array = np.array(image, dtype=np.float32) / 255.0
 
25
  # Run inference
26
  interpreter.set_tensor(input_details[0]['index'], img_array)
27
  interpreter.invoke()
28
+ output = interpreter.get_tensor(output_details[0]['index'])[0] # shape (num_classes,)
29
 
30
+ # Normalize if needed (sometimes TFLite outputs logits)
31
  probs = tf.nn.softmax(output).numpy()
32
 
33
+ # Get predicted class
34
+ class_idx = int(np.argmax(probs))
35
+ confidence = float(np.max(probs))
36
+
37
+ # Format output (show every class probability)
38
+ results = []
39
+ for i, prob in enumerate(probs):
40
+ results.append(f"{class_names[i]}: {prob*100:.2f}%")
41
+
42
+ results_text = "\n".join(results)
43
 
44
+ if confidence < CONFIDENCE_THRESHOLD:
45
+ return f"⚠️ Low confidence ({confidence:.2f}). The model is unsure.\n\nProbabilities:\n{results_text}"
46
+ else:
47
+ return f"✅ Prediction: {class_names[class_idx]} ({confidence*100:.2f}%)\n\nProbabilities:\n{results_text}"
48
 
49
  except Exception as e:
50
+ return f"Error: {str(e)}"
51
 
52
  # Gradio UI
53
  gr.Interface(
54
  fn=predict_image,
55
  inputs=gr.Image(type="pil"),
56
+ outputs="text",
57
+ title="Muscle Disease Detection",
58
+ description="Upload an MRI image to detect muscle conditions."
59
  ).launch()