Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import cv2 | |
| import sqlite3 | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from flask import Flask, render_template, request, redirect, url_for, flash | |
| from werkzeug.utils import secure_filename | |
| from datetime import datetime | |
| from markupsafe import escape | |
| from huggingface_hub import hf_hub_download | |
| classification_model = None | |
| segmentation_model = None | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| OpenAI = None | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = 'your_secret_key' | |
| app.config['UPLOAD_FOLDER'] = 'static/uploads' | |
| app.config['RESULTS_FOLDER'] = 'static/results' | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 | |
| app.config['ALLOWED_EXTENSIONS'] = {'jpg', 'jpeg', 'png'} | |
| app.config['OPENROUTER_MODEL'] = 'openai/gpt-oss-120b' | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| def get_model_path(): | |
| classification_path = hf_hub_download( | |
| repo_id = "MohammedAH/Brrain-MRI-Classification", | |
| filename = "brain_mri.h5" | |
| ) | |
| segmentation_path = hf_hub_download( | |
| repo_id = "MohammedAH/Unet-Brain-Segmentation", | |
| filename = "Unet_model.h5" | |
| ) | |
| return classification_path, segmentation_path | |
| # Create directories | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True) | |
| # Class names for the classification model | |
| class_names = ['glioma', 'meningioma', 'no_tumor', 'pituitary'] | |
| # Database initialization | |
| def init_db(): | |
| conn = sqlite3.connect('brain_mri.db') | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS analyses ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| filename TEXT NOT NULL, | |
| original_path TEXT NOT NULL, | |
| result_path TEXT NOT NULL, | |
| classification TEXT NOT NULL, | |
| confidence REAL NOT NULL, | |
| summary TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| with app.app_context(): | |
| init_db() | |
| # Helper functions for model inference | |
| def dice_coefficient(y_true, y_pred, smooth=1): | |
| y_true_f = tf.keras.backend.flatten(y_true) | |
| y_pred_f = tf.keras.backend.flatten(y_pred) | |
| intersection = tf.keras.backend.sum(y_true_f * y_pred_f) | |
| return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth) | |
| def dice_loss(y_true, y_pred): | |
| return 1 - dice_coefficient(y_true, y_pred) | |
| def iou(y_true, y_pred, smooth=1): | |
| y_true_f = tf.keras.backend.flatten(y_true) | |
| y_pred_f = tf.keras.backend.flatten(y_pred) | |
| intersection = tf.keras.backend.sum(y_true_f * y_pred_f) | |
| total = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) | |
| union = total - intersection | |
| return (intersection + smooth) / (union + smooth) | |
| # Load the models | |
| def load_models(): | |
| global classification_model, segmentation_model | |
| if classification_model is None or segmentation_model is None: | |
| classification_path, segmentation_path = get_model_path() | |
| classification_model = load_model(classification_path, compile=False) | |
| segmentation_model = load_model( | |
| segmentation_path, | |
| compile=False, | |
| custom_objects={ | |
| 'dice_coefficient': dice_coefficient, | |
| 'dice_loss':dice_loss, | |
| 'iou': iou | |
| } | |
| ) | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| def preprocess_image_for_classification(image_path): | |
| img = cv2.imread(image_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, (224, 224)) | |
| img = img / 255.0 | |
| return np.expand_dims(img, axis=0) | |
| def preprocess_image_for_segmentation(image_path): | |
| img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) | |
| img = cv2.resize(img, (128, 128)) | |
| img = img / 255.0 | |
| return np.expand_dims(img, axis=(0, -1)) | |
| def format_summary_html(text): | |
| paragraphs = [segment.strip() for segment in text.split('\n') if segment.strip()] | |
| if not paragraphs: | |
| return "<p>No summary available.</p>" | |
| formatted = ''.join(f"<p>{escape(paragraph)}</p>" for paragraph in paragraphs) | |
| disclaimer = ( | |
| "<p><strong>Note:</strong> This explanation is educational and must not be used " | |
| "as a medical diagnosis or treatment plan.</p>" | |
| ) | |
| return formatted + disclaimer | |
| def get_fallback_summary(classification, confidence): | |
| base_summaries = { | |
| 'glioma': ( | |
| "Glioma is a tumor arising from glial tissue in the brain. Clinical significance usually depends on tumor grade, location, and surrounding brain involvement. " | |
| "Typical follow-up in real care includes radiology review, symptom correlation, and specialist evaluation." | |
| ), | |
| 'meningioma': ( | |
| "Meningioma usually develops from the membranes surrounding the brain and is often slow-growing, though behavior varies by subtype and location. " | |
| "Real-world next steps often include reviewing size, pressure effect, growth pattern, and need for observation versus intervention." | |
| ), | |
| 'pituitary': ( | |
| "Pituitary tumors involve the pituitary region and may matter because of hormone effects, visual pathway compression, or local mass effect. " | |
| "Clinical workup commonly includes endocrine assessment and focused review of symptoms such as headache or visual disturbance." | |
| ), | |
| 'no_tumor': ( | |
| "No tumor class was detected by the model for this image. That does not rule out other abnormalities, image quality issues, or findings outside the model's scope. " | |
| "Formal interpretation should still depend on a qualified radiology or neurology workflow when clinically needed." | |
| ) | |
| } | |
| if confidence > 0.9: | |
| confidence_note = "The model confidence is high for this predicted class." | |
| elif confidence > 0.7: | |
| confidence_note = "The model confidence is moderate for this predicted class." | |
| else: | |
| confidence_note = "The model confidence is limited, so the output should be treated cautiously." | |
| combined = f"{base_summaries.get(classification, 'The predicted class is not recognized by the summary helper.')} {confidence_note}" | |
| return format_summary_html(combined) | |
| def get_openrouter_summary(classification, confidence): | |
| api_key = os.environ.get('OPENROUTER_API_KEY') | |
| if not api_key or OpenAI is None: | |
| return get_fallback_summary(classification, confidence) | |
| client = OpenAI( | |
| base_url='https://openrouter.ai/api/v1', | |
| api_key=api_key, | |
| default_headers={ | |
| 'HTTP-Referer': 'http://127.0.0.1:5000', | |
| 'X-OpenRouter-Title': 'NeuroScope MRI' | |
| } | |
| ) | |
| prompt = ( | |
| f"Brain MRI model output:\n" | |
| f"- Predicted classification: {classification}\n" | |
| f"- Confidence: {confidence:.4f}\n\n" | |
| "Write a concise educational explanation for a web app result page. " | |
| "Explain what this tumor class generally means, why the confidence level matters, " | |
| "what clinical factors are usually reviewed next, and one caution about not using AI output alone. " | |
| "If the class is no_tumor, explain that no tumor was detected by the model but other abnormalities may still require professional review. " | |
| "Keep it to 3 short paragraphs in plain text. Do not claim a diagnosis. Do not mention that you are an AI model." | |
| ) | |
| try: | |
| response = client.chat.completions.create( | |
| model=app.config['OPENROUTER_MODEL'], | |
| messages=[ | |
| { | |
| 'role': 'system', | |
| 'content': ( | |
| "You write careful educational MRI result summaries for a student medical imaging app. " | |
| "Be clinically literate, concise, and explicit that the output is not a diagnosis." | |
| ) | |
| }, | |
| { | |
| 'role': 'user', | |
| 'content': prompt | |
| } | |
| ], | |
| temperature=0.4, | |
| max_tokens=350 | |
| ) | |
| content = (response.choices[0].message.content or '').strip() | |
| if not content: | |
| return get_fallback_summary(classification, confidence) | |
| return format_summary_html(content) | |
| except Exception: | |
| return get_fallback_summary(classification, confidence) | |
| def save_to_database(filename, original_path, result_path, classification, confidence, summary): | |
| conn = sqlite3.connect('brain_mri.db') | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO analyses (filename, original_path, result_path, classification, confidence, summary) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| ''', (filename, original_path, result_path, classification, confidence, summary)) | |
| analysis_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return analysis_id | |
| # Routes | |
| def index(): | |
| return render_template('index.html') | |
| def analyze(): | |
| return render_template('upload.html') | |
| def upload_file(): | |
| if 'mri_image' not in request.files: | |
| flash('No file part', 'error') | |
| return redirect(url_for('analyze')) | |
| file = request.files['mri_image'] | |
| if file.filename == '': | |
| flash('No selected file', 'error') | |
| return redirect(url_for('analyze')) | |
| if file and allowed_file(file.filename): | |
| # Save original image | |
| filename = secure_filename(file.filename) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| saved_filename = f"{timestamp}_{filename}" | |
| original_path = os.path.join(app.config['UPLOAD_FOLDER'], saved_filename) | |
| file.save(original_path) | |
| try: | |
| load_models() | |
| except Exception as exc: | |
| flash(str(exc), 'error') | |
| return redirect(url_for('analyze')) | |
| # Classify the image | |
| classification_input = preprocess_image_for_classification(original_path) | |
| predictions = classification_model.predict(classification_input) | |
| predicted_class_index = np.argmax(predictions[0]) | |
| predicted_class = class_names[predicted_class_index] | |
| confidence = float(predictions[0][predicted_class_index]) | |
| # Segment the image | |
| segmentation_input = preprocess_image_for_segmentation(original_path) | |
| segmentation_mask = segmentation_model.predict(segmentation_input) | |
| segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) | |
| # Create overlay image | |
| plt.figure(figsize=(10, 8)) | |
| # Original image | |
| img = cv2.imread(original_path, cv2.IMREAD_GRAYSCALE) | |
| img_resized = cv2.resize(img, (128, 128)) | |
| # Display original and mask overlay | |
| plt.subplot(1, 2, 1) | |
| plt.imshow(img_resized, cmap='gray') | |
| plt.title('Original MRI') | |
| plt.axis('off') | |
| plt.subplot(1, 2, 2) | |
| plt.imshow(img_resized, cmap='gray') | |
| plt.imshow(segmentation_mask[0, :, :, 0], alpha=0.5, cmap='jet') | |
| plt.title('Tumor Segmentation') | |
| plt.axis('off') | |
| # Save the result | |
| result_filename = f"result_{saved_filename.split('.')[0]}.png" | |
| result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename) | |
| plt.savefig(result_path, bbox_inches='tight') | |
| plt.close() | |
| # Get an educational insight summary from OpenRouter via an OpenAI-compatible API. | |
| summary = get_openrouter_summary(predicted_class, confidence) | |
| # Save to database | |
| analysis_id = save_to_database( | |
| saved_filename, | |
| original_path, | |
| result_path, | |
| predicted_class, | |
| confidence, | |
| summary | |
| ) | |
| # Redirect to results page | |
| return redirect(url_for('result', analysis_id=analysis_id)) | |
| flash('Invalid file type. Please upload JPG, JPEG, or PNG files.', 'error') | |
| return redirect(url_for('analyze')) | |
| def result(analysis_id): | |
| conn = sqlite3.connect('brain_mri.db') | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT * FROM analyses WHERE id = ?', (analysis_id,)) | |
| analysis = cursor.fetchone() | |
| conn.close() | |
| if analysis: | |
| return render_template('result.html', analysis=analysis) | |
| else: | |
| flash('Analysis not found', 'error') | |
| return redirect(url_for('index')) | |
| def history(): | |
| conn = sqlite3.connect('brain_mri.db') | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT * FROM analyses ORDER BY created_at DESC LIMIT 20') | |
| analyses = cursor.fetchall() | |
| conn.close() | |
| return render_template('history.html', analyses=analyses) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 5000)) | |
| app.run(host='0.0.0.0', port=port, debug=False) | |