jayn95 commited on
Commit
bcbbf99
·
verified ·
1 Parent(s): 68709d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -41
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from ultralytics import YOLO
3
- from PIL import Image
4
- from PIL import ImageOps, ImageEnhance
5
  import numpy as np
6
  import tempfile
7
 
@@ -10,61 +9,58 @@ model_swelling = YOLO("models/swelling/best.pt")
10
  model_redness = YOLO("models/redness/best.pt")
11
  model_bleeding = YOLO("models/bleeding/best.pt")
12
 
13
- def save_temp_image(img):
14
- """Save PIL image to temp file and return the file path."""
15
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
16
  img.save(temp.name, format="JPEG")
17
  return temp.name
18
 
19
- def preprocess_image(image, max_size=1024, contrast_factor=1.05):
20
- if image is None:
21
- return image
22
-
23
  if isinstance(image, np.ndarray):
24
  image = Image.fromarray(image)
25
 
26
- image = ImageOps.exif_transpose(image)
27
- image = image.convert("RGB")
28
 
 
29
  w, h = image.size
30
- max_side = max(w, h)
31
- if max_side > max_size:
32
- scale = max_size / float(max_side)
33
- new_size = (int(w * scale), int(h * scale))
34
- image = image.resize(new_size, Image.LANCZOS)
35
 
36
- enhancer = ImageEnhance.Contrast(image)
37
- image = enhancer.enhance(contrast_factor)
38
 
39
  return image
40
 
41
- def detect_gingivitis(image, conf_threshold=0.5, iou_threshold=0.5):
42
- image = preprocess_image(image)
43
 
44
- # Swelling
45
- res_swelling = model_swelling.predict(image, conf=conf_threshold, iou=iou_threshold)
46
- swelling_img = Image.fromarray(res_swelling[0].plot()[..., ::-1])
47
- swelling_path = save_temp_image(swelling_img)
48
 
49
- # Redness
50
- res_redness = model_redness.predict(image, conf=conf_threshold, iou=iou_threshold)
51
- redness_img = Image.fromarray(res_redness[0].plot()[..., ::-1])
52
- redness_path = save_temp_image(redness_img)
53
 
54
- # Bleeding
55
- res_bleeding = model_bleeding.predict(image, conf=conf_threshold, iou=iou_threshold)
56
- bleeding_img = Image.fromarray(res_bleeding[0].plot()[..., ::-1])
57
- bleeding_path = save_temp_image(bleeding_img)
58
 
59
- # Diagnosis text
60
- has_swelling = len(res_swelling[0].boxes) > 0
61
- has_redness = len(res_redness[0].boxes) > 0
62
- has_bleeding = len(res_bleeding[0].boxes) > 0
63
 
64
- diagnosis = "Gingivitis Detected" if (has_swelling and has_redness and has_bleeding) else "No Gingivitis Detected"
65
 
66
- # Must return EXACTLY what Flask expects:
67
- return [swelling_path, redness_path, bleeding_path, diagnosis]
 
68
 
69
  # Gradio Interface
70
  interface = gr.Interface(
@@ -78,10 +74,9 @@ interface = gr.Interface(
78
  gr.File(label="Swelling Detection"),
79
  gr.File(label="Redness Detection"),
80
  gr.File(label="Bleeding Detection"),
81
- gr.Textbox(label="Gingivitis Diagnosis")
82
  ],
83
- title="Gingivitis Detection (3 Models)",
84
- description="Detect swelling, redness, and bleeding using 3 YOLO models.",
85
  )
86
 
87
  interface.launch()
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
+ from PIL import Image, ImageOps, ImageEnhance
 
4
  import numpy as np
5
  import tempfile
6
 
 
9
  model_redness = YOLO("models/redness/best.pt")
10
  model_bleeding = YOLO("models/bleeding/best.pt")
11
 
12
+ def save_temp_file(img):
13
+ """Save PIL image to a temporary file and return the file path."""
14
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
15
  img.save(temp.name, format="JPEG")
16
  return temp.name
17
 
18
+ def preprocess(image):
 
 
 
19
  if isinstance(image, np.ndarray):
20
  image = Image.fromarray(image)
21
 
22
+ image = ImageOps.exif_transpose(image).convert("RGB")
 
23
 
24
+ # Resize if needed
25
  w, h = image.size
26
+ max_dim = max(w, h)
27
+ if max_dim > 1024:
28
+ scale = 1024 / max_dim
29
+ image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
 
30
 
31
+ # Light contrast boost
32
+ image = ImageEnhance.Contrast(image).enhance(1.05)
33
 
34
  return image
35
 
36
+ def detect_gingivitis(image, conf=0.4, iou=0.5):
37
+ image = preprocess(image)
38
 
39
+ # >> Run models <<
40
+ sw_res = model_swelling.predict(image, conf=conf, iou=iou)
41
+ rd_res = model_redness.predict(image, conf=conf, iou=iou)
42
+ bl_res = model_bleeding.predict(image, conf=conf, iou=iou)
43
 
44
+ # >> Create annotated PIL images <<
45
+ img_sw = Image.fromarray(sw_res[0].plot()[..., ::-1])
46
+ img_rd = Image.fromarray(rd_res[0].plot()[..., ::-1])
47
+ img_bl = Image.fromarray(bl_res[0].plot()[..., ::-1])
48
 
49
+ # >> Save to temporary files <<
50
+ sw_path = save_temp_file(img_sw)
51
+ rd_path = save_temp_file(img_rd)
52
+ bl_path = save_temp_file(img_bl)
53
 
54
+ # >> Diagnosis <<
55
+ has_sw = len(sw_res[0].boxes) > 0
56
+ has_rd = len(rd_res[0].boxes) > 0
57
+ has_bl = len(bl_res[0].boxes) > 0
58
 
59
+ diagnosis = "Gingivitis Detected" if (has_sw and has_rd and has_bl) else "No Gingivitis Detected"
60
 
61
+ # *** CRITICAL ***
62
+ # Must return FILE PATHS + diagnosis string in ONE LIST
63
+ return [sw_path, rd_path, bl_path, diagnosis]
64
 
65
  # Gradio Interface
66
  interface = gr.Interface(
 
74
  gr.File(label="Swelling Detection"),
75
  gr.File(label="Redness Detection"),
76
  gr.File(label="Bleeding Detection"),
77
+ gr.Textbox(label="Diagnosis")
78
  ],
79
+ title="Gingivitis Detection"
 
80
  )
81
 
82
  interface.launch()