Alzheimer_UI / flask_app.py
ak0601's picture
Update flask_app.py
b7365a1 verified
import os
import pandas as pd
from flask import Flask, render_template, request, send_file, url_for
import requests
import json
from werkzeug.utils import secure_filename
app = Flask(__name__)
# Configuration
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# Ensure the upload folder exists
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls'}
# API URL
API_URL = "https://ak0601-et-alzheimer.hf.space/predict"
# Feature names in order
FEATURE_NAMES = [
"ROI",
"nFixations",
"nTobiiSaccades",
"regSaccades",
"longSaccades",
"tinySaccades",
"saccadeTotLength",
"totalFixTime",
"totalSpokenTime",
"speechDelay",
"endSpeechDelay",
"startPupL",
"startPupR",
"endPupL",
"endPupR",
"diffPupL",
"diffPupR"
]
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/', methods=['GET', 'POST'])
def index():
prediction_result = None
realtime_result = None
input_values = {}
realtime_inputs = {}
error_message = None
bulk_result = None
active_tab = "file-prediction"
if request.method == 'POST':
# Single Prediction Logic
if 'predict_btn' in request.form:
active_tab = "individual-analysis"
try:
# Collect and convert features
features = []
for name in FEATURE_NAMES:
val = request.form.get(name)
input_values[name] = val # Keep for re-populating form
features.append(float(val))
# Prepare payload
payload = {"features": features}
# Call API
response = requests.post(API_URL, json=payload)
if response.status_code == 200:
result = response.json()
# Determine class label (optional mapping)
class_map = {0: "Control/Healthy", 1: "MCI (Mild Cognitive Impairment)", 2: "Alzheimer's Disease"}
predicted_class_idx = result.get("predicted_class")
predicted_label = class_map.get(predicted_class_idx, f"Class {predicted_class_idx}")
prediction_result = {
"class_index": predicted_class_idx,
"label": predicted_label,
"confidence": f"{result.get('confidence', 0):.2%}",
"probabilities": result.get("probabilities")
}
else:
error_message = f"API Error: {response.status_code} - {response.text}"
except ValueError:
error_message = "Invalid input: Please ensure all fields contain numeric values."
except requests.exceptions.ConnectionError:
error_message = "Connection Error: Could not connect to the prediction API."
except Exception as e:
error_message = f"An error occurred: {str(e)}"
# Real-time Prediction Logic
elif 'realtime_btn' in request.form:
active_tab = "real-time-prediction"
error_message = "No eye tracker found."
# Bulk Prediction (CSV) Logic
elif 'upload_btn' in request.form:
active_tab = "file-prediction"
if 'file' not in request.files:
error_message = "No file part"
else:
file = request.files['file']
if file.filename == '':
error_message = "No selected file"
elif file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
# Process File based on extension
ext = filename.rsplit('.', 1)[1].lower()
if ext == 'csv':
df = pd.read_csv(filepath)
else:
df = pd.read_excel(filepath)
# Check columns
missing_cols = [col for col in FEATURE_NAMES if col not in df.columns]
if missing_cols:
error_message = f"File is missing columns: {', '.join(missing_cols)}"
else:
predictions = []
for index, row in df.iterrows():
try:
features = [float(row[col]) for col in FEATURE_NAMES]
payload = {"features": features}
resp = requests.post(API_URL, json=payload)
if resp.status_code == 200:
res = resp.json()
class_map = {0: "Healthy", 1: "MCI", 2: "Alzheimer's"}
pred_idx = res.get("predicted_class")
predictions.append(class_map.get(pred_idx, f"Class {pred_idx}"))
else:
predictions.append("API Error")
except Exception:
predictions.append("Error")
df['predicted'] = predictions
output_filename = "predicted_" + filename
output_path = os.path.join(app.config['UPLOAD_FOLDER'], output_filename)
if ext == 'csv':
df.to_csv(output_path, index=False)
else:
df.to_excel(output_path, index=False)
bulk_result = {
"original_filename": filename,
"download_url": url_for('download_file', filename=output_filename)
}
except Exception as e:
error_message = f"Error processing file: {str(e)}"
return render_template('index.html',
feature_names=FEATURE_NAMES,
result=prediction_result,
realtime_result=realtime_result,
inputs=input_values,
realtime_inputs=realtime_inputs,
error=error_message,
bulk_result=bulk_result,
active_tab=active_tab)
@app.route('/download/<filename>')
def download_file(filename):
return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename), as_attachment=True)
if __name__ == '__main__':
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
app.run(debug=True, port=8080)