stroke_predict / app.py
jagadeesh72's picture
initial backend
4975e29
from flask import Flask, request, jsonify
import tensorflow as tf
from flask_cors import CORS
from utils import predict_image
import os
import requests
app = Flask(__name__)
CORS(app)
# ------------------------------
# MODEL CONFIG
# ------------------------------
MODEL_PATH = "model.h5"
MODEL_URL = "https://huggingface.co/bakhili/stroke-classification-resnet-model/resolve/main/stroke_classification_model.h5"
# ------------------------------
# DOWNLOAD MODEL IF NOT EXISTS
# ------------------------------
if not os.path.exists(MODEL_PATH):
print("Downloading model from Hugging Face...")
r = requests.get(MODEL_URL, stream=True)
with open(MODEL_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Model downloaded successfully!")
# ------------------------------
# LOAD MODEL
# ------------------------------
print("Loading model...")
model = tf.keras.models.load_model(MODEL_PATH)
print("Model loaded successfully!")
# ------------------------------
# ROUTES
# ------------------------------
@app.route("/")
def home():
return "Stroke Detection Backend Running"
@app.route("/predict", methods=["POST"])
def predict():
try:
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "Empty filename"}), 400
result = predict_image(model, file)
return jsonify(result)
except Exception as e:
print("Error during prediction:", str(e))
return jsonify({"error": "Prediction failed"}), 500
# ------------------------------
# RUN SERVER
# ------------------------------
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)