Spaces:
Sleeping
Sleeping
| import os | |
| import traceback | |
| from flask import Flask, request, jsonify, render_template | |
| from flask_cors import CORS | |
| from cnnClassifier.pipeline.prediction import PredictionPipeline | |
| app = Flask(__name__) | |
| CORS(app) | |
| UPLOAD_FOLDER = "uploads" | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| # Resolve model path once — prefer .keras (Keras 3 native), fall back to .h5 | |
| _keras_path = os.path.join("artifacts", "training", "model.keras") | |
| _h5_path = os.path.join("artifacts", "training", "model.h5") | |
| MODEL_PATH = _keras_path if os.path.isfile(_keras_path) else _h5_path | |
| # Load the model once at startup so every request reuses the same in-memory model | |
| _MODEL = None | |
| _MODEL_ERROR = None | |
| try: | |
| from tensorflow.keras.models import load_model | |
| _MODEL = load_model(MODEL_PATH) | |
| print(f"[startup] Model loaded from {MODEL_PATH}") | |
| except Exception as _e: | |
| _MODEL_ERROR = str(_e) | |
| print(f"[startup] WARNING: model failed to load — {_MODEL_ERROR}") | |
| def home(): | |
| return render_template("index.html") | |
| def health(): | |
| ok = _MODEL is not None | |
| return jsonify({ | |
| "status": "ok" if ok else "degraded", | |
| "model_loaded": ok, | |
| "model_path": os.path.abspath(MODEL_PATH), | |
| "error": _MODEL_ERROR, | |
| }), 200 if ok else 503 | |
| def train(): | |
| os.system("python main.py") | |
| return "Training completed successfully!" | |
| def predict(): | |
| if _MODEL is None: | |
| return jsonify({"error": f"Model not loaded: {_MODEL_ERROR}"}), 503 | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded"}), 400 | |
| file = request.files["file"] | |
| if file.filename == "": | |
| return jsonify({"error": "No file selected"}), 400 | |
| filepath = os.path.join(UPLOAD_FOLDER, file.filename) | |
| try: | |
| file.save(filepath) | |
| pipeline = PredictionPipeline(filepath, model=_MODEL) | |
| result = pipeline.predict() | |
| return jsonify(result) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| finally: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |