bobbypaton commited on
Commit
ff85ac8
Β·
1 Parent(s): d7ad5f5

Fix app.py truncation; add iframe headers for HF Spaces

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py CHANGED
@@ -1,3 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"}
2
  )
3
 
 
1
+ """
2
+ CASCADE – Flask app for HF Spaces
3
+ Mounted at /cascade_v1/
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import uuid
9
+
10
+ import redis
11
+ from flask import Flask, Blueprint, request, jsonify, render_template, send_file, abort, redirect, Response
12
+
13
+ from rdkit import Chem
14
+ from rdkit.Chem import AllChem
15
+ from rdkit.Chem.Draw import rdMolDraw2D
16
+
17
+ from NMR_Prediction.valid import validate_smiles
18
+
19
+ bp = Blueprint("cascade", __name__)
20
+
21
+ redis_client = redis.StrictRedis(
22
+ host="localhost", port=6379, db=0, decode_responses=True
23
+ )
24
+
25
+ # ── Pages ─────────────────────────────────────────────────────────────────────
26
+ @bp.route("/")
27
+ @bp.route("")
28
+ @bp.route("/predict/")
29
+ @bp.route("/predict")
30
+ @bp.route("/home/")
31
+ @bp.route("/about/")
32
+ def predict():
33
+ return render_template("cascade/predict.html")
34
+
35
+
36
+ # ── Job submission ────────────────────────────────────────────────────────────
37
+ def _submit_job(smiles, type_):
38
+ if not validate_smiles(smiles):
39
+ return jsonify({"message": "Input molecule is not allowed", "task_id": None})
40
+ task_id = uuid.uuid4().hex
41
+ redis_client.set(f"task_detail_{task_id}",
42
+ json.dumps({"smiles": smiles, "type_": type_}))
43
+ redis_client.rpush("task_queue", task_id)
44
+ return task_id
45
+
46
+
47
+ @bp.route("/predict_NMR_C/", methods=["POST"])
48
+ def predict_NMR_C():
49
+ result = _submit_job(request.form["smiles"], request.form["type_"])
50
+ if not isinstance(result, str):
51
+ return result
52
+ return jsonify({"message": "Molecule submitted to C queue", "task_id": result})
53
+
54
+
55
+ @bp.route("/predict_NMR_H/", methods=["POST"])
56
+ def predict_NMR_H():
57
+ result = _submit_job(request.form["smiles"], request.form["type_"])
58
+ if not isinstance(result, str):
59
+ return result
60
+ return jsonify({"message": "Molecule submitted to H queue", "task_id": result})
61
+
62
+
63
+ # ── check_task ────────────────────────────────────────────────────────────────
64
+ @bp.route("/check_task/")
65
+ def check_task():
66
+ raw = redis_client.get(f"task_result_{request.args['task_id']}")
67
+ if not raw:
68
+ return "running", 200
69
+ result = json.loads(raw)
70
+ if "errMessage" in result:
71
+ return "Error1", 200
72
+ return "done", 200
73
+
74
+
75
+ # ── get_result ────────────────────────────────────────────────────────────────
76
+ @bp.route("/get_result/")
77
+ def get_result():
78
+ task_id = request.args["task_id"]
79
+ raw = redis_client.get(f"task_result_{task_id}")
80
+ if not raw:
81
+ return jsonify({"error": "Result not found"}), 404
82
+
83
+ result = json.loads(raw)
84
+ if "errMessage" in result:
85
+ return jsonify({"error": result["errMessage"]}), 500
86
+
87
+ smiles = result["smiles"]
88
+ nucleus = result.get("type_", "C")
89
+
90
+ weighted_shift_txt = result["weightedShiftTxt"]
91
+ shift_map = {}
92
+ for item in filter(None, weighted_shift_txt.split(";")):
93
+ parts = item.split(",")
94
+ if len(parts) == 2:
95
+ shift_map[int(parts[0])] = parts[1]
96
+
97
+ if nucleus == "H":
98
+ mol = Chem.MolFromSmiles(smiles)
99
+ mol = Chem.AddHs(mol)
100
+ AllChem.Compute2DCoords(mol)
101
+ mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
102
+ n_label = mol.GetNumAtoms()
103
+ drawer = rdMolDraw2D.MolDraw2DSVG(700, 500)
104
+ else:
105
+ mol = Chem.MolFromSmiles(smiles)
106
+ AllChem.Compute2DCoords(mol)
107
+ mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
108
+ n_label = mol.GetNumAtoms()
109
+ drawer = rdMolDraw2D.MolDraw2DSVG(600, 450)
110
+
111
+ opts = drawer.drawOptions()
112
+ for atom_1idx, shift_val in shift_map.items():
113
+ atom_0idx = atom_1idx - 1
114
+ if atom_0idx < n_label:
115
+ opts.atomLabels[atom_0idx] = shift_val
116
+ opts.clearBackground = False
117
+ opts.bondLineWidth = 1
118
+ opts.padding = 0.15
119
+ opts.additionalAtomLabelPadding = 0.1
120
+
121
+ drawer.DrawMolecule(mol_draw)
122
+ drawer.FinishDrawing()
123
+ svg = drawer.GetDrawingText().replace("svg:", "").replace(":svg", "")
124
+
125
+ return jsonify({
126
+ "svg": svg,
127
+ "smiles": smiles,
128
+ "nucleus": nucleus,
129
+ "conf_sdfs": result.get("conf_sdfs", []),
130
+ "weightedShift": weighted_shift_txt,
131
+ "confShift": result["confShiftTxt"],
132
+ "relative_E": result["relative_E"],
133
+ "taskId": task_id,
134
+ })
135
+
136
+
137
+ # ── Download as CSV ───────────────────────────────────────────────────────────
138
+ @bp.route("/download/<task_id>/")
139
+ def download(task_id):
140
+ raw = redis_client.get(f"task_result_{task_id}")
141
+ if not raw:
142
+ abort(404)
143
+ result = json.loads(raw)
144
+ if "errMessage" in result:
145
+ abort(404)
146
+
147
+ nucleus = result.get("type_", "C")
148
+ header = f"Atom Index,Predicted {'1H' if nucleus == 'H' else '13C'} Shift (ppm)"
149
+ lines = [header]
150
+ for item in filter(None, result["weightedShiftTxt"].split(";")):
151
+ parts = item.split(",")
152
+ if len(parts) == 2:
153
+ lines.append(f"{parts[0]},{parts[1]}")
154
+
155
+ return Response(
156
+ "\n".join(lines),
157
+ mimetype="text/csv",
158
  headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"}
159
  )
160