Spaces:
Sleeping
Sleeping
File size: 10,460 Bytes
48b33cb eeac6cc 48b33cb eeac6cc 48b33cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | """
api/server.py โ Flask backend with pre-warming, caching, slim responses
"""
import sys, os, json, time, threading
# ็กฎไฟๅทฅไฝ็ฎๅฝๅง็ปๆฏ้กน็ฎๆ น็ฎๅฝ๏ผๆฌๅฐๅ Docker ้ฝ้็จ๏ผ
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.chdir(ROOT)
sys.path.insert(0, os.path.join(ROOT, 'pipeline'))
from flask import Flask, request, jsonify, send_from_directory
app = Flask(__name__,
static_folder=os.path.join(ROOT, 'frontend', 'dist'),
static_url_path='')
# โโ Pre-warm state โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
_state = {
'ready': False,
'step': 'idle', # idle | loading_data | building_models | done
'message': '็ญๅพ
ๅๅงๅ',
'started': None,
'elapsed': None,
}
_translator = None
_cache: dict = {} # syndrome โ result (LRU max 50)
_CACHE_MAX = 50
def _prewarm():
global _translator, _state
_state['started'] = time.time()
try:
_state['step'] = 'loading_data'
_state['message'] = 'ๅ ่ฝฝๆฐๆฎๅบ๏ผๅปๆก + TCMSP + SYMMAP + ้ถ็น๏ผโฆ'
from tcm_translator import TCMTranslator
t = TCMTranslator()
_state['step'] = 'building_models'
_state['message'] = 'ๆๅปบๆฃ็ดขๆจกๅ๏ผPageRank + ็ธไผผๅบฆ็ดขๅผ๏ผโฆ'
t._initialize()
_translator = t
_state['ready'] = True
_state['step'] = 'done'
_state['elapsed'] = round(time.time() - _state['started'], 1)
_state['message'] = f'ๅฐฑ็ปช๏ผๅๅงๅ่ๆถ {_state["elapsed"]}s๏ผ'
print(f"[server] Ready in {_state['elapsed']}s", flush=True)
except Exception as e:
_state['step'] = 'error'
_state['message'] = f'ๅๅงๅๅคฑ่ดฅ: {e}'
print(f"[server] INIT ERROR: {e}", flush=True)
# ๅฏๅจๆถ็ซๅณๅๅฐ้ข็ญ
threading.Thread(target=_prewarm, daemon=True).start()
# โโ CORS โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
@app.after_request
def cors(r):
r.headers['Access-Control-Allow-Origin'] = '*'
r.headers['Access-Control-Allow-Headers'] = 'Content-Type'
r.headers['Access-Control-Allow-Methods'] = 'GET,POST,OPTIONS'
return r
# โโ Endpoints โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
@app.route('/api/health')
def health():
return jsonify({'status': 'ok', 'ready': _state['ready']})
@app.route('/api/status')
def status():
return jsonify({
'ready': _state['ready'],
'step': _state['step'],
'message': _state['message'],
'elapsed': _state['elapsed'],
})
@app.route('/api/examples')
def examples():
return jsonify([
{'label': 'ๆฐ้ดไธค่ๅ
ผๆนฟๆต็้ป', 'desc': 'ๆๅธธ่ง็็ณๅฐฟ็
่พ็
่ฏๅ'},
{'label': '่พ่พ้ณ่ๅ
ผๆฐดๆนฟๅ
ๅ', 'desc': 'ๆฐด่ฟไธบไธป'},
{'label': '้ด้ณไธค่ๅ
ผ็ๆตๅ
็ป', 'desc': 'ๆๆๅคๆ่ฏๅ'},
{'label': '่่พ้ด่ๅ
ผๆฐ่่ก็', 'desc': '้ด่ไผด่ก็'},
{'label': 'ๆฐ่ๆนฟๆต็่ก้ปๆป', 'desc': 'ๆฐ่ๆนฟ็ไบ็ป'},
])
def _slim(result: dict, top_k_compounds: int, top_k_targets: int) -> dict:
"""็ฒพ็ฎ API ๅๅบ๏ผ่ฃๅช่ฟ้ฟๅญๆฎต๏ผๆงๅถๆฐ็ป้ฟๅบฆใ"""
def _clean_float(v):
if isinstance(v, float) and v != v: # NaN
return None
return v
def clean(obj):
if isinstance(obj, dict):
return {k: clean(v) for k, v in obj.items()}
if isinstance(obj, list):
return [clean(i) for i in obj]
if isinstance(obj, float):
return _clean_float(obj)
return obj
cases = []
for c in result.get('matched_cases', []):
cases.append({
'case_id': c['case_id'],
'source': c['source'],
'zheng_raw': c['zheng_raw'],
'similarity': round(c['similarity'], 4),
'herbs': c['herbs'][:8],
'role_jun': c['role_jun'][:3],
'role_chen': c['role_chen'][:3],
})
herbs = []
for h in result.get('core_herbs', []):
herbs.append({
'herb': h['herb'],
'freq_score': round(h['freq_score'], 4),
'n_cases': h['n_cases'],
'role_weight': h['role_weight'],
'centrality': round(h['centrality'], 4),
'core_score': round(h['core_score'], 4),
})
tcmsp_c = []
for c in result.get('top_compounds_tcmsp', [])[:top_k_compounds]:
tcmsp_c.append({
'mol_id': c.get('mol_id', ''),
'molecule_name': c.get('molecule_name', ''),
'ob': round(float(c.get('ob') or 0), 2),
'dl': round(float(c.get('dl') or 0), 3),
'compound_quality': round(float(c.get('compound_quality') or 0), 4),
'compound_score': round(float(c.get('compound_score') or 0), 5),
'source_herbs': str(c.get('source_herbs', ''))[:60],
})
symmap_c = []
for c in result.get('top_compounds_symmap', [])[:top_k_compounds]:
symmap_c.append({
'smit_id': c.get('smit_id', ''),
'molecule_name': c.get('molecule_name', ''),
'ob_score': round(float(c.get('ob_score') or 0), 2),
'evidence_score':round(float(c.get('evidence_score') or 0), 3),
'compound_score':round(float(c.get('compound_score') or 0), 5),
'pubchem_cid': c.get('pubchem_cid', ''),
})
tcmsp_t = []
for t in result.get('top_targets_tcmsp', [])[:top_k_targets]:
tcmsp_t.append({
'target_name': t.get('target_name', ''),
'target_id': t.get('target_id', ''),
'drugbank_id': t.get('drugbank_id', ''),
'validated': bool(t.get('validated', '')),
'target_score':round(float(t.get('target_score') or 0), 4),
'n_compounds': t.get('n_compounds', 0),
})
symmap_t = []
for t in result.get('top_targets_symmap', [])[:top_k_targets]:
symmap_t.append({
'gene_symbol': t.get('gene_symbol', ''),
'gene_name': t.get('gene_name', ''),
'protein_name': t.get('protein_name', ''),
'uniprot_id': t.get('uniprot_id', ''),
'min_p_value': t.get('min_p_value', 1.0),
'target_score': round(float(t.get('target_score') or 0), 4),
'n_herbs': t.get('n_herbs', 0),
'source_herbs': str(t.get('source_herbs', ''))[:80],
})
# Formula fingerprint top 12
fp = dict(sorted(
result.get('formula_fingerprint', {}).items(),
key=lambda x: x[1], reverse=True
)[:12])
return {
'input': result['input'],
'parsed_elements': result.get('parsed_elements', []),
'query_vector': [round(v, 3) for v in result.get('query_vector', [])],
'pattern_label': result.get('pattern_label', ''),
'n_matched': result.get('n_matched', 0),
'formula_fingerprint': fp,
'matched_cases': cases,
'core_herbs': herbs,
'top_compounds_tcmsp': tcmsp_c,
'top_compounds_symmap': symmap_c,
'top_targets_tcmsp': tcmsp_t,
'top_targets_symmap': symmap_t,
}
@app.route('/api/translate', methods=['POST', 'OPTIONS'])
def translate():
if request.method == 'OPTIONS':
return jsonify({}), 200
if not _state['ready']:
return jsonify({'error': '็ณป็ปๅๅงๅไธญ๏ผ่ฏท็จๅ', 'loading': True,
'message': _state['message']}), 503
body = request.get_json(force=True, silent=True) or {}
syndrome = body.get('syndrome', '').strip()
if not syndrome:
return jsonify({'error': '่ฏท่พๅ
ฅ่ฏๅๆ่ฟฐ'}), 400
top_k_compounds = int(body.get('top_k_compounds', 30))
top_k_targets = int(body.get('top_k_targets', 30))
# Cache check
cache_key = f"{syndrome}|{body.get('top_k_cases',15)}|{top_k_compounds}|{top_k_targets}"
if cache_key in _cache:
cached = dict(_cache[cache_key])
cached['cached'] = True
return jsonify(cached)
params = {
'top_k_cases': int(body.get('top_k_cases', 15)),
'top_k_herbs': int(body.get('top_k_herbs', 20)),
'top_k_compounds': top_k_compounds,
'top_k_targets': top_k_targets,
'ob_min': float(body.get('ob_min', 30.0)),
'dl_min': float(body.get('dl_min', 0.18)),
'confidence_min': float(body.get('confidence_min', 0.5)),
'fdr_max': float(body.get('fdr_max', 0.05)),
}
try:
t0 = time.time()
raw = _translator.translate(syndrome, **params)
elapsed = round(time.time() - t0, 2)
payload = _slim(raw, top_k_compounds, top_k_targets)
payload['elapsed_sec'] = elapsed
payload['cached'] = False
# Store in cache (evict oldest if full)
if len(_cache) >= _CACHE_MAX:
_cache.pop(next(iter(_cache)))
_cache[cache_key] = payload
return jsonify(payload)
except Exception as e:
import traceback; traceback.print_exc()
return jsonify({'error': str(e)}), 500
# โโ Static frontend โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>')
def spa(path):
dist = os.path.join(os.path.dirname(__file__), '..', 'frontend', 'dist')
full = os.path.join(dist, path)
if path and os.path.exists(full):
return send_from_directory(dist, path)
return send_from_directory(dist, 'index.html')
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5001))
print(f"TCM Translation API โ http://localhost:{port}")
print("Pre-warming in background threadโฆ")
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)
|