from flask import Flask, request, jsonify import tensorflow as tf from flask_cors import CORS from utils import predict_image import os import requests app = Flask(__name__) CORS(app) # ------------------------------ # MODEL CONFIG # ------------------------------ MODEL_PATH = "model.h5" MODEL_URL = "https://huggingface.co/bakhili/stroke-classification-resnet-model/resolve/main/stroke_classification_model.h5" # ------------------------------ # DOWNLOAD MODEL IF NOT EXISTS # ------------------------------ if not os.path.exists(MODEL_PATH): print("Downloading model from Hugging Face...") r = requests.get(MODEL_URL, stream=True) with open(MODEL_PATH, "wb") as f: for chunk in r.iter_content(chunk_size=8192): if chunk: f.write(chunk) print("Model downloaded successfully!") # ------------------------------ # LOAD MODEL # ------------------------------ print("Loading model...") model = tf.keras.models.load_model(MODEL_PATH) print("Model loaded successfully!") # ------------------------------ # ROUTES # ------------------------------ @app.route("/") def home(): return "Stroke Detection Backend Running" @app.route("/predict", methods=["POST"]) def predict(): try: if "file" not in request.files: return jsonify({"error": "No file uploaded"}), 400 file = request.files["file"] if file.filename == "": return jsonify({"error": "Empty filename"}), 400 result = predict_image(model, file) return jsonify(result) except Exception as e: print("Error during prediction:", str(e)) return jsonify({"error": "Prediction failed"}), 500 # ------------------------------ # RUN SERVER # ------------------------------ if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)