jayn95 commited on
Commit
367cd24
·
verified ·
1 Parent(s): ddfb6ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -14
app.py CHANGED
@@ -27,28 +27,31 @@ def get_models():
27
  return model_swelling, model_redness, model_bleeding
28
 
29
 
30
- # --- Helpers ---
 
 
31
  def preprocess(image):
 
32
  if isinstance(image, np.ndarray):
33
  image = Image.fromarray(image)
34
 
35
  image = ImageOps.exif_transpose(image).convert("RGB")
36
 
37
- # Resize if needed
38
  w, h = image.size
39
  max_dim = max(w, h)
40
  if max_dim > 1024:
41
  scale = 1024 / max_dim
42
  image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
43
 
44
- # Light contrast boost
45
  image = ImageEnhance.Contrast(image).enhance(1.05)
46
 
47
  return image
48
 
49
 
50
  def np_to_base64(img_np, format="JPEG"):
51
- """Convert numpy RGB image to Base64 string."""
52
  pil_img = Image.fromarray(img_np)
53
  buffer = io.BytesIO()
54
  pil_img.save(buffer, format=format)
@@ -56,11 +59,13 @@ def np_to_base64(img_np, format="JPEG"):
56
 
57
 
58
  def base64_to_pil(b64_str):
59
- """Convert Base64 string back to PIL image (for Gradio display)."""
60
  return Image.open(io.BytesIO(base64.b64decode(b64_str)))
61
 
62
 
63
- # --- Main detection ---
 
 
64
  def detect_gingivitis(image, conf=0.25, iou=0.5):
65
  try:
66
  if image is None:
@@ -69,7 +74,7 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
69
  # Load models (only once)
70
  sw_model, rd_model, bl_model = get_models()
71
 
72
- # Preprocess image
73
  image = preprocess(image)
74
 
75
  # Run detections
@@ -77,7 +82,7 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
77
  rd_res = rd_model.predict(image, conf=conf, iou=iou)
78
  bl_res = bl_model.predict(image, conf=conf, iou=iou)
79
 
80
- # Convert to images
81
  img_sw = sw_res[0].plot(labels=False)[:, :, ::-1]
82
  img_rd = rd_res[0].plot(labels=False)[:, :, ::-1]
83
  img_bl = bl_res[0].plot(labels=False)[:, :, ::-1]
@@ -86,7 +91,7 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
86
  rd_pil = base64_to_pil(np_to_base64(img_rd))
87
  bl_pil = base64_to_pil(np_to_base64(img_bl))
88
 
89
- # Determine diagnosis
90
  has_sw = len(sw_res[0].boxes) > 0
91
  has_rd = len(rd_res[0].boxes) > 0
92
  has_bl = len(bl_res[0].boxes) > 0
@@ -103,12 +108,12 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
103
  return [sw_pil, rd_pil, bl_pil, diagnosis]
104
 
105
  except Exception as e:
106
- # Catch all errors and return a friendly message
107
  return [None, None, None, f"❌ Error during processing: {str(e)}"]
108
 
109
 
110
-
111
- # --- Gradio Interface ---
 
112
  interface = gr.Interface(
113
  fn=detect_gingivitis,
114
  inputs=[
@@ -125,6 +130,7 @@ interface = gr.Interface(
125
  title="Gingivitis Detection"
126
  )
127
 
 
128
  # =========================================================
129
  # Warm-start: preload models on startup
130
  # =========================================================
@@ -132,6 +138,9 @@ print("🔥 Preloading models to reduce Render cold start...")
132
  get_models()
133
  print("✅ Gingivitis models ready")
134
 
135
- # interface.launch()
 
 
 
136
  if __name__ == "__main__":
137
- interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
27
  return model_swelling, model_redness, model_bleeding
28
 
29
 
30
+ # =========================================================
31
+ # Helper functions
32
+ # =========================================================
33
  def preprocess(image):
34
+ """Resize, fix orientation, improve contrast."""
35
  if isinstance(image, np.ndarray):
36
  image = Image.fromarray(image)
37
 
38
  image = ImageOps.exif_transpose(image).convert("RGB")
39
 
40
+ # Resize if too large
41
  w, h = image.size
42
  max_dim = max(w, h)
43
  if max_dim > 1024:
44
  scale = 1024 / max_dim
45
  image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
46
 
47
+ # Slight contrast enhancement
48
  image = ImageEnhance.Contrast(image).enhance(1.05)
49
 
50
  return image
51
 
52
 
53
  def np_to_base64(img_np, format="JPEG"):
54
+ """Convert numpy RGB image to Base64."""
55
  pil_img = Image.fromarray(img_np)
56
  buffer = io.BytesIO()
57
  pil_img.save(buffer, format=format)
 
59
 
60
 
61
  def base64_to_pil(b64_str):
62
+ """Convert Base64 string to PIL image."""
63
  return Image.open(io.BytesIO(base64.b64decode(b64_str)))
64
 
65
 
66
+ # =========================================================
67
+ # Main detection function
68
+ # =========================================================
69
  def detect_gingivitis(image, conf=0.25, iou=0.5):
70
  try:
71
  if image is None:
 
74
  # Load models (only once)
75
  sw_model, rd_model, bl_model = get_models()
76
 
77
+ # Preprocess
78
  image = preprocess(image)
79
 
80
  # Run detections
 
82
  rd_res = rd_model.predict(image, conf=conf, iou=iou)
83
  bl_res = bl_model.predict(image, conf=conf, iou=iou)
84
 
85
+ # Convert YOLO output → numpy → PIL
86
  img_sw = sw_res[0].plot(labels=False)[:, :, ::-1]
87
  img_rd = rd_res[0].plot(labels=False)[:, :, ::-1]
88
  img_bl = bl_res[0].plot(labels=False)[:, :, ::-1]
 
91
  rd_pil = base64_to_pil(np_to_base64(img_rd))
92
  bl_pil = base64_to_pil(np_to_base64(img_bl))
93
 
94
+ # Diagnosis logic
95
  has_sw = len(sw_res[0].boxes) > 0
96
  has_rd = len(rd_res[0].boxes) > 0
97
  has_bl = len(bl_res[0].boxes) > 0
 
108
  return [sw_pil, rd_pil, bl_pil, diagnosis]
109
 
110
  except Exception as e:
 
111
  return [None, None, None, f"❌ Error during processing: {str(e)}"]
112
 
113
 
114
+ # =========================================================
115
+ # Gradio Interface
116
+ # =========================================================
117
  interface = gr.Interface(
118
  fn=detect_gingivitis,
119
  inputs=[
 
130
  title="Gingivitis Detection"
131
  )
132
 
133
+
134
  # =========================================================
135
  # Warm-start: preload models on startup
136
  # =========================================================
 
138
  get_models()
139
  print("✅ Gingivitis models ready")
140
 
141
+
142
+ # =========================================================
143
+ # Start server
144
+ # =========================================================
145
  if __name__ == "__main__":
146
+ interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True)