TCM / api /server.py
mardin123456's picture
fix: set ROOT cwd so relative data paths resolve correctly
eeac6cc verified
"""
api/server.py — Flask backend with pre-warming, caching, slim responses
"""
import sys, os, json, time, threading
# 确保工作目录始终是项目根目录(本地和 Docker 都适用)
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.chdir(ROOT)
sys.path.insert(0, os.path.join(ROOT, 'pipeline'))
from flask import Flask, request, jsonify, send_from_directory
app = Flask(__name__,
static_folder=os.path.join(ROOT, 'frontend', 'dist'),
static_url_path='')
# ── Pre-warm state ────────────────────────────────────────────────────────────
_state = {
'ready': False,
'step': 'idle', # idle | loading_data | building_models | done
'message': '等待初始化',
'started': None,
'elapsed': None,
}
_translator = None
_cache: dict = {} # syndrome → result (LRU max 50)
_CACHE_MAX = 50
def _prewarm():
global _translator, _state
_state['started'] = time.time()
try:
_state['step'] = 'loading_data'
_state['message'] = '加载数据库(医案 + TCMSP + SYMMAP + 靶点)…'
from tcm_translator import TCMTranslator
t = TCMTranslator()
_state['step'] = 'building_models'
_state['message'] = '构建检索模型(PageRank + 相似度索引)…'
t._initialize()
_translator = t
_state['ready'] = True
_state['step'] = 'done'
_state['elapsed'] = round(time.time() - _state['started'], 1)
_state['message'] = f'就绪(初始化耗时 {_state["elapsed"]}s)'
print(f"[server] Ready in {_state['elapsed']}s", flush=True)
except Exception as e:
_state['step'] = 'error'
_state['message'] = f'初始化失败: {e}'
print(f"[server] INIT ERROR: {e}", flush=True)
# 启动时立即后台预热
threading.Thread(target=_prewarm, daemon=True).start()
# ── CORS ──────────────────────────────────────────────────────────────────────
@app.after_request
def cors(r):
r.headers['Access-Control-Allow-Origin'] = '*'
r.headers['Access-Control-Allow-Headers'] = 'Content-Type'
r.headers['Access-Control-Allow-Methods'] = 'GET,POST,OPTIONS'
return r
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.route('/api/health')
def health():
return jsonify({'status': 'ok', 'ready': _state['ready']})
@app.route('/api/status')
def status():
return jsonify({
'ready': _state['ready'],
'step': _state['step'],
'message': _state['message'],
'elapsed': _state['elapsed'],
})
@app.route('/api/examples')
def examples():
return jsonify([
{'label': '气阴两虚兼湿浊瘀阻', 'desc': '最常见的糖尿病肾病证型'},
{'label': '脾肾阳虚兼水湿内停', 'desc': '水肿为主'},
{'label': '阴阳两虚兼瘀浊内结', 'desc': '晚期复杂证型'},
{'label': '肝肾阴虚兼气虚血瘀', 'desc': '阴虚伴血瘀'},
{'label': '气虚湿浊瘀血阻滞', 'desc': '气虚湿瘀互结'},
])
def _slim(result: dict, top_k_compounds: int, top_k_targets: int) -> dict:
"""精简 API 响应:裁剪过长字段,控制数组长度。"""
def _clean_float(v):
if isinstance(v, float) and v != v: # NaN
return None
return v
def clean(obj):
if isinstance(obj, dict):
return {k: clean(v) for k, v in obj.items()}
if isinstance(obj, list):
return [clean(i) for i in obj]
if isinstance(obj, float):
return _clean_float(obj)
return obj
cases = []
for c in result.get('matched_cases', []):
cases.append({
'case_id': c['case_id'],
'source': c['source'],
'zheng_raw': c['zheng_raw'],
'similarity': round(c['similarity'], 4),
'herbs': c['herbs'][:8],
'role_jun': c['role_jun'][:3],
'role_chen': c['role_chen'][:3],
})
herbs = []
for h in result.get('core_herbs', []):
herbs.append({
'herb': h['herb'],
'freq_score': round(h['freq_score'], 4),
'n_cases': h['n_cases'],
'role_weight': h['role_weight'],
'centrality': round(h['centrality'], 4),
'core_score': round(h['core_score'], 4),
})
tcmsp_c = []
for c in result.get('top_compounds_tcmsp', [])[:top_k_compounds]:
tcmsp_c.append({
'mol_id': c.get('mol_id', ''),
'molecule_name': c.get('molecule_name', ''),
'ob': round(float(c.get('ob') or 0), 2),
'dl': round(float(c.get('dl') or 0), 3),
'compound_quality': round(float(c.get('compound_quality') or 0), 4),
'compound_score': round(float(c.get('compound_score') or 0), 5),
'source_herbs': str(c.get('source_herbs', ''))[:60],
})
symmap_c = []
for c in result.get('top_compounds_symmap', [])[:top_k_compounds]:
symmap_c.append({
'smit_id': c.get('smit_id', ''),
'molecule_name': c.get('molecule_name', ''),
'ob_score': round(float(c.get('ob_score') or 0), 2),
'evidence_score':round(float(c.get('evidence_score') or 0), 3),
'compound_score':round(float(c.get('compound_score') or 0), 5),
'pubchem_cid': c.get('pubchem_cid', ''),
})
tcmsp_t = []
for t in result.get('top_targets_tcmsp', [])[:top_k_targets]:
tcmsp_t.append({
'target_name': t.get('target_name', ''),
'target_id': t.get('target_id', ''),
'drugbank_id': t.get('drugbank_id', ''),
'validated': bool(t.get('validated', '')),
'target_score':round(float(t.get('target_score') or 0), 4),
'n_compounds': t.get('n_compounds', 0),
})
symmap_t = []
for t in result.get('top_targets_symmap', [])[:top_k_targets]:
symmap_t.append({
'gene_symbol': t.get('gene_symbol', ''),
'gene_name': t.get('gene_name', ''),
'protein_name': t.get('protein_name', ''),
'uniprot_id': t.get('uniprot_id', ''),
'min_p_value': t.get('min_p_value', 1.0),
'target_score': round(float(t.get('target_score') or 0), 4),
'n_herbs': t.get('n_herbs', 0),
'source_herbs': str(t.get('source_herbs', ''))[:80],
})
# Formula fingerprint top 12
fp = dict(sorted(
result.get('formula_fingerprint', {}).items(),
key=lambda x: x[1], reverse=True
)[:12])
return {
'input': result['input'],
'parsed_elements': result.get('parsed_elements', []),
'query_vector': [round(v, 3) for v in result.get('query_vector', [])],
'pattern_label': result.get('pattern_label', ''),
'n_matched': result.get('n_matched', 0),
'formula_fingerprint': fp,
'matched_cases': cases,
'core_herbs': herbs,
'top_compounds_tcmsp': tcmsp_c,
'top_compounds_symmap': symmap_c,
'top_targets_tcmsp': tcmsp_t,
'top_targets_symmap': symmap_t,
}
@app.route('/api/translate', methods=['POST', 'OPTIONS'])
def translate():
if request.method == 'OPTIONS':
return jsonify({}), 200
if not _state['ready']:
return jsonify({'error': '系统初始化中,请稍候', 'loading': True,
'message': _state['message']}), 503
body = request.get_json(force=True, silent=True) or {}
syndrome = body.get('syndrome', '').strip()
if not syndrome:
return jsonify({'error': '请输入证候描述'}), 400
top_k_compounds = int(body.get('top_k_compounds', 30))
top_k_targets = int(body.get('top_k_targets', 30))
# Cache check
cache_key = f"{syndrome}|{body.get('top_k_cases',15)}|{top_k_compounds}|{top_k_targets}"
if cache_key in _cache:
cached = dict(_cache[cache_key])
cached['cached'] = True
return jsonify(cached)
params = {
'top_k_cases': int(body.get('top_k_cases', 15)),
'top_k_herbs': int(body.get('top_k_herbs', 20)),
'top_k_compounds': top_k_compounds,
'top_k_targets': top_k_targets,
'ob_min': float(body.get('ob_min', 30.0)),
'dl_min': float(body.get('dl_min', 0.18)),
'confidence_min': float(body.get('confidence_min', 0.5)),
'fdr_max': float(body.get('fdr_max', 0.05)),
}
try:
t0 = time.time()
raw = _translator.translate(syndrome, **params)
elapsed = round(time.time() - t0, 2)
payload = _slim(raw, top_k_compounds, top_k_targets)
payload['elapsed_sec'] = elapsed
payload['cached'] = False
# Store in cache (evict oldest if full)
if len(_cache) >= _CACHE_MAX:
_cache.pop(next(iter(_cache)))
_cache[cache_key] = payload
return jsonify(payload)
except Exception as e:
import traceback; traceback.print_exc()
return jsonify({'error': str(e)}), 500
# ── Static frontend ───────────────────────────────────────────────────────────
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>')
def spa(path):
dist = os.path.join(os.path.dirname(__file__), '..', 'frontend', 'dist')
full = os.path.join(dist, path)
if path and os.path.exists(full):
return send_from_directory(dist, path)
return send_from_directory(dist, 'index.html')
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5001))
print(f"TCM Translation API → http://localhost:{port}")
print("Pre-warming in background thread…")
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)