IOMP / app.py
Reboot2004's picture
Update app.py
d96a702 verified
from flask import Flask, jsonify, request, send_file
from flask_cors import CORS
from lrp_pipeline_2 import lrp_main
from cam_pipeline import cam_process_single_image
from utils import create_folders, delete_folders, create_zip_file
from pymongo import MongoClient
from bson import ObjectId
from datetime import datetime
import os
import base64
app = Flask(__name__)
CORS(app)
# === MongoDB Atlas Setup (Hugging Face Secret) ===
MONGO_URI = os.getenv("MONGO_URI") # Add this secret in Hugging Face: Settings → Variables and secrets
if not MONGO_URI:
raise RuntimeError("MONGO_URI not set. Please add it in Hugging Face Space Secrets.")
client = MongoClient(MONGO_URI)
db = client["xai_results"]
try:
client.admin.command("ping")
print("✅ Connected to MongoDB Atlas successfully.")
except Exception as e:
print("⚠️ MongoDB connection failed:", e)
# === ROUTE: Upload image ===
@app.route("/api/upload", methods=["POST"])
def submit_data():
folder_names = ["uploads", "heatmaps", "segmentations", "tables", "cell_descriptors"]
delete_folders(folder_names)
create_folders(folder_names)
uploads_dir = "uploads"
if not os.path.exists(uploads_dir):
os.makedirs(uploads_dir)
file = list(dict(request.files).values())[0]
file_path = os.path.join(uploads_dir, file.filename)
file.save(file_path)
return jsonify({
"message": "Data received successfully!",
"file_path": file_path
})
# === ROUTE: Process input form (LRP or GradCAM++) ===
@app.route("/api/inputform", methods=["POST"])
def submit_form():
data = dict(request.json)
uploads_dir = "uploads"
image_files = [f for f in os.listdir(uploads_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')) and not f.startswith('.')]
if not image_files:
return jsonify({"error": "No images found in uploads directory"}), 400
image_path = os.path.join(uploads_dir, image_files[0])
xai_method = data.get("xaiMethod", "Unknown")
magval = float(data.get("magval", 1.0))
# === LRP ===
if "LRP" in xai_method:
result_dict = lrp_main(magval)
record = {
"model": data.get("model"),
"xaiMethod": xai_method,
"magnification": magval,
"classification": result_dict["classification"],
"images": {
"originalImage": result_dict["image1"],
"heatmapImage": result_dict["inter1"],
"maskImage": result_dict["mask1"],
"tableImage": result_dict["table1"]
},
"timestamp": datetime.utcnow()
}
db.predictions.insert_one(record)
return jsonify({
"success": True,
"summary": f"LRP completed with magnification {magval}",
"classification": record["classification"],
"results": record["images"]
})
# === GradCAM++ ===
elif "GradCAM++" in xai_method:
result_dict, output_paths = cam_process_single_image(image_path, magval)
def encode_img(path):
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
original = encode_img(image_path)
heatmap = encode_img(output_paths["heatmap"])
mask = encode_img(output_paths["mask"])
table = encode_img(output_paths["table"])
record = {
"model": data.get("model"),
"xaiMethod": xai_method,
"magnification": magval,
"classification": result_dict.get("class1"),
"images": {
"originalImage": original,
"heatmapImage": heatmap,
"maskImage": mask,
"tableImage": table
},
"timestamp": datetime.utcnow()
}
db.predictions.insert_one(record)
return jsonify({
"success": True,
"summary": f"GradCAM++ completed with magnification {magval}",
"classification": record["classification"],
"results": record["images"]
})
else:
return jsonify({"error": "Invalid XAI method"}), 400
# === ROUTE: Create ZIP (optional) ===
@app.route("/api/zip", methods=["GET"])
def get_csv():
zip_path = "outputs.zip"
create_zip_file()
if not os.path.exists(zip_path):
return jsonify({"error": "outputs.zip not found"}), 404
return send_file(zip_path, as_attachment=True)
# === ROUTE: Fetch all previous predictions ===
@app.route("/api/oldpreds", methods=["GET"])
def list_old_predictions():
preds = list(db.predictions.find().sort("timestamp", -1))
result = []
for p in preds:
result.append({
"id": str(p["_id"]),
"model": p.get("model"),
"xaiMethod": p.get("xaiMethod"),
"magnification": p.get("magnification"),
"classification": p.get("classification"),
"images": p.get("images"),
"timestamp": p["timestamp"].strftime("%Y-%m-%d %H:%M:%S")
})
return jsonify(result)
# === ROUTE: Fetch one old prediction by ID ===
@app.route("/api/oldpreds/<id>", methods=["GET"])
def get_old_prediction(id):
try:
record = db.predictions.find_one({"_id": ObjectId(id)})
if not record:
return jsonify({"error": "Record not found"}), 404
record["_id"] = str(record["_id"])
record["timestamp"] = record["timestamp"].strftime("%Y-%m-%d %H:%M:%S")
return jsonify(record)
except Exception as e:
return jsonify({"error": str(e)}), 400
@app.route("/", methods=["GET"])
def home():
return jsonify({"message": "Flask XAI API running successfully"})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)