Spaces:
Sleeping
Sleeping
File size: 23,127 Bytes
a8ab250 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 | import base64
import io
import json
import os
import re
import secrets
from functools import wraps
import matplotlib
import matplotlib.cm as cm
import numpy as np
import requests as req
import tensorflow as tf
from flask import Flask, g, jsonify, request
from flask_cors import CORS
from PIL import Image
from dotenv import load_dotenv
from pymongo import MongoClient
from pymongo.server_api import ServerApi
from werkzeug.security import check_password_hash, generate_password_hash
matplotlib.use("Agg")
BASE_DIR = os.path.dirname(__file__)
load_dotenv(os.path.join(BASE_DIR, ".env"))
MODEL_PATH = os.path.join(BASE_DIR, "densenet121_covid_final.keras")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-3-4b-it:free")
OPENROUTER_FALLBACK_MODELS = [
model_name.strip()
for model_name in os.getenv(
"OPENROUTER_FALLBACK_MODELS",
"meta-llama/llama-3.1-8b-instruct:free,mistralai/mistral-7b-instruct:free",
).split(",")
if model_name.strip()
]
MONGODB_URI = os.getenv("MONGODB_URI", "")
MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME", "pulmoai")
FRONTEND_ORIGIN = os.getenv("FRONTEND_ORIGIN", "*")
PORT = int(os.getenv("PORT", "5000"))
DEFAULT_FRONTEND_ORIGINS = [
"https://pulomonary-ai-care-system.vercel.app",
"https://pulmonary-ai-care-system.vercel.app",
"http://localhost:5173",
]
def parse_frontend_origins(origin_value):
if origin_value == "*":
return "*"
origins = [origin.strip().rstrip("/") for origin in origin_value.split(",") if origin.strip()]
return list(dict.fromkeys([*origins, *DEFAULT_FRONTEND_ORIGINS]))
FRONTEND_ORIGINS = parse_frontend_origins(FRONTEND_ORIGIN)
def is_allowed_origin(origin):
return FRONTEND_ORIGINS == "*" or origin in FRONTEND_ORIGINS
app = Flask(__name__)
CORS(
app,
supports_credentials=True,
resources={r"/*": {"origins": FRONTEND_ORIGINS}}
if FRONTEND_ORIGINS != "*"
else {r"/*": {"origins": "*"}},
)
@app.after_request
def add_cors_headers(response):
origin = (request.headers.get("Origin") or "").rstrip("/")
if origin and is_allowed_origin(origin):
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Vary"] = "Origin"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
return response
CLASSES = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]
DISPLAY_NAMES = {
"COVID": "COVID-19",
"Lung_Opacity": "Lung Opacity",
"Normal": "Normal",
"Viral Pneumonia": "Viral Pneumonia",
}
SEVERITY = {
"COVID": "High",
"Lung_Opacity": "Moderate",
"Normal": "Low",
"Viral Pneumonia": "High",
}
LOCAL_ADVICE_BY_DISEASE = {
"Normal": {
"summary": "The X-ray pattern appears within normal limits based on the model output. Clinical correlation is still important if symptoms such as fever, chest pain, or shortness of breath are present.",
"urgency": "Low",
"precautions": [
"Continue routine health monitoring and seek care if new respiratory symptoms develop.",
"Maintain good hydration and adequate sleep to support overall recovery and wellness.",
"Avoid smoking and secondhand smoke exposure.",
"Use a mask and hand hygiene in crowded settings if respiratory infections are circulating.",
"Keep preventive vaccinations and routine checkups up to date.",
],
"medications": [
{"name": "No routine medication", "dosage": "As advised", "frequency": "Not routinely needed", "purpose": "No medication is typically needed for a normal study without symptoms."}
],
"lifestyle": [
"Stay physically active as tolerated.",
"Maintain a balanced diet rich in fluids, fruits, and protein.",
"Follow up with a clinician if symptoms persist despite a normal image result.",
],
"followUp": "Routine follow-up only if symptoms continue or worsen.",
},
"Lung Opacity": {
"summary": "Lung opacity can reflect infection, inflammation, fluid, or other causes and should be interpreted with symptoms and examination findings. A clinician should review the result to determine whether further imaging, oxygen assessment, or treatment is needed.",
"urgency": "Moderate",
"precautions": [
"Seek prompt medical review, especially if you have fever, cough, chest pain, or breathlessness.",
"Monitor oxygen saturation if available and seek urgent care for worsening shortness of breath.",
"Avoid smoking, vaping, and dusty or polluted environments.",
"Rest adequately and limit strenuous activity until reviewed.",
"Maintain hydration unless a clinician has advised fluid restriction.",
],
"medications": [
{"name": "Symptom-directed treatment", "dosage": "As prescribed", "frequency": "As directed", "purpose": "Management depends on whether the opacity is due to infection, inflammation, or another cause."}
],
"lifestyle": [
"Prioritize rest and breathing comfort.",
"Use steam inhalation or humidified air only if comfortable and clinician-approved.",
"Keep follow-up imaging or lab appointments if recommended.",
],
"followUp": "Arrange clinical review within 24-48 hours, sooner if symptoms are significant or worsening.",
},
"Viral Pneumonia": {
"summary": "The model suggests viral pneumonia, which may cause cough, fever, fatigue, and reduced oxygen levels. Clinical assessment is important to determine severity and whether home care or hospital treatment is needed.",
"urgency": "High",
"precautions": [
"Seek medical evaluation promptly, especially for shortness of breath or persistent fever.",
"Check oxygen saturation if available and seek urgent care for low readings or breathing difficulty.",
"Rest, isolate if infection is suspected, and use a mask around others.",
"Drink fluids regularly unless you have been told to restrict intake.",
"Go to emergency care immediately for confusion, bluish lips, or rapidly worsening breathing.",
],
"medications": [
{"name": "Supportive care", "dosage": "As prescribed", "frequency": "As directed", "purpose": "Treatment often focuses on fever control, hydration, and condition-specific therapy after clinician review."}
],
"lifestyle": [
"Reduce exertion and allow time for recovery.",
"Avoid smoking and alcohol during illness.",
"Track temperature, breathing symptoms, and oxygen levels if possible.",
],
"followUp": "Same-day clinical review is recommended if symptoms are active; emergency care is needed for worsening breathlessness or low oxygen.",
},
"COVID-19": {
"summary": "The model suggests a pattern concerning for COVID-19-related lung involvement. Severity can vary, so symptoms, oxygen level, and risk factors should guide whether home care, urgent review, or hospital evaluation is needed.",
"urgency": "High",
"precautions": [
"Isolate according to local guidance and wear a mask around others.",
"Seek prompt medical care for worsening cough, persistent fever, chest pain, or breathlessness.",
"Monitor oxygen saturation if available and seek urgent care for low readings.",
"Rest well and maintain hydration unless a clinician advises otherwise.",
"Get emergency help for severe shortness of breath, confusion, or bluish lips.",
],
"medications": [
{"name": "Supportive care", "dosage": "As prescribed", "frequency": "As directed", "purpose": "Clinical treatment depends on symptom severity, timing of illness, and patient risk factors."}
],
"lifestyle": [
"Rest and avoid strenuous physical activity while symptomatic.",
"Use good ventilation at home if isolating.",
"Follow clinician advice on testing, antiviral eligibility, and recovery monitoring.",
],
"followUp": "Medical review should be arranged promptly if symptoms are present, with urgent care for breathing issues or low oxygen.",
},
}
def build_local_advice(disease, confidence):
advice = LOCAL_ADVICE_BY_DISEASE.get(disease) or LOCAL_ADVICE_BY_DISEASE["Lung Opacity"]
summary_suffix = f" Model confidence was {float(confidence):.1f}%."
return {
**advice,
"summary": f"{advice['summary']}{summary_suffix}",
"source": "local-fallback",
"disclaimer": "This guidance is generated from predefined clinical safety rules and is not a substitute for a licensed medical evaluation.",
}
def parse_openrouter_advice(response):
try:
result = response.json()
except ValueError:
print(f"OpenRouter returned non-JSON response: {response.text[:500]}")
return None, "OpenRouter returned an invalid response", 502
if "choices" not in result or not result["choices"]:
print(f"OpenRouter response missing choices: {result}")
error_message = "Unable to generate AI advice"
if isinstance(result.get("error"), dict):
error_message = result["error"].get("message", error_message)
elif result.get("error"):
error_message = str(result["error"])
return None, error_message, 502
text = result["choices"][0].get("message", {}).get("content", "")
text = re.sub(r"```json|```", "", text).strip()
match = re.search(r"\{[\s\S]*\}", text)
if not match:
print(f"OpenRouter response did not contain JSON: {text[:500]}")
return None, "AI advice response was not valid JSON", 502
advice_json = match.group(0)
try:
parsed_advice = json.loads(advice_json)
except ValueError:
print(f"Failed to parse AI advice JSON: {advice_json[:500]}")
return None, "AI advice response could not be parsed", 502
return parsed_advice, None, 200
def create_mongo_database():
if not MONGODB_URI:
raise RuntimeError("MONGODB_URI is not set. Add your MongoDB Atlas connection string.")
client = MongoClient(MONGODB_URI, server_api=ServerApi("1"))
database = client[MONGODB_DB_NAME]
database.users.create_index("email", unique=True)
database.sessions.create_index("token", unique=True)
database.sessions.create_index("user_id")
return database
mongo_db = create_mongo_database()
print("Loading model...")
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
print("Model loaded")
def serialize_user(user_document):
return {
"id": str(user_document["_id"]),
"fullName": user_document["full_name"],
"email": user_document["email"],
}
def create_session(user_id):
token = secrets.token_urlsafe(32)
mongo_db.sessions.insert_one({"token": token, "user_id": user_id})
return token
def get_authenticated_user():
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None, None
token = auth_header.split(" ", 1)[1].strip()
if not token:
return None, None
session_document = mongo_db.sessions.find_one({"token": token})
if not session_document:
return None, None
user_document = mongo_db.users.find_one({"_id": session_document["user_id"]})
return (user_document, token) if user_document else (None, None)
def require_auth(route_handler):
@wraps(route_handler)
def wrapper(*args, **kwargs):
if request.method == "OPTIONS":
return app.make_default_options_response()
user_document, token = get_authenticated_user()
if not user_document:
return jsonify({"error": "Authentication required"}), 401
g.current_user = user_document
g.current_token = token
return route_handler(*args, **kwargs)
return wrapper
def get_gradcam_heatmap(current_model, img_array, pred_index=None):
try:
last_conv_layer = None
for layer in reversed(current_model.layers):
if isinstance(layer, tf.keras.layers.Conv2D):
last_conv_layer = layer.name
break
if last_conv_layer is None:
return None
grad_model = tf.keras.models.Model(
inputs=current_model.input,
outputs=[current_model.get_layer(last_conv_layer).output, current_model.output],
)
img_tensor = tf.cast(img_array, tf.float32)
with tf.GradientTape() as tape:
tape.watch(img_tensor)
conv_outputs, predictions = grad_model(img_tensor)
tape.watch(conv_outputs)
class_channel = predictions[:, pred_index]
grads = tape.gradient(class_channel, conv_outputs)
if grads is None:
return None
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
return heatmap.numpy()
except Exception as exc:
print(f"Grad-CAM error: {exc}")
return None
def overlay_gradcam(original_img, heatmap, alpha=0.4):
heatmap_resized = np.uint8(255 * heatmap)
jet = cm.get_cmap("jet")
jet_colors = jet(np.arange(256))[:, :3]
jet_heatmap = jet_colors[heatmap_resized]
jet_heatmap = tf.keras.preprocessing.image.array_to_img(jet_heatmap)
jet_heatmap = jet_heatmap.resize((original_img.shape[1], original_img.shape[0]))
jet_heatmap = tf.keras.preprocessing.image.img_to_array(jet_heatmap)
superimposed = jet_heatmap * alpha + original_img
return np.clip(superimposed, 0, 255).astype(np.uint8)
@app.route("/auth/signup", methods=["POST"])
def sign_up():
data = request.get_json() or {}
full_name = (data.get("fullName") or "").strip()
email = (data.get("email") or "").strip().lower()
password = data.get("password") or ""
if len(full_name) < 2:
return jsonify({"error": "Full name must be at least 2 characters"}), 400
if not re.fullmatch(r"[^@\s]+@[^@\s]+\.[^@\s]+", email):
return jsonify({"error": "Please enter a valid email address"}), 400
if len(password) < 6:
return jsonify({"error": "Password must be at least 6 characters"}), 400
if mongo_db.users.find_one({"email": email}):
return jsonify({"error": "An account with this email already exists"}), 409
user_result = mongo_db.users.insert_one(
{
"full_name": full_name,
"email": email,
"password_hash": generate_password_hash(password),
}
)
user_document = mongo_db.users.find_one({"_id": user_result.inserted_id})
token = create_session(user_document["_id"])
return jsonify({"token": token, "user": serialize_user(user_document)}), 201
@app.route("/auth/signin", methods=["POST"])
def sign_in():
data = request.get_json() or {}
email = (data.get("email") or "").strip().lower()
password = data.get("password") or ""
user_document = mongo_db.users.find_one({"email": email})
if not user_document or not check_password_hash(user_document["password_hash"], password):
return jsonify({"error": "Invalid email or password"}), 401
token = create_session(user_document["_id"])
return jsonify({"token": token, "user": serialize_user(user_document)}), 200
@app.route("/auth/me", methods=["GET"])
@require_auth
def auth_me():
return jsonify({"user": serialize_user(g.current_user)}), 200
@app.route("/auth/logout", methods=["POST"])
@require_auth
def logout():
mongo_db.sessions.delete_one({"token": g.current_token})
return jsonify({"message": "Logged out successfully"}), 200
@app.route("/predict", methods=["POST", "OPTIONS"])
@require_auth
def predict():
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
print(f"Processing file: {file.filename}")
try:
img = Image.open(file.stream).convert("RGB")
img_resized = img.resize((224, 224))
img_array = np.array(img_resized) / 255.0
img_array = np.expand_dims(img_array, axis=0)
print("Running model prediction...")
preds = model.predict(img_array)[0]
pred_index = int(np.argmax(preds))
pred_class = CLASSES[pred_index]
print(f"Prediction complete: {pred_class}")
# Simplified response without Grad-CAM to reduce memory usage
findings = []
for index, class_name in enumerate(CLASSES):
findings.append(
{
"name": DISPLAY_NAMES[class_name],
"confidence": f"{float(preds[index]):.4f}",
"severity": SEVERITY[class_name],
"predicted": index == pred_index,
}
)
return jsonify(
{
"predicted": DISPLAY_NAMES[pred_class],
"predicted_class_display": DISPLAY_NAMES[pred_class],
"confidence": float(preds[pred_index]),
"findings": findings,
"model_version": "DenseNet121-COVID-v1.0",
}
)
except Exception as e:
print(f"Prediction error: {str(e)}")
return jsonify({"error": f"Prediction failed: {str(e)}"}), 500
@app.route("/ai-advice", methods=["POST"])
@require_auth
def ai_advice():
data = request.get_json() or {}
disease = data.get("disease", "")
confidence = data.get("confidence", 0)
findings = data.get("findings", [])
predictions_text = ", ".join(
[f"{finding['name']}: {float(finding['confidence']) * 100:.1f}%" for finding in findings]
)
prompt = f"""You are a clinical AI assistant helping a radiologist interpret a chest X-ray analysis result. A DenseNet121 model has analyzed the X-ray with the following results:
Primary Diagnosis: {disease} ({confidence:.1f}% confidence)
All predictions: {predictions_text}
Provide a structured clinical response in this exact JSON format only. No markdown, no text outside the JSON:
{{
"summary": "2-3 sentence clinical summary of the condition",
"urgency": "Low | Moderate | High | Critical",
"precautions": [
"Precaution 1",
"Precaution 2",
"Precaution 3",
"Precaution 4",
"Precaution 5"
],
"medications": [
{{"name": "Drug name", "dosage": "e.g. 500mg", "frequency": "e.g. twice daily", "purpose": "brief purpose"}},
{{"name": "Drug name", "dosage": "...", "frequency": "...", "purpose": "..."}}
],
"lifestyle": [
"Lifestyle recommendation 1",
"Lifestyle recommendation 2",
"Lifestyle recommendation 3"
],
"followUp": "When and what kind of follow-up is recommended"
}}
For Normal findings provide general wellness advice. For diseases provide clinically appropriate guidance."""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": FRONTEND_ORIGINS[0] if FRONTEND_ORIGINS != "*" and FRONTEND_ORIGINS else "http://localhost:5173",
"X-Title": "PulmoAI",
}
fallback_advice = build_local_advice(disease, confidence)
if not OPENROUTER_API_KEY:
return jsonify(
{
"advice": fallback_advice,
"warning": "OPENROUTER_API_KEY is not configured. Returned local fallback guidance.",
}
)
model_candidates = []
for model_name in [OPENROUTER_MODEL, *OPENROUTER_FALLBACK_MODELS]:
if model_name and model_name not in model_candidates:
model_candidates.append(model_name)
last_error = "Unable to generate AI advice"
last_status = 502
for model_name in model_candidates:
body = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
}
try:
response = req.post(
"https://openrouter.ai/api/v1/chat/completions",
json=body,
headers=headers,
timeout=60,
)
response.raise_for_status()
except req.RequestException as exc:
status_code = getattr(exc.response, "status_code", 502)
response_text = ""
if getattr(exc, "response", None) is not None:
response_text = exc.response.text[:500]
print(f"OpenRouter request failed for model {model_name}: {exc}. Response: {response_text}")
if status_code == 401:
return jsonify({"error": "OpenRouter API key is invalid, expired, or belongs to a deleted account"}), 401
last_status = status_code
if status_code == 429:
last_error = f"OpenRouter model '{model_name}' is temporarily rate-limited."
continue
last_error = f"Failed to fetch AI advice from OpenRouter using model '{model_name}'"
continue
parsed_advice, parse_error, parse_status = parse_openrouter_advice(response)
if parsed_advice is not None:
parsed_advice.setdefault("source", "openrouter")
parsed_advice.setdefault(
"disclaimer",
"AI-generated guidance should be reviewed by a licensed clinician before making medical decisions.",
)
return jsonify({"advice": parsed_advice, "model": model_name})
last_error = parse_error or last_error
last_status = parse_status
return jsonify(
{
"advice": fallback_advice,
"warning": (
f"{last_error} Returned local fallback guidance after trying: {', '.join(model_candidates)}."
),
}
), 200
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok", "model": os.path.basename(MODEL_PATH), "database": MONGODB_DB_NAME})
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=PORT)
|