Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| from flask import Flask, render_template, request, jsonify, redirect, url_for, send_from_directory | |
| from flask_pymongo import PyMongo | |
| from flask_bcrypt import Bcrypt | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing import image | |
| import numpy as np | |
| import cv2 | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| import certifi | |
| import uuid | |
| import secrets | |
| import logging | |
| load_dotenv() | |
| app = Flask(__name__) | |
| app.config["MONGO_URI"] = os.getenv("MONGODB_URI") or os.getenv("MONGO_URI") | |
| app.config['SECRET_KEY'] = os.getenv("SECRET_KEY") or secrets.token_hex(16) | |
| app.config.setdefault("SESSION_COOKIE_HTTPONLY", True) | |
| app.config.setdefault("SESSION_COOKIE_SAMESITE", "Lax") | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("app") | |
| try: | |
| if app.config["MONGO_URI"]: | |
| mongo = PyMongo(app, tlsCAFile=certifi.where()) | |
| else: | |
| logger.warning("MONGO_URI not set. MongoDB operations will fail.") | |
| mongo = PyMongo(app, tlsCAFile=certifi.where()) | |
| except Exception as e: | |
| logger.error(f"Mongo initialization error: {e}") | |
| mongo = None | |
| bcrypt = Bcrypt(app) | |
| gemini_model = None | |
| if GEMINI_API_KEY: | |
| try: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| gemini_model = genai.GenerativeModel('gemini-2.0-flash') | |
| except Exception as e: | |
| logger.error(f"Gemini initialization error: {e}") | |
| else: | |
| logger.warning("GEMINI_API_KEY/GOOGLE_API_KEY not set. /chat will return a friendly error.") | |
| MODEL_CONFIG = { | |
| "Pneumonia": { | |
| "path": "model/best_pneumonia_model.h5", | |
| "labels": ["Normal", "Pneumonia"], | |
| "last_conv_layer": "relu", | |
| "input_size": (224, 224) | |
| }, | |
| "Tuberculosis": { | |
| "path": "model/best_tuberculosis_model.h5", | |
| "labels": ["Normal", "Tuberculosis"], | |
| "last_conv_layer": "relu", | |
| "input_size": (224, 224) | |
| }, | |
| "Brain Tumor": { | |
| "path": "model/best_braintumor_model.h5", | |
| "labels": ["glioma", "meningioma", "notumor", "pituitary"], | |
| "last_conv_layer": "relu", | |
| "input_size": (224, 224) | |
| }, | |
| "Skin Cancer": { | |
| "path": "model/best_skincancer_model.h5", | |
| "labels": ["Actinic keratoses", "Basal cell carcinoma", "Benign keratosis-like lesions", | |
| "Dermatofibroma", "Melanoma", "Melanocytic nevi", "Vascular lesions"], | |
| "last_conv_layer": "relu", | |
| "input_size": (224, 224) | |
| }, | |
| "Kvasir": { | |
| "path": "model/best_kvasir_model.h5", | |
| "labels": ["dyed-lifted-polyps", "dyed-resection-margins", "esophagitis", | |
| "normal-cecum", "normal-pylorus", "normal-z-line", "polyps", "ulcerative-colitis"], | |
| "last_conv_layer": "relu", | |
| "input_size": (224, 224) | |
| } | |
| } | |
| # Heuristic filename patterns for mapping examples per model | |
| MODEL_EXAMPLE_PATTERNS = { | |
| "Pneumonia": ["pneumonia", "normal-"], | |
| "Tuberculosis": ["tuberculosis", "tb-"], | |
| "Brain Tumor": ["glioma", "meningioma", "notumor", "pituitary", "brain"], | |
| "Skin Cancer": ["melanoma", "nev", "keratos", "carcinoma", "vascular", "dermatofibroma", "skin"], | |
| "Kvasir": [ | |
| "dyedlifted", "dyedresection", "esophagitis", "normalceacum", "normalpylorus", | |
| "normalzline", "polypus", "ulcerative" | |
| ], | |
| } | |
| models = {} | |
| def load_all_models(): | |
| for name, config in MODEL_CONFIG.items(): | |
| try: | |
| model_path = config["path"] | |
| if os.path.exists(model_path): | |
| models[name] = load_model(model_path, compile=False) | |
| logger.info(f"Successfully loaded {name} model from {model_path}.") | |
| else: | |
| logger.warning(f"Model file not found at {model_path}") | |
| except Exception as e: | |
| logger.error(f"Error loading model {name}: {e}") | |
| load_all_models() | |
| def preprocess_image(img_path, target_size=(224, 224)): | |
| img = image.load_img(img_path, target_size=target_size) | |
| img_array = image.img_to_array(img) | |
| if img_array.ndim == 2: | |
| img_array = np.stack([img_array]*3, axis=-1) | |
| elif img_array.shape[-1] == 4: | |
| img_array = img_array[..., :3] | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = img_array.astype("float32") / 255.0 | |
| return img_array | |
| def _safe_get_layer(model, layer_name): | |
| try: | |
| return model.get_layer(layer_name) | |
| except Exception: | |
| return None | |
| def find_last_conv_layer(model): | |
| for layer in reversed(model.layers): | |
| if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.DepthwiseConv2D)): | |
| try: | |
| out_shape = layer.output_shape | |
| except Exception: | |
| out_shape = None | |
| if out_shape and len(out_shape) == 4: | |
| return layer.name | |
| raise ValueError("Could not automatically find a convolutional layer in the model.") | |
| def get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=None): | |
| if not _safe_get_layer(model, last_conv_layer_name): | |
| last_conv_layer_name = find_last_conv_layer(model) | |
| conv_layer = model.get_layer(last_conv_layer_name) | |
| grad_model = tf.keras.models.Model([model.inputs], [conv_layer.output, model.output]) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, preds = grad_model(img_array, training=False) | |
| if isinstance(preds, (list, tuple)): | |
| preds = preds[0] | |
| preds = tf.convert_to_tensor(preds) | |
| if preds.shape.rank is not None and preds.shape[-1] == 1: | |
| class_channel = preds[:, 0] | |
| else: | |
| if pred_index is None: | |
| pred_index = tf.argmax(preds[0]) | |
| class_channel = preds[:, pred_index] | |
| grads = tape.gradient(class_channel, conv_outputs) | |
| if grads is None: | |
| heatmap = tf.zeros(conv_outputs.shape[1:3], dtype=tf.float32) | |
| return heatmap.numpy() | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| conv_outputs = conv_outputs[0] | |
| heatmap = tf.tensordot(conv_outputs, pooled_grads, axes=(2, 0)) | |
| heatmap = tf.maximum(heatmap, 0) | |
| denom = tf.math.reduce_max(heatmap) | |
| heatmap = heatmap / (denom + 1e-8) | |
| return heatmap.numpy() | |
| def save_gradcam_image(img_path, heatmap, output_path, threshold=0.6, alpha=0.4): | |
| img = cv2.imread(img_path) | |
| if img is None: | |
| raise ValueError("Failed to read image with OpenCV.") | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) | |
| mask = heatmap > threshold | |
| overlay = np.zeros_like(img, dtype=np.uint8) | |
| overlay[mask] = [255, 0, 0] | |
| superimposed_img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) | |
| superimposed_img[~mask] = img[~mask] | |
| superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(output_path, superimposed_img) | |
| return output_path | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| TEST_IMAGES_DIR = os.path.join(BASE_DIR, 'testimages') | |
| def home(): | |
| return redirect(url_for('index')) | |
| def serve_tmp_file(filename): | |
| return send_from_directory('/tmp', filename) | |
| def serve_test_image(filename): | |
| return send_from_directory(TEST_IMAGES_DIR, filename) | |
| def example_images(): | |
| try: | |
| files = [] | |
| selected_model = (request.args.get('model') or '').strip() | |
| patterns = MODEL_EXAMPLE_PATTERNS.get(selected_model, []) if selected_model else [] | |
| if os.path.isdir(TEST_IMAGES_DIR): | |
| for f in os.listdir(TEST_IMAGES_DIR): | |
| lf = f.lower() | |
| if lf.endswith(('.png', '.jpg', '.jpeg')): | |
| # If a model is selected and patterns exist, filter by them | |
| if patterns: | |
| if not any(p in lf for p in patterns): | |
| continue | |
| files.append(url_for('serve_test_image', filename=f)) | |
| return jsonify({"images": files}) | |
| except Exception as e: | |
| logger.error(f"example_images error: {e}") | |
| return jsonify({"images": []}) | |
| def login(): | |
| return redirect(url_for('index')) | |
| def signup(): | |
| return redirect(url_for('index')) | |
| def index(): | |
| return render_template('index.html') | |
| def logout(): | |
| return redirect(url_for('index')) | |
| def _postprocess_binary_prediction(raw): | |
| arr = np.array(raw, dtype=np.float32) | |
| arr = np.squeeze(arr) | |
| if arr.ndim == 0: | |
| prob = float(arr) | |
| if prob < 0.0 or prob > 1.0: | |
| prob = float(1.0 / (1.0 + np.exp(-prob))) | |
| return min(max(prob, 0.0), 1.0) | |
| prob = float(arr[0]) | |
| if prob < 0.0 or prob > 1.0: | |
| prob = float(1.0 / (1.0 + np.exp(-prob))) | |
| return min(max(prob, 0.0), 1.0) | |
| def predict(): | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file part"}), 400 | |
| file = request.files["file"] | |
| model_name = request.form.get("model") | |
| if not file or file.filename == "": | |
| return jsonify({"error": "No selected file"}), 400 | |
| if model_name not in models: | |
| return jsonify({"error": "Invalid model selected"}), 400 | |
| try: | |
| filename = f"{uuid.uuid4()}_{file.filename}" | |
| filepath = os.path.join("/tmp", filename) | |
| file.save(filepath) | |
| model_config = MODEL_CONFIG[model_name] | |
| model = models[model_name] | |
| labels = model_config["labels"] | |
| input_size = model_config.get("input_size", (224, 224)) | |
| img_array = preprocess_image(filepath, target_size=input_size) | |
| prediction = model.predict(img_array, verbose=0) | |
| prediction = np.array(prediction) | |
| if len(labels) == 2 and prediction.ndim >= 1 and prediction.shape[-1] in (1,) and prediction.size >= 1: | |
| prob_pos = _postprocess_binary_prediction(prediction) | |
| if prob_pos >= 0.5: | |
| predicted_index = 1 | |
| predicted_label = labels[1] | |
| confidence = prob_pos | |
| else: | |
| predicted_index = 0 | |
| predicted_label = labels[0] | |
| confidence = 1.0 - prob_pos | |
| else: | |
| if prediction.ndim == 2: | |
| vec = prediction[0] | |
| else: | |
| vec = prediction.reshape(-1) | |
| if np.any(vec < 0) or np.any(vec > 1) or not np.isclose(np.sum(vec), 1.0, atol=1e-3): | |
| exps = np.exp(vec - np.max(vec)) | |
| probs = exps / (np.sum(exps) + 1e-8) | |
| else: | |
| probs = vec | |
| predicted_index = int(np.argmax(probs)) | |
| predicted_label = labels[predicted_index] | |
| confidence = float(np.max(probs)) | |
| gradcam_url = None | |
| try: | |
| last_conv_layer_name = MODEL_CONFIG[model_name].get('last_conv_layer') or "" | |
| heatmap = get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=predicted_index) | |
| gradcam_filename = f"gradcam_{filename}" | |
| gradcam_filepath = os.path.join("/tmp", gradcam_filename) | |
| save_gradcam_image(filepath, heatmap, gradcam_filepath) | |
| gradcam_url = url_for('serve_tmp_file', filename=gradcam_filename) | |
| except Exception as e: | |
| logger.error(f"Grad-CAM error: {e}") | |
| return jsonify({ | |
| "original_image": url_for('serve_tmp_file', filename=filename), | |
| "gradcam_image": gradcam_url, | |
| "prediction": str(predicted_label), | |
| "confidence": float(confidence), | |
| "model_used": str(model_name) | |
| }) | |
| except Exception as e: | |
| logger.exception("Prediction error") | |
| return jsonify({"error": str(e)}), 500 | |
| def chat(): | |
| data = request.get_json(silent=True) or {} | |
| user_message = data.get("message", "") | |
| prediction_context = data.get("context") or {} | |
| model_used = prediction_context.get('model_used', 'Unknown Model') | |
| pred_label = prediction_context.get('prediction', 'Unknown') | |
| conf = prediction_context.get('confidence', 0.0) | |
| try: | |
| conf_pct = float(conf) * 100.0 | |
| except Exception: | |
| conf_pct = 0.0 | |
| prompt = f""" | |
| You are a helpful medical assistant chatbot. | |
| A medical image was analyzed with the following results: | |
| - Model Used: {model_used} | |
| - Prediction: {pred_label} | |
| - Confidence Score: {conf_pct:.2f}% | |
| The user's question is: "{user_message}" | |
| Based on this context, provide a helpful and informative response. | |
| Do not provide a diagnosis. Advise the user to consult a medical professional. | |
| """ | |
| try: | |
| if gemini_model is None: | |
| return jsonify({"error": "Gemini API not configured. Set GEMINI_API_KEY in environment."}), 500 | |
| response = gemini_model.generate_content(prompt) | |
| text = getattr(response, "text", None) | |
| if not text: | |
| text = str(response) | |
| return jsonify({"response": text}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| app.run(debug=True) |