Reboot2004 commited on
Commit
d96a702
·
verified ·
1 Parent(s): 7319547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -71
app.py CHANGED
@@ -1,124 +1,181 @@
1
- from flask import Flask, jsonify, request, send_file, render_template
2
  from flask_cors import CORS
3
  from lrp_pipeline_2 import lrp_main
 
4
  from utils import create_folders, delete_folders, create_zip_file
5
- from cam_pipeline import cam_main, cam_process_single_image
 
 
6
  import os
7
  import base64
8
 
9
  app = Flask(__name__)
10
  CORS(app)
11
 
 
 
12
 
13
- @app.route("/api/upload", methods=["GET"])
14
- def get_data():
15
- data = {"message": "Hello from Flask backend!"}
16
- return jsonify(data)
17
 
 
 
18
 
 
 
 
 
 
 
 
 
19
  @app.route("/api/upload", methods=["POST"])
20
  def submit_data():
21
- # first clear all the existing files in uploads, heatmaps, segmentations, tables, cell_descriptors folders
22
- folder_names = [
23
- "uploads",
24
- "heatmaps",
25
- "segmentations",
26
- "tables",
27
- "cell_descriptors",
28
- ]
29
  delete_folders(folder_names)
30
  create_folders(folder_names)
31
 
32
- # Ensure the uploads directory exists
33
  uploads_dir = "uploads"
34
  if not os.path.exists(uploads_dir):
35
  os.makedirs(uploads_dir)
36
 
37
- # then upload the submitted file(s)
38
  file = list(dict(request.files).values())[0]
39
- print(file)
40
  file_path = os.path.join(uploads_dir, file.filename)
41
- file.save(file_path) # Save to 'uploads' directory
42
 
43
- # Process data here
44
  return jsonify({
45
  "message": "Data received successfully!",
46
  "file_path": file_path
47
  })
48
 
49
 
 
50
  @app.route("/api/inputform", methods=["POST"])
51
  def submit_form():
52
- data = dict(request.json) # format of data: {'model': 'VGGNet', 'xaiMethod': 'LRP'}
53
- print(data)
54
-
55
- # Check if we have images in the uploads directory
56
  uploads_dir = "uploads"
57
- image_files = [f for f in os.listdir(uploads_dir)
58
- if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
59
- and not f.startswith('.')]
60
-
61
  if not image_files:
62
  return jsonify({"error": "No images found in uploads directory"}), 400
63
-
64
- # Process the first image (or all images based on your requirements)
65
  image_path = os.path.join(uploads_dir, image_files[0])
66
-
67
- if "LRP" in data["xaiMethod"]:
68
- result_dict = lrp_main(float(data["magval"]))
69
- # Extract relevant results to show in the frontend
70
- # return jsonify({
71
- # "success": True,
72
- # "summary": f"LRP analysis completed with magnification {data['magval']}",
73
- # "details": "Nucleus and cytoplasm segmented successfully",
74
- # "results": result_dict
75
- # })
76
- print(result_dict)
77
- return jsonify({
78
- "success": True,
79
- "summary": f"GradCAM++ analysis completed with magnification {data['magval']}",
80
- "details": "Nucleus and cytoplasm segmented successfully",
81
  "classification": result_dict["classification"],
82
- "results": {
83
  "originalImage": result_dict["image1"],
84
  "heatmapImage": result_dict["inter1"],
85
  "maskImage": result_dict["mask1"],
86
  "tableImage": result_dict["table1"]
87
- }
 
 
 
 
 
 
 
 
88
  })
