Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
| 1 |
-
# Enhanced Face-Based Lab Test Predictor with AI Models for 30 Lab Metrics
|
| 2 |
-
|
| 3 |
import gradio as gr
|
| 4 |
import cv2
|
| 5 |
import numpy as np
|
| 6 |
import mediapipe as mp
|
| 7 |
from sklearn.linear_model import LinearRegression
|
| 8 |
import random
|
|
|
|
| 9 |
|
|
|
|
| 10 |
mp_face_mesh = mp.solutions.face_mesh
|
| 11 |
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
|
| 12 |
|
|
|
|
| 13 |
def extract_features(image, landmarks):
|
| 14 |
red_channel = image[:, :, 2]
|
| 15 |
green_channel = image[:, :, 1]
|
|
@@ -21,6 +22,7 @@ def extract_features(image, landmarks):
|
|
| 21 |
|
| 22 |
return [red_percent, green_percent, blue_percent]
|
| 23 |
|
|
|
|
| 24 |
def train_model(output_range):
|
| 25 |
X = [[random.uniform(0.2, 0.5), random.uniform(0.05, 0.2), random.uniform(0.05, 0.2),
|
| 26 |
random.uniform(0.2, 0.5), random.uniform(0.2, 0.5), random.uniform(0.2, 0.5),
|
|
@@ -29,14 +31,12 @@ def train_model(output_range):
|
|
| 29 |
model = LinearRegression().fit(X, y)
|
| 30 |
return model
|
| 31 |
|
| 32 |
-
|
| 33 |
hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
|
| 34 |
-
|
| 35 |
-
hemoglobin_r2 = 0.385
|
| 36 |
-
import joblib
|
| 37 |
spo2_model = joblib.load("spo2_model_simulated.pkl")
|
| 38 |
hr_model = joblib.load("heart_rate_model.pkl")
|
| 39 |
|
|
|
|
| 40 |
models = {
|
| 41 |
"Hemoglobin": hemoglobin_model,
|
| 42 |
"WBC Count": train_model((4.0, 11.0)),
|
|
@@ -59,6 +59,7 @@ models = {
|
|
| 59 |
"Temperature": train_model((97, 99))
|
| 60 |
}
|
| 61 |
|
|
|
|
| 62 |
def get_risk_color(value, normal_range):
|
| 63 |
low, high = normal_range
|
| 64 |
if value < low:
|
|
@@ -68,6 +69,7 @@ def get_risk_color(value, normal_range):
|
|
| 68 |
else:
|
| 69 |
return ("Normal", "✅", "#CCFFCC")
|
| 70 |
|
|
|
|
| 71 |
def build_table(title, rows):
|
| 72 |
html = (
|
| 73 |
f'<div style="margin-bottom: 24px;">'
|
|
@@ -81,9 +83,9 @@ def build_table(title, rows):
|
|
| 81 |
html += '</tbody></table></div>'
|
| 82 |
return html
|
| 83 |
|
|
|
|
| 84 |
def analyze_video(video_path):
|
| 85 |
import matplotlib.pyplot as plt
|
| 86 |
-
from PIL import Image
|
| 87 |
cap = cv2.VideoCapture(video_path)
|
| 88 |
brightness_vals = []
|
| 89 |
green_vals = []
|
|
@@ -99,135 +101,26 @@ def analyze_video(video_path):
|
|
| 99 |
brightness_vals.append(np.mean(gray))
|
| 100 |
green_vals.append(np.mean(green))
|
| 101 |
cap.release()
|
| 102 |
-
|
|
|
|
| 103 |
brightness_std = np.std(brightness_vals) / 255
|
| 104 |
green_std = np.std(green_vals) / 255
|
| 105 |
tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
|
| 106 |
hr_features = [brightness_std, green_std, tone_index]
|
| 107 |
heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
|
| 108 |
-
|
| 109 |
-
brightness_variation = np.std(cv2.cvtColor(frame_sample, cv2.COLOR_BGR2GRAY)) / 255
|
| 110 |
-
spo2_features = [heart_rate, brightness_variation, skin_tone_index]
|
| 111 |
-
spo2 = spo2_model.predict([spo2_features])[0]
|
| 112 |
-
rr = int(12 + abs(heart_rate % 5 - 2))
|
| 113 |
-
plt.figure(figsize=(6, 2))
|
| 114 |
-
plt.plot(brightness_vals, label='rPPG Signal')
|
| 115 |
-
plt.title("Simulated rPPG Signal")
|
| 116 |
-
plt.xlabel("Frame")
|
| 117 |
-
plt.ylabel("Brightness")
|
| 118 |
-
plt.legend()
|
| 119 |
-
plt.tight_layout()
|
| 120 |
-
plot_path = "/tmp/ppg_plot.png"
|
| 121 |
-
plt.savefig(plot_path)
|
| 122 |
-
plt.close()
|
| 123 |
-
# Reuse frame_sample for full analysis
|
| 124 |
-
frame_rgb = cv2.cvtColor(frame_sample, cv2.COLOR_BGR2RGB)
|
| 125 |
-
result = face_mesh.process(frame_rgb)
|
| 126 |
-
if not result.multi_face_landmarks:
|
| 127 |
-
return "<div style='color:red;'>⚠️ Face not detected in video.</div>", frame_rgb
|
| 128 |
-
landmarks = result.multi_face_landmarks[0].landmark
|
| 129 |
-
features = extract_features(frame_rgb, landmarks)
|
| 130 |
-
test_values = {}
|
| 131 |
-
r2_scores = {}
|
| 132 |
-
for label in models:
|
| 133 |
-
if label == "Hemoglobin":
|
| 134 |
-
prediction = models[label].predict([features])[0]
|
| 135 |
-
test_values[label] = prediction
|
| 136 |
-
r2_scores[label] = hemoglobin_r2
|
| 137 |
-
else:
|
| 138 |
-
value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
|
| 139 |
-
test_values[label] = value
|
| 140 |
-
r2_scores[label] = 0.0
|
| 141 |
-
html_output = "".join([
|
| 142 |
-
f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
|
| 143 |
-
build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
|
| 144 |
-
build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
|
| 145 |
-
build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
|
| 146 |
-
build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
|
| 147 |
-
build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
|
| 148 |
-
build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
|
| 149 |
-
build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
|
| 150 |
-
])
|
| 151 |
-
summary = "<div style='margin-top:20px;padding:12px;border:1px dashed #999;background:#fcfcfc;'>"
|
| 152 |
-
summary += "<h4>📝 Summary for You</h4><ul>"
|
| 153 |
-
if test_values["Hemoglobin"] < 13.5:
|
| 154 |
-
summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
|
| 155 |
-
if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
|
| 156 |
-
summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
|
| 157 |
-
if test_values["Bilirubin"] > 1.2:
|
| 158 |
-
summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
|
| 159 |
-
if test_values["HbA1c"] > 5.7:
|
| 160 |
-
summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
|
| 161 |
-
if spo2 < 95:
|
| 162 |
-
summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
|
| 163 |
-
summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
|
| 164 |
-
html_output += summary
|
| 165 |
-
html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
|
| 166 |
-
html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
|
| 167 |
-
html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
|
| 168 |
-
return html_output, frame_rgb
|
| 169 |
-
|
| 170 |
-
def analyze_face(image):
|
| 171 |
-
if image is None:
|
| 172 |
-
return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None
|
| 173 |
-
frame_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 174 |
-
result = face_mesh.process(frame_rgb)
|
| 175 |
-
if not result.multi_face_landmarks:
|
| 176 |
-
return "<div style='color:red;'>⚠️ Error: Face not detected.</div>", None
|
| 177 |
-
landmarks = result.multi_face_landmarks[0].landmark
|
| 178 |
-
features = extract_features(frame_rgb, landmarks)
|
| 179 |
-
test_values = {}
|
| 180 |
-
r2_scores = {}
|
| 181 |
-
for label in models:
|
| 182 |
-
if label == "Hemoglobin":
|
| 183 |
-
prediction = models[label].predict([features])[0]
|
| 184 |
-
test_values[label] = prediction
|
| 185 |
-
r2_scores[label] = hemoglobin_r2
|
| 186 |
-
else:
|
| 187 |
-
value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
|
| 188 |
-
test_values[label] = value
|
| 189 |
-
r2_scores[label] = 0.0 # simulate other 7D inputs
|
| 190 |
-
gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)
|
| 191 |
-
green_std = np.std(frame_rgb[:, :, 1]) / 255
|
| 192 |
-
brightness_std = np.std(gray) / 255
|
| 193 |
-
tone_index = np.mean(frame_rgb[100:150, 100:150]) / 255 if frame_rgb[100:150, 100:150].size else 0.5
|
| 194 |
-
hr_features = [brightness_std, green_std, tone_index]
|
| 195 |
-
heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
|
| 196 |
-
skin_patch = frame_rgb[100:150, 100:150]
|
| 197 |
-
skin_tone_index = np.mean(skin_patch) / 255 if skin_patch.size else 0.5
|
| 198 |
-
brightness_variation = np.std(cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)) / 255
|
| 199 |
-
spo2_features = [heart_rate, brightness_variation, skin_tone_index]
|
| 200 |
spo2 = spo2_model.predict([spo2_features])[0]
|
| 201 |
-
|
|
|
|
| 202 |
html_output = "".join([
|
| 203 |
-
|
| 204 |
-
build_table("
|
| 205 |
-
build_table("
|
| 206 |
-
build_table("
|
| 207 |
-
build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
|
| 208 |
-
build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
|
| 209 |
-
build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
|
| 210 |
-
build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
|
| 211 |
])
|
| 212 |
-
|
| 213 |
-
summary += "<h4>📝 Summary for You</h4><ul>"
|
| 214 |
-
if test_values["Hemoglobin"] < 13.5:
|
| 215 |
-
summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
|
| 216 |
-
if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
|
| 217 |
-
summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
|
| 218 |
-
if test_values["Bilirubin"] > 1.2:
|
| 219 |
-
summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
|
| 220 |
-
if test_values["HbA1c"] > 5.7:
|
| 221 |
-
summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
|
| 222 |
-
if spo2 < 95:
|
| 223 |
-
summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
|
| 224 |
-
summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
|
| 225 |
-
html_output += summary
|
| 226 |
-
html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
|
| 227 |
-
html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
|
| 228 |
-
html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
|
| 229 |
-
return html_output, frame_rgb
|
| 230 |
|
|
|
|
| 231 |
with gr.Blocks() as demo:
|
| 232 |
gr.Markdown("""
|
| 233 |
# 🧠 Face-Based Lab Test AI Report (Video Mode)
|
|
@@ -244,7 +137,7 @@ with gr.Blocks() as demo:
|
|
| 244 |
result_image = gr.Image(label="📷 Key Frame Snapshot")
|
| 245 |
|
| 246 |
def route_inputs(mode, image, video):
|
| 247 |
-
return analyze_video(video) if mode == "Video" else
|
| 248 |
|
| 249 |
submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
|
| 250 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import mediapipe as mp
|
| 5 |
from sklearn.linear_model import LinearRegression
|
| 6 |
import random
|
| 7 |
+
import joblib
|
| 8 |
|
| 9 |
+
# Setup for Face Mesh detection
|
| 10 |
mp_face_mesh = mp.solutions.face_mesh
|
| 11 |
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
|
| 12 |
|
| 13 |
+
# Function to extract color features from the image
|
| 14 |
def extract_features(image, landmarks):
|
| 15 |
red_channel = image[:, :, 2]
|
| 16 |
green_channel = image[:, :, 1]
|
|
|
|
| 22 |
|
| 23 |
return [red_percent, green_percent, blue_percent]
|
| 24 |
|
| 25 |
+
# Mock models training (for demonstration)
|
| 26 |
def train_model(output_range):
|
| 27 |
X = [[random.uniform(0.2, 0.5), random.uniform(0.05, 0.2), random.uniform(0.05, 0.2),
|
| 28 |
random.uniform(0.2, 0.5), random.uniform(0.2, 0.5), random.uniform(0.2, 0.5),
|
|
|
|
| 31 |
model = LinearRegression().fit(X, y)
|
| 32 |
return model
|
| 33 |
|
| 34 |
+
# Load pre-trained models for Hemoglobin, SPO2, and Heart Rate
|
| 35 |
hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
|
|
|
|
|
|
|
|
|
|
| 36 |
spo2_model = joblib.load("spo2_model_simulated.pkl")
|
| 37 |
hr_model = joblib.load("heart_rate_model.pkl")
|
| 38 |
|
| 39 |
+
# Model dictionary setup for other tests
|
| 40 |
models = {
|
| 41 |
"Hemoglobin": hemoglobin_model,
|
| 42 |
"WBC Count": train_model((4.0, 11.0)),
|
|
|
|
| 59 |
"Temperature": train_model((97, 99))
|
| 60 |
}
|
| 61 |
|
| 62 |
+
# Function to determine risk level
|
| 63 |
def get_risk_color(value, normal_range):
|
| 64 |
low, high = normal_range
|
| 65 |
if value < low:
|
|
|
|
| 69 |
else:
|
| 70 |
return ("Normal", "✅", "#CCFFCC")
|
| 71 |
|
| 72 |
+
# Function to build an HTML table for displaying test results
|
| 73 |
def build_table(title, rows):
|
| 74 |
html = (
|
| 75 |
f'<div style="margin-bottom: 24px;">'
|
|
|
|
| 83 |
html += '</tbody></table></div>'
|
| 84 |
return html
|
| 85 |
|
| 86 |
+
# Analyzing video for health metrics
|
| 87 |
def analyze_video(video_path):
|
| 88 |
import matplotlib.pyplot as plt
|
|
|
|
| 89 |
cap = cv2.VideoCapture(video_path)
|
| 90 |
brightness_vals = []
|
| 91 |
green_vals = []
|
|
|
|
| 101 |
brightness_vals.append(np.mean(gray))
|
| 102 |
green_vals.append(np.mean(green))
|
| 103 |
cap.release()
|
| 104 |
+
|
| 105 |
+
# Simulate heart rate and SPO2 estimation
|
| 106 |
brightness_std = np.std(brightness_vals) / 255
|
| 107 |
green_std = np.std(green_vals) / 255
|
| 108 |
tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
|
| 109 |
hr_features = [brightness_std, green_std, tone_index]
|
| 110 |
heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
|
| 111 |
+
spo2_features = [heart_rate, np.std(brightness_vals), np.mean(frame_sample[100:150, 100:150])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
spo2 = spo2_model.predict([spo2_features])[0]
|
| 113 |
+
|
| 114 |
+
# Generating the health card with test results
|
| 115 |
html_output = "".join([
|
| 116 |
+
build_table("🩸 Hematology", [("Hemoglobin", models["Hemoglobin"].predict([hr_features])[0], (13.5, 17.5))]),
|
| 117 |
+
build_table("🧬 Iron Panel", [("Iron", models["Iron"].predict([hr_features])[0], (60, 170))]),
|
| 118 |
+
build_table("🧪 Electrolytes", [("Sodium", models["Sodium"].predict([hr_features])[0], (135, 145))]),
|
| 119 |
+
build_table("❤️ Vitals", [("Heart Rate", heart_rate, (60, 100)), ("SpO2", spo2, (95, 100))]),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
])
|
| 121 |
+
return html_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
# Gradio Interface setup
|
| 124 |
with gr.Blocks() as demo:
|
| 125 |
gr.Markdown("""
|
| 126 |
# 🧠 Face-Based Lab Test AI Report (Video Mode)
|
|
|
|
| 137 |
result_image = gr.Image(label="📷 Key Frame Snapshot")
|
| 138 |
|
| 139 |
def route_inputs(mode, image, video):
|
| 140 |
+
return analyze_video(video) if mode == "Video" else analyze_video(image)
|
| 141 |
|
| 142 |
submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
|
| 143 |
|