""" api/server.py — Flask backend with pre-warming, caching, slim responses """ import sys, os, json, time, threading # 确保工作目录始终是项目根目录(本地和 Docker 都适用) ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) os.chdir(ROOT) sys.path.insert(0, os.path.join(ROOT, 'pipeline')) from flask import Flask, request, jsonify, send_from_directory app = Flask(__name__, static_folder=os.path.join(ROOT, 'frontend', 'dist'), static_url_path='') # ── Pre-warm state ──────────────────────────────────────────────────────────── _state = { 'ready': False, 'step': 'idle', # idle | loading_data | building_models | done 'message': '等待初始化', 'started': None, 'elapsed': None, } _translator = None _cache: dict = {} # syndrome → result (LRU max 50) _CACHE_MAX = 50 def _prewarm(): global _translator, _state _state['started'] = time.time() try: _state['step'] = 'loading_data' _state['message'] = '加载数据库(医案 + TCMSP + SYMMAP + 靶点)…' from tcm_translator import TCMTranslator t = TCMTranslator() _state['step'] = 'building_models' _state['message'] = '构建检索模型(PageRank + 相似度索引)…' t._initialize() _translator = t _state['ready'] = True _state['step'] = 'done' _state['elapsed'] = round(time.time() - _state['started'], 1) _state['message'] = f'就绪(初始化耗时 {_state["elapsed"]}s)' print(f"[server] Ready in {_state['elapsed']}s", flush=True) except Exception as e: _state['step'] = 'error' _state['message'] = f'初始化失败: {e}' print(f"[server] INIT ERROR: {e}", flush=True) # 启动时立即后台预热 threading.Thread(target=_prewarm, daemon=True).start() # ── CORS ────────────────────────────────────────────────────────────────────── @app.after_request def cors(r): r.headers['Access-Control-Allow-Origin'] = '*' r.headers['Access-Control-Allow-Headers'] = 'Content-Type' r.headers['Access-Control-Allow-Methods'] = 'GET,POST,OPTIONS' return r # ── Endpoints ───────────────────────────────────────────────────────────────── @app.route('/api/health') def health(): return jsonify({'status': 'ok', 'ready': _state['ready']}) @app.route('/api/status') def status(): return jsonify({ 'ready': _state['ready'], 'step': _state['step'], 'message': _state['message'], 'elapsed': _state['elapsed'], }) @app.route('/api/examples') def examples(): return jsonify([ {'label': '气阴两虚兼湿浊瘀阻', 'desc': '最常见的糖尿病肾病证型'}, {'label': '脾肾阳虚兼水湿内停', 'desc': '水肿为主'}, {'label': '阴阳两虚兼瘀浊内结', 'desc': '晚期复杂证型'}, {'label': '肝肾阴虚兼气虚血瘀', 'desc': '阴虚伴血瘀'}, {'label': '气虚湿浊瘀血阻滞', 'desc': '气虚湿瘀互结'}, ]) def _slim(result: dict, top_k_compounds: int, top_k_targets: int) -> dict: """精简 API 响应:裁剪过长字段,控制数组长度。""" def _clean_float(v): if isinstance(v, float) and v != v: # NaN return None return v def clean(obj): if isinstance(obj, dict): return {k: clean(v) for k, v in obj.items()} if isinstance(obj, list): return [clean(i) for i in obj] if isinstance(obj, float): return _clean_float(obj) return obj cases = [] for c in result.get('matched_cases', []): cases.append({ 'case_id': c['case_id'], 'source': c['source'], 'zheng_raw': c['zheng_raw'], 'similarity': round(c['similarity'], 4), 'herbs': c['herbs'][:8], 'role_jun': c['role_jun'][:3], 'role_chen': c['role_chen'][:3], }) herbs = [] for h in result.get('core_herbs', []): herbs.append({ 'herb': h['herb'], 'freq_score': round(h['freq_score'], 4), 'n_cases': h['n_cases'], 'role_weight': h['role_weight'], 'centrality': round(h['centrality'], 4), 'core_score': round(h['core_score'], 4), }) tcmsp_c = [] for c in result.get('top_compounds_tcmsp', [])[:top_k_compounds]: tcmsp_c.append({ 'mol_id': c.get('mol_id', ''), 'molecule_name': c.get('molecule_name', ''), 'ob': round(float(c.get('ob') or 0), 2), 'dl': round(float(c.get('dl') or 0), 3), 'compound_quality': round(float(c.get('compound_quality') or 0), 4), 'compound_score': round(float(c.get('compound_score') or 0), 5), 'source_herbs': str(c.get('source_herbs', ''))[:60], }) symmap_c = [] for c in result.get('top_compounds_symmap', [])[:top_k_compounds]: symmap_c.append({ 'smit_id': c.get('smit_id', ''), 'molecule_name': c.get('molecule_name', ''), 'ob_score': round(float(c.get('ob_score') or 0), 2), 'evidence_score':round(float(c.get('evidence_score') or 0), 3), 'compound_score':round(float(c.get('compound_score') or 0), 5), 'pubchem_cid': c.get('pubchem_cid', ''), }) tcmsp_t = [] for t in result.get('top_targets_tcmsp', [])[:top_k_targets]: tcmsp_t.append({ 'target_name': t.get('target_name', ''), 'target_id': t.get('target_id', ''), 'drugbank_id': t.get('drugbank_id', ''), 'validated': bool(t.get('validated', '')), 'target_score':round(float(t.get('target_score') or 0), 4), 'n_compounds': t.get('n_compounds', 0), }) symmap_t = [] for t in result.get('top_targets_symmap', [])[:top_k_targets]: symmap_t.append({ 'gene_symbol': t.get('gene_symbol', ''), 'gene_name': t.get('gene_name', ''), 'protein_name': t.get('protein_name', ''), 'uniprot_id': t.get('uniprot_id', ''), 'min_p_value': t.get('min_p_value', 1.0), 'target_score': round(float(t.get('target_score') or 0), 4), 'n_herbs': t.get('n_herbs', 0), 'source_herbs': str(t.get('source_herbs', ''))[:80], }) # Formula fingerprint top 12 fp = dict(sorted( result.get('formula_fingerprint', {}).items(), key=lambda x: x[1], reverse=True )[:12]) return { 'input': result['input'], 'parsed_elements': result.get('parsed_elements', []), 'query_vector': [round(v, 3) for v in result.get('query_vector', [])], 'pattern_label': result.get('pattern_label', ''), 'n_matched': result.get('n_matched', 0), 'formula_fingerprint': fp, 'matched_cases': cases, 'core_herbs': herbs, 'top_compounds_tcmsp': tcmsp_c, 'top_compounds_symmap': symmap_c, 'top_targets_tcmsp': tcmsp_t, 'top_targets_symmap': symmap_t, } @app.route('/api/translate', methods=['POST', 'OPTIONS']) def translate(): if request.method == 'OPTIONS': return jsonify({}), 200 if not _state['ready']: return jsonify({'error': '系统初始化中,请稍候', 'loading': True, 'message': _state['message']}), 503 body = request.get_json(force=True, silent=True) or {} syndrome = body.get('syndrome', '').strip() if not syndrome: return jsonify({'error': '请输入证候描述'}), 400 top_k_compounds = int(body.get('top_k_compounds', 30)) top_k_targets = int(body.get('top_k_targets', 30)) # Cache check cache_key = f"{syndrome}|{body.get('top_k_cases',15)}|{top_k_compounds}|{top_k_targets}" if cache_key in _cache: cached = dict(_cache[cache_key]) cached['cached'] = True return jsonify(cached) params = { 'top_k_cases': int(body.get('top_k_cases', 15)), 'top_k_herbs': int(body.get('top_k_herbs', 20)), 'top_k_compounds': top_k_compounds, 'top_k_targets': top_k_targets, 'ob_min': float(body.get('ob_min', 30.0)), 'dl_min': float(body.get('dl_min', 0.18)), 'confidence_min': float(body.get('confidence_min', 0.5)), 'fdr_max': float(body.get('fdr_max', 0.05)), } try: t0 = time.time() raw = _translator.translate(syndrome, **params) elapsed = round(time.time() - t0, 2) payload = _slim(raw, top_k_compounds, top_k_targets) payload['elapsed_sec'] = elapsed payload['cached'] = False # Store in cache (evict oldest if full) if len(_cache) >= _CACHE_MAX: _cache.pop(next(iter(_cache))) _cache[cache_key] = payload return jsonify(payload) except Exception as e: import traceback; traceback.print_exc() return jsonify({'error': str(e)}), 500 # ── Static frontend ─────────────────────────────────────────────────────────── @app.route('/', defaults={'path': ''}) @app.route('/') def spa(path): dist = os.path.join(os.path.dirname(__file__), '..', 'frontend', 'dist') full = os.path.join(dist, path) if path and os.path.exists(full): return send_from_directory(dist, path) return send_from_directory(dist, 'index.html') if __name__ == '__main__': port = int(os.environ.get('PORT', 5001)) print(f"TCM Translation API → http://localhost:{port}") print("Pre-warming in background thread…") app.run(host='0.0.0.0', port=port, debug=False, threaded=True)