89
-
90
- elif "GradCAM++" in data["xaiMethod"]:
91
- # Process single image with GradCAM++
92
- result_dict, output_paths = cam_process_single_image(image_path, float(data["magval"]))
93
-
94
- # Read and encode the output files for display
95
- original_image = base64.b64encode(open(image_path, "rb").read()).decode("utf-8")
96
- heatmap_image = base64.b64encode(open(output_paths["heatmap"], "rb").read()).decode("utf-8")
97
- mask_image = base64.b64encode(open(output_paths["mask"], "rb").read()).decode("utf-8")
98
- table_image = base64.b64encode(open(output_paths["table"], "rb").read()).decode("utf-8")
99
-
100
- # include predicted class from the pipeline result
101
- predicted_class = result_dict.get("class1")
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  return jsonify({
104
  "success": True,
105
- "summary": f"GradCAM++ analysis completed with magnification {data['magval']}",
106
- "details": "Nucleus and cytoplasm segmented successfully",
107
- "classification": predicted_class,
108
- "results": {
109
- "originalImage": original_image,
110
- "heatmapImage": heatmap_image,
111
- "maskImage": mask_image,
112
- "tableImage": table_image
113
- }
114
  })
115
 
 
 
 
116
 
 
117
  @app.route("/api/zip", methods=["GET"])
118
  def get_csv():
 
119
  create_zip_file()
