File size: 6,437 Bytes
ff85ac8 6a594c1 ff85ac8 233f6d4 e7afba3 233f6d4 d7ad5f5 233f6d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
CASCADE β Flask app for HF Spaces
Mounted at /cascade_v1/
"""
import os
import json
import uuid
import redis
from flask import Flask, Blueprint, request, jsonify, render_template, send_file, abort, redirect, Response
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
from NMR_Prediction.valid import validate_smiles
bp = Blueprint("cascade", __name__)
redis_client = redis.StrictRedis(
host="localhost", port=6379, db=0, decode_responses=True
)
# ββ Pages βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@bp.route("/")
@bp.route("")
@bp.route("/predict/")
@bp.route("/predict")
@bp.route("/home/")
@bp.route("/about/")
def predict():
return render_template("cascade/predict.html")
# ββ Job submission ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _submit_job(smiles, type_):
if not validate_smiles(smiles):
return jsonify({"message": "Invalid SMILES or molecule exceeds 50 heavy atoms", "task_id": None})
task_id = uuid.uuid4().hex
redis_client.set(f"task_detail_{task_id}",
json.dumps({"smiles": smiles, "type_": type_}))
redis_client.rpush("task_queue", task_id)
return task_id
@bp.route("/predict_NMR_C/", methods=["POST"])
def predict_NMR_C():
result = _submit_job(request.form["smiles"], request.form["type_"])
if not isinstance(result, str):
return result
return jsonify({"message": "Molecule submitted to C queue", "task_id": result})
@bp.route("/predict_NMR_H/", methods=["POST"])
def predict_NMR_H():
result = _submit_job(request.form["smiles"], request.form["type_"])
if not isinstance(result, str):
return result
return jsonify({"message": "Molecule submitted to H queue", "task_id": result})
# ββ check_task ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@bp.route("/check_task/")
def check_task():
raw = redis_client.get(f"task_result_{request.args['task_id']}")
if not raw:
return "running", 200
result = json.loads(raw)
if "errMessage" in result:
return "Error1", 200
return "done", 200
# ββ get_result ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@bp.route("/get_result/")
def get_result():
task_id = request.args["task_id"]
raw = redis_client.get(f"task_result_{task_id}")
if not raw:
return jsonify({"error": "Result not found"}), 404
result = json.loads(raw)
if "errMessage" in result:
return jsonify({"error": result["errMessage"]}), 500
smiles = result["smiles"]
nucleus = result.get("type_", "C")
weighted_shift_txt = result["weightedShiftTxt"]
shift_map = {}
for item in filter(None, weighted_shift_txt.split(";")):
parts = item.split(",")
if len(parts) == 2:
shift_map[int(parts[0])] = parts[1]
if nucleus == "H":
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
AllChem.Compute2DCoords(mol)
mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
n_label = mol.GetNumAtoms()
drawer = rdMolDraw2D.MolDraw2DSVG(700, 500)
else:
mol = Chem.MolFromSmiles(smiles)
AllChem.Compute2DCoords(mol)
mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
n_label = mol.GetNumAtoms()
drawer = rdMolDraw2D.MolDraw2DSVG(600, 450)
opts = drawer.drawOptions()
for atom_1idx, shift_val in shift_map.items():
atom_0idx = atom_1idx - 1
if atom_0idx < n_label:
opts.atomLabels[atom_0idx] = shift_val
opts.clearBackground = False
opts.bondLineWidth = 1
opts.padding = 0.15
opts.additionalAtomLabelPadding = 0.1
drawer.DrawMolecule(mol_draw)
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace("svg:", "").replace(":svg", "")
return jsonify({
"svg": svg,
"smiles": smiles,
"nucleus": nucleus,
"conf_sdfs": result.get("conf_sdfs", []),
"weightedShift": weighted_shift_txt,
"confShift": result["confShiftTxt"],
"relative_E": result["relative_E"],
"taskId": task_id,
})
# ββ Download as CSV βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@bp.route("/download/<task_id>/")
def download(task_id):
raw = redis_client.get(f"task_result_{task_id}")
if not raw:
abort(404)
result = json.loads(raw)
if "errMessage" in result:
abort(404)
nucleus = result.get("type_", "C")
header = f"Atom Index,Predicted {'1H' if nucleus == 'H' else '13C'} Shift (ppm)"
lines = [header]
for item in filter(None, result["weightedShiftTxt"].split(";")):
parts = item.split(",")
if len(parts) == 2:
lines.append(f"{parts[0]},{parts[1]}")
return Response(
"\n".join(lines),
mimetype="text/csv",
headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"}
)
def create_app():
app = Flask(__name__, static_folder="static", template_folder="templates")
app.register_blueprint(bp, url_prefix="/cascade_v1")
@app.route("/")
@app.route("/cascade_v1")
def root():
return '''<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="0; url=/cascade_v1/predict/">
<script>window.location.href="/cascade_v1/predict/";</script>
</head>
<body></body>
</html>'''
@app.after_request
def remove_iframe_restriction(response):
response.headers.pop("X-Frame-Options", None)
response.headers["Content-Security-Policy"] = "frame-ancestors *"
return response
return app
app = create_app()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)
|