Spaces:
Sleeping
Sleeping
File size: 13,237 Bytes
546a860 614e697 546a860 b1fd5bd 546a860 846d7a0 614e697 03d7f52 069fc47 614e697 03d7f52 069fc47 614e697 916db25 614e697 546a860 34a445a 546a860 b1fd5bd 28428d7 546a860 614e697 546a860 28428d7 546a860 614e697 546a860 614e697 546a860 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 | import os
import numpy as np
import cv2
import sqlite3
import tensorflow as tf
from tensorflow.keras.models import load_model
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from flask import Flask, render_template, request, redirect, url_for, flash
from werkzeug.utils import secure_filename
from datetime import datetime
from markupsafe import escape
from huggingface_hub import hf_hub_download
classification_model = None
segmentation_model = None
try:
from openai import OpenAI
except ImportError:
OpenAI = None
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your_secret_key'
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['RESULTS_FOLDER'] = 'static/results'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
app.config['ALLOWED_EXTENSIONS'] = {'jpg', 'jpeg', 'png'}
app.config['OPENROUTER_MODEL'] = 'openai/gpt-oss-120b'
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def get_model_path():
classification_path = hf_hub_download(
repo_id = "MohammedAH/Brrain-MRI-Classification",
filename = "brain_mri.h5"
)
segmentation_path = hf_hub_download(
repo_id = "MohammedAH/Unet-Brain-Segmentation",
filename = "Unet_model.h5"
)
return classification_path, segmentation_path
# Create directories
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
# Class names for the classification model
class_names = ['glioma', 'meningioma', 'no_tumor', 'pituitary']
# Database initialization
def init_db():
conn = sqlite3.connect('brain_mri.db')
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS analyses (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT NOT NULL,
original_path TEXT NOT NULL,
result_path TEXT NOT NULL,
classification TEXT NOT NULL,
confidence REAL NOT NULL,
summary TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
conn.close()
with app.app_context():
init_db()
# Helper functions for model inference
def dice_coefficient(y_true, y_pred, smooth=1):
y_true_f = tf.keras.backend.flatten(y_true)
y_pred_f = tf.keras.backend.flatten(y_pred)
intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)
def dice_loss(y_true, y_pred):
return 1 - dice_coefficient(y_true, y_pred)
def iou(y_true, y_pred, smooth=1):
y_true_f = tf.keras.backend.flatten(y_true)
y_pred_f = tf.keras.backend.flatten(y_pred)
intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
total = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f)
union = total - intersection
return (intersection + smooth) / (union + smooth)
# Load the models
def load_models():
global classification_model, segmentation_model
if classification_model is None or segmentation_model is None:
classification_path, segmentation_path = get_model_path()
classification_model = load_model(classification_path, compile=False)
segmentation_model = load_model(
segmentation_path,
compile=False,
custom_objects={
'dice_coefficient': dice_coefficient,
'dice_loss':dice_loss,
'iou': iou
}
)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def preprocess_image_for_classification(image_path):
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
img = img / 255.0
return np.expand_dims(img, axis=0)
def preprocess_image_for_segmentation(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (128, 128))
img = img / 255.0
return np.expand_dims(img, axis=(0, -1))
def format_summary_html(text):
paragraphs = [segment.strip() for segment in text.split('\n') if segment.strip()]
if not paragraphs:
return "<p>No summary available.</p>"
formatted = ''.join(f"<p>{escape(paragraph)}</p>" for paragraph in paragraphs)
disclaimer = (
"<p><strong>Note:</strong> This explanation is educational and must not be used "
"as a medical diagnosis or treatment plan.</p>"
)
return formatted + disclaimer
def get_fallback_summary(classification, confidence):
base_summaries = {
'glioma': (
"Glioma is a tumor arising from glial tissue in the brain. Clinical significance usually depends on tumor grade, location, and surrounding brain involvement. "
"Typical follow-up in real care includes radiology review, symptom correlation, and specialist evaluation."
),
'meningioma': (
"Meningioma usually develops from the membranes surrounding the brain and is often slow-growing, though behavior varies by subtype and location. "
"Real-world next steps often include reviewing size, pressure effect, growth pattern, and need for observation versus intervention."
),
'pituitary': (
"Pituitary tumors involve the pituitary region and may matter because of hormone effects, visual pathway compression, or local mass effect. "
"Clinical workup commonly includes endocrine assessment and focused review of symptoms such as headache or visual disturbance."
),
'no_tumor': (
"No tumor class was detected by the model for this image. That does not rule out other abnormalities, image quality issues, or findings outside the model's scope. "
"Formal interpretation should still depend on a qualified radiology or neurology workflow when clinically needed."
)
}
if confidence > 0.9:
confidence_note = "The model confidence is high for this predicted class."
elif confidence > 0.7:
confidence_note = "The model confidence is moderate for this predicted class."
else:
confidence_note = "The model confidence is limited, so the output should be treated cautiously."
combined = f"{base_summaries.get(classification, 'The predicted class is not recognized by the summary helper.')} {confidence_note}"
return format_summary_html(combined)
def get_openrouter_summary(classification, confidence):
api_key = os.environ.get('OPENROUTER_API_KEY')
if not api_key or OpenAI is None:
return get_fallback_summary(classification, confidence)
client = OpenAI(
base_url='https://openrouter.ai/api/v1',
api_key=api_key,
default_headers={
'HTTP-Referer': 'http://127.0.0.1:5000',
'X-OpenRouter-Title': 'NeuroScope MRI'
}
)
prompt = (
f"Brain MRI model output:\n"
f"- Predicted classification: {classification}\n"
f"- Confidence: {confidence:.4f}\n\n"
"Write a concise educational explanation for a web app result page. "
"Explain what this tumor class generally means, why the confidence level matters, "
"what clinical factors are usually reviewed next, and one caution about not using AI output alone. "
"If the class is no_tumor, explain that no tumor was detected by the model but other abnormalities may still require professional review. "
"Keep it to 3 short paragraphs in plain text. Do not claim a diagnosis. Do not mention that you are an AI model."
)
try:
response = client.chat.completions.create(
model=app.config['OPENROUTER_MODEL'],
messages=[
{
'role': 'system',
'content': (
"You write careful educational MRI result summaries for a student medical imaging app. "
"Be clinically literate, concise, and explicit that the output is not a diagnosis."
)
},
{
'role': 'user',
'content': prompt
}
],
temperature=0.4,
max_tokens=350
)
content = (response.choices[0].message.content or '').strip()
if not content:
return get_fallback_summary(classification, confidence)
return format_summary_html(content)
except Exception:
return get_fallback_summary(classification, confidence)
def save_to_database(filename, original_path, result_path, classification, confidence, summary):
conn = sqlite3.connect('brain_mri.db')
cursor = conn.cursor()
cursor.execute('''
INSERT INTO analyses (filename, original_path, result_path, classification, confidence, summary)
VALUES (?, ?, ?, ?, ?, ?)
''', (filename, original_path, result_path, classification, confidence, summary))
analysis_id = cursor.lastrowid
conn.commit()
conn.close()
return analysis_id
# Routes
@app.route('/')
def index():
return render_template('index.html')
@app.route('/analyze')
def analyze():
return render_template('upload.html')
@app.route('/upload', methods=['POST'])
def upload_file():
if 'mri_image' not in request.files:
flash('No file part', 'error')
return redirect(url_for('analyze'))
file = request.files['mri_image']
if file.filename == '':
flash('No selected file', 'error')
return redirect(url_for('analyze'))
if file and allowed_file(file.filename):
# Save original image
filename = secure_filename(file.filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
saved_filename = f"{timestamp}_{filename}"
original_path = os.path.join(app.config['UPLOAD_FOLDER'], saved_filename)
file.save(original_path)
try:
load_models()
except Exception as exc:
flash(str(exc), 'error')
return redirect(url_for('analyze'))
# Classify the image
classification_input = preprocess_image_for_classification(original_path)
predictions = classification_model.predict(classification_input)
predicted_class_index = np.argmax(predictions[0])
predicted_class = class_names[predicted_class_index]
confidence = float(predictions[0][predicted_class_index])
# Segment the image
segmentation_input = preprocess_image_for_segmentation(original_path)
segmentation_mask = segmentation_model.predict(segmentation_input)
segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8)
# Create overlay image
plt.figure(figsize=(10, 8))
# Original image
img = cv2.imread(original_path, cv2.IMREAD_GRAYSCALE)
img_resized = cv2.resize(img, (128, 128))
# Display original and mask overlay
plt.subplot(1, 2, 1)
plt.imshow(img_resized, cmap='gray')
plt.title('Original MRI')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(img_resized, cmap='gray')
plt.imshow(segmentation_mask[0, :, :, 0], alpha=0.5, cmap='jet')
plt.title('Tumor Segmentation')
plt.axis('off')
# Save the result
result_filename = f"result_{saved_filename.split('.')[0]}.png"
result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename)
plt.savefig(result_path, bbox_inches='tight')
plt.close()
# Get an educational insight summary from OpenRouter via an OpenAI-compatible API.
summary = get_openrouter_summary(predicted_class, confidence)
# Save to database
analysis_id = save_to_database(
saved_filename,
original_path,
result_path,
predicted_class,
confidence,
summary
)
# Redirect to results page
return redirect(url_for('result', analysis_id=analysis_id))
flash('Invalid file type. Please upload JPG, JPEG, or PNG files.', 'error')
return redirect(url_for('analyze'))
@app.route('/result/<int:analysis_id>')
def result(analysis_id):
conn = sqlite3.connect('brain_mri.db')
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('SELECT * FROM analyses WHERE id = ?', (analysis_id,))
analysis = cursor.fetchone()
conn.close()
if analysis:
return render_template('result.html', analysis=analysis)
else:
flash('Analysis not found', 'error')
return redirect(url_for('index'))
@app.route('/history')
def history():
conn = sqlite3.connect('brain_mri.db')
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('SELECT * FROM analyses ORDER BY created_at DESC LIMIT 20')
analyses = cursor.fetchall()
conn.close()
return render_template('history.html', analyses=analyses)
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000))
app.run(host='0.0.0.0', port=port, debug=False)
|