phiaaa commited on
Commit
3bb9819
·
verified ·
1 Parent(s): ffc503e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -3,53 +3,79 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # Load the TFLite model
7
  interpreter = tf.lite.Interpreter(model_path="stool_model.tflite")
8
  interpreter.allocate_tensors()
9
 
10
  input_details = interpreter.get_input_details()
11
  output_details = interpreter.get_output_details()
12
 
13
- # Define label names (adjust if your model uses different ones)
14
  labels = ["bloody", "hard stool", "normal", "parasite", "watery"]
15
 
 
16
  def preprocess_image(img: Image.Image):
17
  img = img.convert("RGB").resize((128, 128))
18
  arr = np.asarray(img).astype(np.float32) / 255.0
19
  arr = np.expand_dims(arr, axis=0)
20
  return arr
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def classify_image(image):
23
  try:
24
- img = image.convert("RGB")
25
- arr = np.asarray(img).astype(np.float32)
26
- brightness = arr.mean()
27
- contrast = arr.std()
28
-
29
- # 🧠 Simple sanity check for stool-like features (brownish tone + moderate contrast)
30
- # You can adjust these thresholds depending on your dataset
31
- if brightness > 220 or brightness < 20 or contrast < 25:
32
  return {"Not stool image": 1.0}
33
 
34
- # Normal prediction process
35
  input_data = preprocess_image(image)
36
  interpreter.set_tensor(input_details[0]['index'], input_data)
37
  interpreter.invoke()
38
  output_data = interpreter.get_tensor(output_details[0]['index'])[0]
 
 
39
  results = {labels[i]: float(output_data[i]) for i in range(len(labels))}
40
  sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
 
 
 
 
 
 
41
  return sorted_results
42
 
43
  except Exception as e:
44
- return {"error": str(e)}
45
-
46
 
 
47
  demo = gr.Interface(
48
  fn=classify_image,
49
- inputs=gr.Image(type="pil", label="Upload stool image"),
50
- outputs=gr.Label(num_top_classes=3, label="Predictions"),
51
- title="Stool Diagnosis Model",
52
- description="Upload a stool image for classification."
53
  )
54
 
55
  if __name__ == "__main__":
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # 🧠 Load your TFLite model
7
  interpreter = tf.lite.Interpreter(model_path="stool_model.tflite")
8
  interpreter.allocate_tensors()
9
 
10
  input_details = interpreter.get_input_details()
11
  output_details = interpreter.get_output_details()
12
 
13
+ # 🏷️ Define your classes
14
  labels = ["bloody", "hard stool", "normal", "parasite", "watery"]
15
 
16
+ # 🧩 Image preprocessing
17
  def preprocess_image(img: Image.Image):
18
  img = img.convert("RGB").resize((128, 128))
19
  arr = np.asarray(img).astype(np.float32) / 255.0
20
  arr = np.expand_dims(arr, axis=0)
21
  return arr
22
 
23
+ # 🚫 Detect if the uploaded image is NOT a stool image
24
+ def is_not_stool_image(image):
25
+ arr = np.asarray(image.convert("RGB")).astype(np.float32)
26
+ brightness = arr.mean()
27
+ contrast = arr.std()
28
+ avg_color = arr.mean(axis=(0, 1))
29
+
30
+ # 🧠 Basic heuristic checks
31
+ # These values are adjustable based on your dataset
32
+ if brightness > 220 or brightness < 25:
33
+ return True # too bright or dark
34
+ if contrast < 25:
35
+ return True # too flat / low texture
36
+ if avg_color[0] > 180 and avg_color[1] < 80 and avg_color[2] < 80:
37
+ return True # too red
38
+ if avg_color[0] < 50 and avg_color[1] > 180:
39
+ return True # too greenish
40
+ if avg_color[2] > 200:
41
+ return True # too blueish
42
+
43
+ return False
44
+
45
+ # 🧠 Classification function
46
  def classify_image(image):
47
  try:
48
+ # 🚫 Check if this is not stool
49
+ if is_not_stool_image(image):
 
 
 
 
 
 
50
  return {"Not stool image": 1.0}
51
 
52
+ # Proceed with model prediction
53
  input_data = preprocess_image(image)
54
  interpreter.set_tensor(input_details[0]['index'], input_data)
55
  interpreter.invoke()
56
  output_data = interpreter.get_tensor(output_details[0]['index'])[0]
57
+
58
+ # Sort predictions
59
  results = {labels[i]: float(output_data[i]) for i in range(len(labels))}
60
  sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
61
+
62
+ # Extra sanity rule: if top score < 0.4, label as uncertain
63
+ top_label, top_score = list(sorted_results.items())[0]
64
+ if top_score < 0.4:
65
+ return {"Uncertain / unclear stool image": top_score}
66
+
67
  return sorted_results
68
 
69
  except Exception as e:
70
+ return {"Error": str(e)}
 
71
 
72
+ # 🎨 Gradio UI
73
  demo = gr.Interface(
74
  fn=classify_image,
75
+ inputs=gr.Image(type="pil", label="📸 Upload stool image"),
76
+ outputs=gr.Label(num_top_classes=3, label="Predicted diagnosis"),
77
+ title="🐾 Stool Diagnosis AI",
78
+ description="Upload a stool image for analysis. The model predicts stool type or rejects unrelated photos."
79
  )
80
 
81
  if __name__ == "__main__":