bobbypaton commited on
Commit
d7ad5f5
·
1 Parent(s): 9b2e0f3

Allow iframe embedding for HF Spaces

Browse files
Files changed (1) hide show
  1. app.py +7 -159
app.py CHANGED
@@ -1,162 +1,3 @@
1
- """
2
- CASCADE – Flask app for HF Spaces
3
- Mounted at /cascade_v1/
4
- """
5
-
6
- import os
7
- import json
8
- import uuid
9
-
10
- import redis
11
- from flask import Flask, Blueprint, request, jsonify, render_template, send_file, abort, redirect, Response
12
-
13
- from rdkit import Chem
14
- from rdkit.Chem import AllChem
15
- from rdkit.Chem.Draw import rdMolDraw2D
16
-
17
- from NMR_Prediction.valid import validate_smiles
18
-
19
- bp = Blueprint("cascade", __name__)
20
-
21
- redis_client = redis.StrictRedis(
22
- host="localhost", port=6379, db=0, decode_responses=True
23
- )
24
-
25
- # ── Pages ─────────────────────────────────────────────────────────────────────
26
- @bp.route("/")
27
- @bp.route("")
28
- @bp.route("/predict/")
29
- @bp.route("/predict")
30
- @bp.route("/home/")
31
- @bp.route("/about/")
32
- def predict():
33
- return render_template("cascade/predict.html")
34
-
35
-
36
- # ── Job submission ────────────────────────────────────────────────────────────
37
- def _submit_job(smiles, type_):
38
- if not validate_smiles(smiles):
39
- return jsonify({"message": "Input molecule is not allowed", "task_id": None})
40
- task_id = uuid.uuid4().hex
41
- redis_client.set(f"task_detail_{task_id}",
42
- json.dumps({"smiles": smiles, "type_": type_}))
43
- redis_client.rpush("task_queue", task_id)
44
- return task_id
45
-
46
-
47
- @bp.route("/predict_NMR_C/", methods=["POST"])
48
- def predict_NMR_C():
49
- result = _submit_job(request.form["smiles"], request.form["type_"])
50
- if not isinstance(result, str):
51
- return result
52
- return jsonify({"message": "Molecule submitted to C queue", "task_id": result})
53
-
54
-
55
- @bp.route("/predict_NMR_H/", methods=["POST"])
56
- def predict_NMR_H():
57
- result = _submit_job(request.form["smiles"], request.form["type_"])
58
- if not isinstance(result, str):
59
- return result
60
- return jsonify({"message": "Molecule submitted to H queue", "task_id": result})
61
-
62
-
63
- # ── check_task ────────────────────────────────────────────────────────────────
64
- @bp.route("/check_task/")
65
- def check_task():
66
- raw = redis_client.get(f"task_result_{request.args['task_id']}")
67
- if not raw:
68
- return "running", 200
69
- result = json.loads(raw)
70
- if "errMessage" in result:
71
- return "Error1", 200
72
- return "done", 200
73
-
74
-
75
- # ── get_result ────────────────────────────────────────────────────────────────
76
- @bp.route("/get_result/")
77
- def get_result():
78
- task_id = request.args["task_id"]
79
- raw = redis_client.get(f"task_result_{task_id}")
80
- if not raw:
81
- return jsonify({"error": "Result not found"}), 404
82
-
83
- result = json.loads(raw)
84
- if "errMessage" in result:
85
- return jsonify({"error": result["errMessage"]}), 500
86
-
87
- smiles = result["smiles"]
88
- nucleus = result.get("type_", "C") # "C" or "H"
89
-
90
- weighted_shift_txt = result["weightedShiftTxt"]
91
- shift_map = {}
92
- for item in filter(None, weighted_shift_txt.split(";")):
93
- parts = item.split(",")
94
- if len(parts) == 2:
95
- shift_map[int(parts[0])] = parts[1]
96
-
97
- if nucleus == "H":
98
- # ¹H: draw with explicit H so H atoms are visible
99
- mol = Chem.MolFromSmiles(smiles)
100
- mol = Chem.AddHs(mol)
101
- AllChem.Compute2DCoords(mol)
102
- mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
103
- n_label = mol.GetNumAtoms()
104
- drawer = rdMolDraw2D.MolDraw2DSVG(700, 500)
105
- else:
106
- # ¹³C: skeletal structure, no explicit H
107
- mol = Chem.MolFromSmiles(smiles)
108
- AllChem.Compute2DCoords(mol)
109
- mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
110
- n_label = mol.GetNumAtoms()
111
- drawer = rdMolDraw2D.MolDraw2DSVG(600, 450)
112
-
113
- opts = drawer.drawOptions()
114
- for atom_1idx, shift_val in shift_map.items():
115
- atom_0idx = atom_1idx - 1
116
- if atom_0idx < n_label:
117
- opts.atomLabels[atom_0idx] = shift_val
118
- opts.clearBackground = False
119
- opts.bondLineWidth = 1
120
- opts.padding = 0.15
121
- opts.additionalAtomLabelPadding = 0.1
122
-
123
- drawer.DrawMolecule(mol_draw)
124
- drawer.FinishDrawing()
125
- svg = drawer.GetDrawingText().replace("svg:", "").replace(":svg", "")
126
-
127
- return jsonify({
128
- "svg": svg,
129
- "smiles": smiles,
130
- "nucleus": nucleus,
131
- "conf_sdfs": result.get("conf_sdfs", []),
132
- "weightedShift": weighted_shift_txt,
133
- "confShift": result["confShiftTxt"],
134
- "relative_E": result["relative_E"],
135
- "taskId": task_id,
136
- })
137
-
138
-
139
- # ── Download as CSV ───────────────────────────────────────────────────────────
140
- @bp.route("/download/<task_id>/")
141
- def download(task_id):
142
- raw = redis_client.get(f"task_result_{task_id}")
143
- if not raw:
144
- abort(404)
145
- result = json.loads(raw)
146
- if "errMessage" in result:
147
- abort(404)
148
-
149
- nucleus = result.get("type_", "C")
150
- header = f"Atom Index,Predicted {'1H' if nucleus == 'H' else '13C'} Shift (ppm)"
151
- lines = [header]
152
- for item in filter(None, result["weightedShiftTxt"].split(";")):
153
- parts = item.split(",")
154
- if len(parts) == 2:
155
- lines.append(f"{parts[0]},{parts[1]}")
156
-
157
- return Response(
158
- "\n".join(lines),
159
- mimetype="text/csv",
160
  headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"}
161
  )
162
 
@@ -170,6 +11,13 @@ def create_app():
170
  def root():
171
  return redirect("/cascade_v1/predict/")
172
 
 
 
 
 
 
 
 
173
  return app
174
 
175
  app = create_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"}
2
  )
3
 
 
11
  def root():
12
  return redirect("/cascade_v1/predict/")
13
 
14
+ # Allow HF Spaces to embed the app in an iframe
15
+ @app.after_request
16
+ def remove_iframe_restriction(response):
17
+ response.headers.pop("X-Frame-Options", None)
18
+ response.headers["Content-Security-Policy"] = "frame-ancestors *"
19
+ return response
20
+
21
  return app
22
 
23
  app = create_app()