Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| from flask import Flask, render_template, request, send_file, url_for | |
| import requests | |
| import json | |
| from werkzeug.utils import secure_filename | |
| app = Flask(__name__) | |
| # Configuration | |
| UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| # Ensure the upload folder exists | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls'} | |
| # API URL | |
| API_URL = "https://ak0601-et-alzheimer.hf.space/predict" | |
| # Feature names in order | |
| FEATURE_NAMES = [ | |
| "ROI", | |
| "nFixations", | |
| "nTobiiSaccades", | |
| "regSaccades", | |
| "longSaccades", | |
| "tinySaccades", | |
| "saccadeTotLength", | |
| "totalFixTime", | |
| "totalSpokenTime", | |
| "speechDelay", | |
| "endSpeechDelay", | |
| "startPupL", | |
| "startPupR", | |
| "endPupL", | |
| "endPupR", | |
| "diffPupL", | |
| "diffPupR" | |
| ] | |
| def allowed_file(filename): | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def index(): | |
| prediction_result = None | |
| realtime_result = None | |
| input_values = {} | |
| realtime_inputs = {} | |
| error_message = None | |
| bulk_result = None | |
| active_tab = "file-prediction" | |
| if request.method == 'POST': | |
| # Single Prediction Logic | |
| if 'predict_btn' in request.form: | |
| active_tab = "individual-analysis" | |
| try: | |
| # Collect and convert features | |
| features = [] | |
| for name in FEATURE_NAMES: | |
| val = request.form.get(name) | |
| input_values[name] = val # Keep for re-populating form | |
| features.append(float(val)) | |
| # Prepare payload | |
| payload = {"features": features} | |
| # Call API | |
| response = requests.post(API_URL, json=payload) | |
| if response.status_code == 200: | |
| result = response.json() | |
| # Determine class label (optional mapping) | |
| class_map = {0: "Control/Healthy", 1: "MCI (Mild Cognitive Impairment)", 2: "Alzheimer's Disease"} | |
| predicted_class_idx = result.get("predicted_class") | |
| predicted_label = class_map.get(predicted_class_idx, f"Class {predicted_class_idx}") | |
| prediction_result = { | |
| "class_index": predicted_class_idx, | |
| "label": predicted_label, | |
| "confidence": f"{result.get('confidence', 0):.2%}", | |
| "probabilities": result.get("probabilities") | |
| } | |
| else: | |
| error_message = f"API Error: {response.status_code} - {response.text}" | |
| except ValueError: | |
| error_message = "Invalid input: Please ensure all fields contain numeric values." | |
| except requests.exceptions.ConnectionError: | |
| error_message = "Connection Error: Could not connect to the prediction API." | |
| except Exception as e: | |
| error_message = f"An error occurred: {str(e)}" | |
| # Real-time Prediction Logic | |
| elif 'realtime_btn' in request.form: | |
| active_tab = "real-time-prediction" | |
| error_message = "No eye tracker found." | |
| # Bulk Prediction (CSV) Logic | |
| elif 'upload_btn' in request.form: | |
| active_tab = "file-prediction" | |
| if 'file' not in request.files: | |
| error_message = "No file part" | |
| else: | |
| file = request.files['file'] | |
| if file.filename == '': | |
| error_message = "No selected file" | |
| elif file and allowed_file(file.filename): | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| try: | |
| # Process File based on extension | |
| ext = filename.rsplit('.', 1)[1].lower() | |
| if ext == 'csv': | |
| df = pd.read_csv(filepath) | |
| else: | |
| df = pd.read_excel(filepath) | |
| # Check columns | |
| missing_cols = [col for col in FEATURE_NAMES if col not in df.columns] | |
| if missing_cols: | |
| error_message = f"File is missing columns: {', '.join(missing_cols)}" | |
| else: | |
| predictions = [] | |
| for index, row in df.iterrows(): | |
| try: | |
| features = [float(row[col]) for col in FEATURE_NAMES] | |
| payload = {"features": features} | |
| resp = requests.post(API_URL, json=payload) | |
| if resp.status_code == 200: | |
| res = resp.json() | |
| class_map = {0: "Healthy", 1: "MCI", 2: "Alzheimer's"} | |
| pred_idx = res.get("predicted_class") | |
| predictions.append(class_map.get(pred_idx, f"Class {pred_idx}")) | |
| else: | |
| predictions.append("API Error") | |
| except Exception: | |
| predictions.append("Error") | |
| df['predicted'] = predictions | |
| output_filename = "predicted_" + filename | |
| output_path = os.path.join(app.config['UPLOAD_FOLDER'], output_filename) | |
| if ext == 'csv': | |
| df.to_csv(output_path, index=False) | |
| else: | |
| df.to_excel(output_path, index=False) | |
| bulk_result = { | |
| "original_filename": filename, | |
| "download_url": url_for('download_file', filename=output_filename) | |
| } | |
| except Exception as e: | |
| error_message = f"Error processing file: {str(e)}" | |
| return render_template('index.html', | |
| feature_names=FEATURE_NAMES, | |
| result=prediction_result, | |
| realtime_result=realtime_result, | |
| inputs=input_values, | |
| realtime_inputs=realtime_inputs, | |
| error=error_message, | |
| bulk_result=bulk_result, | |
| active_tab=active_tab) | |
| def download_file(filename): | |
| return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename), as_attachment=True) | |
| if __name__ == '__main__': | |
| if not os.path.exists(UPLOAD_FOLDER): | |
| os.makedirs(UPLOAD_FOLDER) | |
| app.run(debug=True, port=8080) | |