jayn95 commited on
Commit
25f3eef
·
verified ·
1 Parent(s): a2418a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -132
app.py CHANGED
@@ -1,133 +1,65 @@
1
- from flask import Flask, request, jsonify
2
- from gradio_client import Client, handle_file
3
- import tempfile, base64, os, threading
4
- from huggingface_hub import InferenceClient
5
-
6
- from flask_cors import CORS # ✅ allow mobile apps to call this API
7
-
8
- app = Flask(__name__)
9
- CORS(app) # enable CORS for all routes
10
-
11
- # Your Hugging Face Spaces
12
- GINGIVITIS_SPACE = "jayn95/deepdent_gingivitis"
13
- PERIODONTITIS_SPACE = "jayn95/deepdent_periodontitis"
14
-
15
-
16
- def call_huggingface(space_name, image_path, labels=None, flatten=False, timeout_seconds=120):
17
- """Call HF Space in a separate thread with timeout."""
18
- client = Client(space_name)
19
- result_container = {}
20
-
21
- def run_predict():
22
- result_container["data"] = client.predict(
23
- handle_file(image_path),
24
- 0.4,
25
- 0.5,
26
- api_name="/predict"
27
- )
28
-
29
- thread = threading.Thread(target=run_predict)
30
- thread.start()
31
- thread.join(timeout=timeout_seconds)
32
-
33
- if thread.is_alive():
34
- raise TimeoutError(f"Hugging Face request to {space_name} timed out after {timeout_seconds}s")
35
-
36
- result = result_container.get("data", [])
37
-
38
- # Flatten nested list if needed
39
- flat_result = []
40
- if flatten:
41
- for r in result:
42
- if isinstance(r, (list, tuple)):
43
- flat_result.extend(r)
44
- else:
45
- flat_result.append(r)
46
- else:
47
- flat_result = result
48
-
49
- # Auto-generate labels if None
50
- if labels is None:
51
- labels = []
52
- if space_name == PERIODONTITIS_SPACE:
53
- num_teeth = len(flat_result) // 2
54
- for i in range(num_teeth):
55
- for m in ["cej", "abc"]:
56
- labels.append(f"tooth{i+1}_{m}")
57
  else:
58
- labels = [f"output{i+1}" for i in range(len(flat_result))]
59
-
60
- # Encode results as base64
61
- encoded_results = {}
62
- for label, path in zip(labels, flat_result):
63
- if os.path.exists(path):
64
- with open(path, "rb") as f:
65
- encoded_results[label] = base64.b64encode(f.read()).decode("utf-8")
66
- else:
67
- encoded_results[label] = None
68
-
69
- return encoded_results
70
-
71
-
72
- @app.route("/")
73
- def home():
74
- return jsonify({"status": "DeepDent backend running successfully!"})
75
-
76
-
77
- @app.route("/predict/gingivitis", methods=["POST"])
78
- def predict_gingivitis():
79
- try:
80
- image = request.files.get("image")
81
- if not image:
82
- return jsonify({"error": "No image provided"}), 400
83
-
84
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
85
- image.save(temp_file.name)
86
- temp_path = temp_file.name
87
-
88
- encoded_results = call_huggingface(
89
- GINGIVITIS_SPACE,
90
- temp_path,
91
- labels=["swelling", "redness", "bleeding"]
92
- )
93
-
94
- os.remove(temp_path)
95
- return jsonify({"images": encoded_results})
96
-
97
- except TimeoutError as te:
98
- return jsonify({"error": str(te)}), 504
99
- except Exception as e:
100
- return jsonify({"error": str(e)}), 500
101
-
102
-
103
- @app.route("/predict/periodontitis", methods=["POST"])
104
- def predict_periodontitis():
105
- try:
106
- image = request.files.get("image")
107
- if not image:
108
- return jsonify({"error": "No image provided"}), 400
109
-
110
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
111
- image.save(temp_file.name)
112
- temp_path = temp_file.name
113
-
114
- encoded_results = call_huggingface(
115
- PERIODONTITIS_SPACE,
116
- temp_path,
117
- labels=None,
118
- flatten=True
119
- )
120
-
121
- os.remove(temp_path)
122
- return jsonify({"images": encoded_results})
123
-
124
- except TimeoutError as te:
125
- return jsonify({"error": str(te)}), 504
126
- except Exception as e:
127
- return jsonify({"error": str(e)}), 500
128
-
129
-
130
- # 🚀 Required function for Hugging Face Spaces
131
- def start():
132
- # HF Spaces automatically assigns host/port
133
- app.run()
 
1
+ # app.py
2
+ import gradio as gr
3
+ import cv2
4
+ from periodontitis_detection import SimpleDentalSegmentationNoEnhance
5
+
6
+ # ==========================
7
+ # 1️⃣ Load models once
8
+ # ==========================
9
+ model = SimpleDentalSegmentationNoEnhance(
10
+ unet_model_path="unet.keras", # same filenames as your repo
11
+ yolo_model_path="best2.pt"
12
+ )
13
+
14
+ # ==========================
15
+ # 2️⃣ Define wrapper for Gradio
16
+ # ==========================
17
+ def detect_periodontitis(image_np):
18
+ """
19
+ Gradio sends image as a NumPy RGB array.
20
+ We temporarily save it to a file path since analyze_image() needs a path.
21
+ """
22
+ temp_path = "temp_input.jpg"
23
+ cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
24
+
25
+ # Run full pipeline
26
+ results = model.analyze_image(temp_path)
27
+
28
+ # Convert OpenCV BGR → RGB for Gradio display
29
+ combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
30
+
31
+ # Optional: summarize measurements for text output
32
+ summaries = []
33
+ for tooth in results["distance_analyses"]:
34
+ tooth_id = tooth["tooth_id"]
35
+ analysis = tooth["analysis"]
36
+ if analysis:
37
+ mean_d = analysis["mean_distance"]
38
+ summaries.append(f"Tooth {tooth_id}: mean={mean_d:.2f}px")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  else:
40
+ summaries.append(f"Tooth {tooth_id}: no valid CEJ–ABC measurement")
41
+
42
+ summary_text = "\n".join(summaries) if summaries else "No detections found."
43
+
44
+ return combined_rgb, summary_text
45
+
46
+
47
+ # ==========================
48
+ # 3️⃣ Build Gradio Interface
49
+ # ==========================
50
+ demo = gr.Interface(
51
+ fn=detect_periodontitis,
52
+ inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
53
+ outputs=[
54
+ gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
55
+ gr.Textbox(label="Analysis Summary"),
56
+ ],
57
+ title="🦷 Periodontitis Detection & Analysis",
58
+ description=(
59
+ "Automatically detects teeth (YOLOv8), segments CEJ/ABC (U-Net), "
60
+ "and measures CEJ–ABC distances per tooth to assess bone loss."
61
+ ),
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)