120
- return send_file("outputs.zip", as_attachment=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  if __name__ == "__main__":
124
- app.run(host="0.0.0.0",debug=True)
 
1
+ from flask import Flask, jsonify, request, send_file
2
  from flask_cors import CORS
3
  from lrp_pipeline_2 import lrp_main
4
+ from cam_pipeline import cam_process_single_image
5
  from utils import create_folders, delete_folders, create_zip_file
6
+ from pymongo import MongoClient
7
+ from bson import ObjectId
8
+ from datetime import datetime
9
  import os
10
  import base64
11
 
12
  app = Flask(__name__)
13
  CORS(app)
14
 
15
+ # === MongoDB Atlas Setup (Hugging Face Secret) ===
16
+ MONGO_URI = os.getenv("MONGO_URI") # Add this secret in Hugging Face: Settings → Variables and secrets
17
 
18
+ if not MONGO_URI:
19
+ raise RuntimeError("MONGO_URI not set. Please add it in Hugging Face Space Secrets.")
 
 
20
 
21
+ client = MongoClient(MONGO_URI)
22
+ db = client["xai_results"]
23
 
24
+ try:
25
+ client.admin.command("ping")
26
+ print("✅ Connected to MongoDB Atlas successfully.")
27
+ except Exception as e:
28
+ print("⚠️ MongoDB connection failed:", e)
29
+
30
+
31
+ # === ROUTE: Upload image ===
32
  @app.route("/api/upload", methods=["POST"])
33
  def submit_data():
34
+ folder_names = ["uploads", "heatmaps", "segmentations", "tables", "cell_descriptors"]
 
 
 
 
 
 
 
35
  delete_folders(folder_names)
36
  create_folders(folder_names)
37
 
 
38
  uploads_dir = "uploads"
39
  if not os.path.exists(uploads_dir):
40
  os.makedirs(uploads_dir)
41
 
 
42
  file = list(dict(request.files).values())[0]
 
43
  file_path = os.path.join(uploads_dir, file.filename)
44
+ file.save(file_path)
45
 
 
46
  return jsonify({
47
  "message": "Data received successfully!",
48
  "file_path": file_path
49
  })
50
 
51
 
52
+ # === ROUTE: Process input form (LRP or GradCAM++) ===
53
  @app.route("/api/inputform", methods=["POST"])
54
  def submit_form():
55
+ data = dict(request.json)
 
 
 
56
  uploads_dir = "uploads"
57
+
58
+ image_files = [f for f in os.listdir(uploads_dir)
59
+ if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')) and not f.startswith('.')]
60
+
61
  if not image_files:
62
  return jsonify({"error": "No images found in uploads directory"}), 400
63
+
 
64
  image_path = os.path.join(uploads_dir, image_files[0])
65
+ xai_method = data.get("xaiMethod", "Unknown")
66
+ magval = float(data.get("magval", 1.0))
67
+
68
+ # === LRP ===
69
+ if "LRP" in xai_method:
70
+ result_dict = lrp_main(magval)
71
+ record = {
72
+ "model": data.get("model"),
73
+ "xaiMethod": xai_method,
74
+ "magnification": magval,
 
 
 
 
 
75
  "classification": result_dict["classification"],
76
+ "images": {
77
  "originalImage": result_dict["image1"],
78
  "heatmapImage": result_dict["inter1"],
79
  "maskImage": result_dict["mask1"],
80
  "tableImage": result_dict["table1"]
81
+ },
82
+ "timestamp": datetime.utcnow()
83
+ }
84
+ db.predictions.insert_one(record)
85
+ return jsonify({
86
+ "success": True,
87
+ "summary": f"LRP completed with magnification {magval}",
88
+ "classification": record["classification"],
89
+ "results": record["images"]
90
  })
91
+
92
+ # === GradCAM++ ===
93
+ elif "GradCAM++" in xai_method:
94
+ result_dict, output_paths = cam_process_single_image(image_path, magval)
95
+
96
+ def encode_img(path):
97
+ with open(path, "rb") as f:
98
+ return base64.b64encode(f.read()).decode("utf-8")
99
+
100
+ original = encode_img(image_path)
101
+ heatmap = encode_img(output_paths["heatmap"])
102
+ mask = encode_img(output_paths["mask"])
103
+ table = encode_img(output_paths["table"])
104
+
105
+ record = {
106
+ "model": data.get("model"),
107
+ "xaiMethod": xai_method,
108
+ "magnification": magval,
109
+ "classification": result_dict.get("class1"),
110
+ "images": {
111
+ "originalImage": original,
112
+ "heatmapImage": heatmap,
113
+ "maskImage": mask,
114
+ "tableImage": table
115
+ },
116
+ "timestamp": datetime.utcnow()
117
+ }
118
+
119
+ db.predictions.insert_one(record)
120
  return jsonify({
121
  "success": True,
122
+ "summary": f"GradCAM++ completed with magnification {magval}",
123
+ "classification": record["classification"],
124
+ "results": record["images"]
 
 
 
 
 
 
125
  })
126
 
127
+ else:
128
+ return jsonify({"error": "Invalid XAI method"}), 400
129
+
130
 
131
+ # === ROUTE: Create ZIP (optional) ===
132
  @app.route("/api/zip", methods=["GET"])
133
  def get_csv():
134
+ zip_path = "outputs.zip"
135
  create_zip_file()
136
+
137
+ if not os.path.exists(zip_path):
138
+ return jsonify({"error": "outputs.zip not found"}), 404
139
+
140
+ return send_file(zip_path, as_attachment=True)
141
+
142
+
143
+ # === ROUTE: Fetch all previous predictions ===
144
+ @app.route("/api/oldpreds", methods=["GET"])
145
+ def list_old_predictions():
146
+ preds = list(db.predictions.find().sort("timestamp", -1))
147
+ result = []
148
+ for p in preds:
149
+ result.append({
150
+ "id": str(p["_id"]),
151
+ "model": p.get("model"),
152
+ "xaiMethod": p.get("xaiMethod"),
153
+ "magnification": p.get("magnification"),
154
+ "classification": p.get("classification"),
155
+ "images": p.get("images"),
156
+ "timestamp": p["timestamp"].strftime("%Y-%m-%d %H:%M:%S")
157
+ })
158
+ return jsonify(result)
159
+
160
+
161
+ # === ROUTE: Fetch one old prediction by ID ===
162
+ @app.route("/api/oldpreds/<id>", methods=["GET"])
163
+ def get_old_prediction(id):
164
+ try:
165
+ record = db.predictions.find_one({"_id": ObjectId(id)})
166
+ if not record:
167
+ return jsonify({"error": "Record not found"}), 404
168
+ record["_id"] = str(record["_id"])
169
+ record["timestamp"] = record["timestamp"].strftime("%Y-%m-%d %H:%M:%S")
170
+ return jsonify(record)
171
+ except Exception as e:
172
+ return jsonify({"error": str(e)}), 400
173
+
174
+
175
+ @app.route("/", methods=["GET"])
176
+ def home():
177
+ return jsonify({"message": "Flask XAI API running successfully"})
178
 
179
 
180
  if __name__ == "__main__":
181
+ app.run(host="0.0.0.0", port=7860, debug=True)