Abs6187's picture
Update app.py
2bdc248 verified
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from flask import Flask, render_template, request, jsonify, redirect, url_for, send_from_directory
from flask_pymongo import PyMongo
from flask_bcrypt import Bcrypt
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import cv2
import google.generativeai as genai
from dotenv import load_dotenv
import certifi
import uuid
import secrets
import logging
load_dotenv()
app = Flask(__name__)
app.config["MONGO_URI"] = os.getenv("MONGODB_URI") or os.getenv("MONGO_URI")
app.config['SECRET_KEY'] = os.getenv("SECRET_KEY") or secrets.token_hex(16)
app.config.setdefault("SESSION_COOKIE_HTTPONLY", True)
app.config.setdefault("SESSION_COOKIE_SAMESITE", "Lax")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("app")
try:
if app.config["MONGO_URI"]:
mongo = PyMongo(app, tlsCAFile=certifi.where())
else:
logger.warning("MONGO_URI not set. MongoDB operations will fail.")
mongo = PyMongo(app, tlsCAFile=certifi.where())
except Exception as e:
logger.error(f"Mongo initialization error: {e}")
mongo = None
bcrypt = Bcrypt(app)
gemini_model = None
if GEMINI_API_KEY:
try:
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel('gemini-2.0-flash')
except Exception as e:
logger.error(f"Gemini initialization error: {e}")
else:
logger.warning("GEMINI_API_KEY/GOOGLE_API_KEY not set. /chat will return a friendly error.")
MODEL_CONFIG = {
"Pneumonia": {
"path": "model/best_pneumonia_model.h5",
"labels": ["Normal", "Pneumonia"],
"last_conv_layer": "relu",
"input_size": (224, 224)
},
"Tuberculosis": {
"path": "model/best_tuberculosis_model.h5",
"labels": ["Normal", "Tuberculosis"],
"last_conv_layer": "relu",
"input_size": (224, 224)
},
"Brain Tumor": {
"path": "model/best_braintumor_model.h5",
"labels": ["glioma", "meningioma", "notumor", "pituitary"],
"last_conv_layer": "relu",
"input_size": (224, 224)
},
"Skin Cancer": {
"path": "model/best_skincancer_model.h5",
"labels": ["Actinic keratoses", "Basal cell carcinoma", "Benign keratosis-like lesions",
"Dermatofibroma", "Melanoma", "Melanocytic nevi", "Vascular lesions"],
"last_conv_layer": "relu",
"input_size": (224, 224)
},
"Kvasir": {
"path": "model/best_kvasir_model.h5",
"labels": ["dyed-lifted-polyps", "dyed-resection-margins", "esophagitis",
"normal-cecum", "normal-pylorus", "normal-z-line", "polyps", "ulcerative-colitis"],
"last_conv_layer": "relu",
"input_size": (224, 224)
}
}
# Heuristic filename patterns for mapping examples per model
MODEL_EXAMPLE_PATTERNS = {
"Pneumonia": ["pneumonia", "normal-"],
"Tuberculosis": ["tuberculosis", "tb-"],
"Brain Tumor": ["glioma", "meningioma", "notumor", "pituitary", "brain"],
"Skin Cancer": ["melanoma", "nev", "keratos", "carcinoma", "vascular", "dermatofibroma", "skin"],
"Kvasir": [
"dyedlifted", "dyedresection", "esophagitis", "normalceacum", "normalpylorus",
"normalzline", "polypus", "ulcerative"
],
}
models = {}
def load_all_models():
for name, config in MODEL_CONFIG.items():
try:
model_path = config["path"]
if os.path.exists(model_path):
models[name] = load_model(model_path, compile=False)
logger.info(f"Successfully loaded {name} model from {model_path}.")
else:
logger.warning(f"Model file not found at {model_path}")
except Exception as e:
logger.error(f"Error loading model {name}: {e}")
load_all_models()
def preprocess_image(img_path, target_size=(224, 224)):
img = image.load_img(img_path, target_size=target_size)
img_array = image.img_to_array(img)
if img_array.ndim == 2:
img_array = np.stack([img_array]*3, axis=-1)
elif img_array.shape[-1] == 4:
img_array = img_array[..., :3]
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array.astype("float32") / 255.0
return img_array
def _safe_get_layer(model, layer_name):
try:
return model.get_layer(layer_name)
except Exception:
return None
def find_last_conv_layer(model):
for layer in reversed(model.layers):
if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.DepthwiseConv2D)):
try:
out_shape = layer.output_shape
except Exception:
out_shape = None
if out_shape and len(out_shape) == 4:
return layer.name
raise ValueError("Could not automatically find a convolutional layer in the model.")
def get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=None):
if not _safe_get_layer(model, last_conv_layer_name):
last_conv_layer_name = find_last_conv_layer(model)
conv_layer = model.get_layer(last_conv_layer_name)
grad_model = tf.keras.models.Model([model.inputs], [conv_layer.output, model.output])
with tf.GradientTape() as tape:
conv_outputs, preds = grad_model(img_array, training=False)
if isinstance(preds, (list, tuple)):
preds = preds[0]
preds = tf.convert_to_tensor(preds)
if preds.shape.rank is not None and preds.shape[-1] == 1:
class_channel = preds[:, 0]
else:
if pred_index is None:
pred_index = tf.argmax(preds[0])
class_channel = preds[:, pred_index]
grads = tape.gradient(class_channel, conv_outputs)
if grads is None:
heatmap = tf.zeros(conv_outputs.shape[1:3], dtype=tf.float32)
return heatmap.numpy()
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = tf.tensordot(conv_outputs, pooled_grads, axes=(2, 0))
heatmap = tf.maximum(heatmap, 0)
denom = tf.math.reduce_max(heatmap)
heatmap = heatmap / (denom + 1e-8)
return heatmap.numpy()
def save_gradcam_image(img_path, heatmap, output_path, threshold=0.6, alpha=0.4):
img = cv2.imread(img_path)
if img is None:
raise ValueError("Failed to read image with OpenCV.")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
mask = heatmap > threshold
overlay = np.zeros_like(img, dtype=np.uint8)
overlay[mask] = [255, 0, 0]
superimposed_img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0)
superimposed_img[~mask] = img[~mask]
superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, superimposed_img)
return output_path
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_IMAGES_DIR = os.path.join(BASE_DIR, 'testimages')
@app.route("/")
def home():
return redirect(url_for('index'))
@app.route('/tmp/<path:filename>')
def serve_tmp_file(filename):
return send_from_directory('/tmp', filename)
@app.route('/testimages/<path:filename>')
def serve_test_image(filename):
return send_from_directory(TEST_IMAGES_DIR, filename)
@app.route('/example_images')
def example_images():
try:
files = []
selected_model = (request.args.get('model') or '').strip()
patterns = MODEL_EXAMPLE_PATTERNS.get(selected_model, []) if selected_model else []
if os.path.isdir(TEST_IMAGES_DIR):
for f in os.listdir(TEST_IMAGES_DIR):
lf = f.lower()
if lf.endswith(('.png', '.jpg', '.jpeg')):
# If a model is selected and patterns exist, filter by them
if patterns:
if not any(p in lf for p in patterns):
continue
files.append(url_for('serve_test_image', filename=f))
return jsonify({"images": files})
except Exception as e:
logger.error(f"example_images error: {e}")
return jsonify({"images": []})
@app.route('/login', methods=['GET', 'POST'])
def login():
return redirect(url_for('index'))
@app.route('/signup', methods=['GET', 'POST'])
def signup():
return redirect(url_for('index'))
@app.route('/index')
def index():
return render_template('index.html')
@app.route('/logout')
def logout():
return redirect(url_for('index'))
def _postprocess_binary_prediction(raw):
arr = np.array(raw, dtype=np.float32)
arr = np.squeeze(arr)
if arr.ndim == 0:
prob = float(arr)
if prob < 0.0 or prob > 1.0:
prob = float(1.0 / (1.0 + np.exp(-prob)))
return min(max(prob, 0.0), 1.0)
prob = float(arr[0])
if prob < 0.0 or prob > 1.0:
prob = float(1.0 / (1.0 + np.exp(-prob)))
return min(max(prob, 0.0), 1.0)
@app.route("/predict", methods=["POST"])
def predict():
if "file" not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files["file"]
model_name = request.form.get("model")
if not file or file.filename == "":
return jsonify({"error": "No selected file"}), 400
if model_name not in models:
return jsonify({"error": "Invalid model selected"}), 400
try:
filename = f"{uuid.uuid4()}_{file.filename}"
filepath = os.path.join("/tmp", filename)
file.save(filepath)
model_config = MODEL_CONFIG[model_name]
model = models[model_name]
labels = model_config["labels"]
input_size = model_config.get("input_size", (224, 224))
img_array = preprocess_image(filepath, target_size=input_size)
prediction = model.predict(img_array, verbose=0)
prediction = np.array(prediction)
if len(labels) == 2 and prediction.ndim >= 1 and prediction.shape[-1] in (1,) and prediction.size >= 1:
prob_pos = _postprocess_binary_prediction(prediction)
if prob_pos >= 0.5:
predicted_index = 1
predicted_label = labels[1]
confidence = prob_pos
else:
predicted_index = 0
predicted_label = labels[0]
confidence = 1.0 - prob_pos
else:
if prediction.ndim == 2:
vec = prediction[0]
else:
vec = prediction.reshape(-1)
if np.any(vec < 0) or np.any(vec > 1) or not np.isclose(np.sum(vec), 1.0, atol=1e-3):
exps = np.exp(vec - np.max(vec))
probs = exps / (np.sum(exps) + 1e-8)
else:
probs = vec
predicted_index = int(np.argmax(probs))
predicted_label = labels[predicted_index]
confidence = float(np.max(probs))
gradcam_url = None
try:
last_conv_layer_name = MODEL_CONFIG[model_name].get('last_conv_layer') or ""
heatmap = get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=predicted_index)
gradcam_filename = f"gradcam_{filename}"
gradcam_filepath = os.path.join("/tmp", gradcam_filename)
save_gradcam_image(filepath, heatmap, gradcam_filepath)
gradcam_url = url_for('serve_tmp_file', filename=gradcam_filename)
except Exception as e:
logger.error(f"Grad-CAM error: {e}")
return jsonify({
"original_image": url_for('serve_tmp_file', filename=filename),
"gradcam_image": gradcam_url,
"prediction": str(predicted_label),
"confidence": float(confidence),
"model_used": str(model_name)
})
except Exception as e:
logger.exception("Prediction error")
return jsonify({"error": str(e)}), 500
@app.route("/chat", methods=["POST"])
def chat():
data = request.get_json(silent=True) or {}
user_message = data.get("message", "")
prediction_context = data.get("context") or {}
model_used = prediction_context.get('model_used', 'Unknown Model')
pred_label = prediction_context.get('prediction', 'Unknown')
conf = prediction_context.get('confidence', 0.0)
try:
conf_pct = float(conf) * 100.0
except Exception:
conf_pct = 0.0
prompt = f"""
You are a helpful medical assistant chatbot.
A medical image was analyzed with the following results:
- Model Used: {model_used}
- Prediction: {pred_label}
- Confidence Score: {conf_pct:.2f}%
The user's question is: "{user_message}"
Based on this context, provide a helpful and informative response.
Do not provide a diagnosis. Advise the user to consult a medical professional.
"""
try:
if gemini_model is None:
return jsonify({"error": "Gemini API not configured. Set GEMINI_API_KEY in environment."}), 500
response = gemini_model.generate_content(prompt)
text = getattr(response, "text", None)
if not text:
text = str(response)
return jsonify({"response": text})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(debug=True)