Spaces:
Sleeping
Sleeping
| """ | |
| api/server.py — Flask backend with pre-warming, caching, slim responses | |
| """ | |
| import sys, os, json, time, threading | |
| # 确保工作目录始终是项目根目录(本地和 Docker 都适用) | |
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| os.chdir(ROOT) | |
| sys.path.insert(0, os.path.join(ROOT, 'pipeline')) | |
| from flask import Flask, request, jsonify, send_from_directory | |
| app = Flask(__name__, | |
| static_folder=os.path.join(ROOT, 'frontend', 'dist'), | |
| static_url_path='') | |
| # ── Pre-warm state ──────────────────────────────────────────────────────────── | |
| _state = { | |
| 'ready': False, | |
| 'step': 'idle', # idle | loading_data | building_models | done | |
| 'message': '等待初始化', | |
| 'started': None, | |
| 'elapsed': None, | |
| } | |
| _translator = None | |
| _cache: dict = {} # syndrome → result (LRU max 50) | |
| _CACHE_MAX = 50 | |
| def _prewarm(): | |
| global _translator, _state | |
| _state['started'] = time.time() | |
| try: | |
| _state['step'] = 'loading_data' | |
| _state['message'] = '加载数据库(医案 + TCMSP + SYMMAP + 靶点)…' | |
| from tcm_translator import TCMTranslator | |
| t = TCMTranslator() | |
| _state['step'] = 'building_models' | |
| _state['message'] = '构建检索模型(PageRank + 相似度索引)…' | |
| t._initialize() | |
| _translator = t | |
| _state['ready'] = True | |
| _state['step'] = 'done' | |
| _state['elapsed'] = round(time.time() - _state['started'], 1) | |
| _state['message'] = f'就绪(初始化耗时 {_state["elapsed"]}s)' | |
| print(f"[server] Ready in {_state['elapsed']}s", flush=True) | |
| except Exception as e: | |
| _state['step'] = 'error' | |
| _state['message'] = f'初始化失败: {e}' | |
| print(f"[server] INIT ERROR: {e}", flush=True) | |
| # 启动时立即后台预热 | |
| threading.Thread(target=_prewarm, daemon=True).start() | |
| # ── CORS ────────────────────────────────────────────────────────────────────── | |
| def cors(r): | |
| r.headers['Access-Control-Allow-Origin'] = '*' | |
| r.headers['Access-Control-Allow-Headers'] = 'Content-Type' | |
| r.headers['Access-Control-Allow-Methods'] = 'GET,POST,OPTIONS' | |
| return r | |
| # ── Endpoints ───────────────────────────────────────────────────────────────── | |
| def health(): | |
| return jsonify({'status': 'ok', 'ready': _state['ready']}) | |
| def status(): | |
| return jsonify({ | |
| 'ready': _state['ready'], | |
| 'step': _state['step'], | |
| 'message': _state['message'], | |
| 'elapsed': _state['elapsed'], | |
| }) | |
| def examples(): | |
| return jsonify([ | |
| {'label': '气阴两虚兼湿浊瘀阻', 'desc': '最常见的糖尿病肾病证型'}, | |
| {'label': '脾肾阳虚兼水湿内停', 'desc': '水肿为主'}, | |
| {'label': '阴阳两虚兼瘀浊内结', 'desc': '晚期复杂证型'}, | |
| {'label': '肝肾阴虚兼气虚血瘀', 'desc': '阴虚伴血瘀'}, | |
| {'label': '气虚湿浊瘀血阻滞', 'desc': '气虚湿瘀互结'}, | |
| ]) | |
| def _slim(result: dict, top_k_compounds: int, top_k_targets: int) -> dict: | |
| """精简 API 响应:裁剪过长字段,控制数组长度。""" | |
| def _clean_float(v): | |
| if isinstance(v, float) and v != v: # NaN | |
| return None | |
| return v | |
| def clean(obj): | |
| if isinstance(obj, dict): | |
| return {k: clean(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [clean(i) for i in obj] | |
| if isinstance(obj, float): | |
| return _clean_float(obj) | |
| return obj | |
| cases = [] | |
| for c in result.get('matched_cases', []): | |
| cases.append({ | |
| 'case_id': c['case_id'], | |
| 'source': c['source'], | |
| 'zheng_raw': c['zheng_raw'], | |
| 'similarity': round(c['similarity'], 4), | |
| 'herbs': c['herbs'][:8], | |
| 'role_jun': c['role_jun'][:3], | |
| 'role_chen': c['role_chen'][:3], | |
| }) | |
| herbs = [] | |
| for h in result.get('core_herbs', []): | |
| herbs.append({ | |
| 'herb': h['herb'], | |
| 'freq_score': round(h['freq_score'], 4), | |
| 'n_cases': h['n_cases'], | |
| 'role_weight': h['role_weight'], | |
| 'centrality': round(h['centrality'], 4), | |
| 'core_score': round(h['core_score'], 4), | |
| }) | |
| tcmsp_c = [] | |
| for c in result.get('top_compounds_tcmsp', [])[:top_k_compounds]: | |
| tcmsp_c.append({ | |
| 'mol_id': c.get('mol_id', ''), | |
| 'molecule_name': c.get('molecule_name', ''), | |
| 'ob': round(float(c.get('ob') or 0), 2), | |
| 'dl': round(float(c.get('dl') or 0), 3), | |
| 'compound_quality': round(float(c.get('compound_quality') or 0), 4), | |
| 'compound_score': round(float(c.get('compound_score') or 0), 5), | |
| 'source_herbs': str(c.get('source_herbs', ''))[:60], | |
| }) | |
| symmap_c = [] | |
| for c in result.get('top_compounds_symmap', [])[:top_k_compounds]: | |
| symmap_c.append({ | |
| 'smit_id': c.get('smit_id', ''), | |
| 'molecule_name': c.get('molecule_name', ''), | |
| 'ob_score': round(float(c.get('ob_score') or 0), 2), | |
| 'evidence_score':round(float(c.get('evidence_score') or 0), 3), | |
| 'compound_score':round(float(c.get('compound_score') or 0), 5), | |
| 'pubchem_cid': c.get('pubchem_cid', ''), | |
| }) | |
| tcmsp_t = [] | |
| for t in result.get('top_targets_tcmsp', [])[:top_k_targets]: | |
| tcmsp_t.append({ | |
| 'target_name': t.get('target_name', ''), | |
| 'target_id': t.get('target_id', ''), | |
| 'drugbank_id': t.get('drugbank_id', ''), | |
| 'validated': bool(t.get('validated', '')), | |
| 'target_score':round(float(t.get('target_score') or 0), 4), | |
| 'n_compounds': t.get('n_compounds', 0), | |
| }) | |
| symmap_t = [] | |
| for t in result.get('top_targets_symmap', [])[:top_k_targets]: | |
| symmap_t.append({ | |
| 'gene_symbol': t.get('gene_symbol', ''), | |
| 'gene_name': t.get('gene_name', ''), | |
| 'protein_name': t.get('protein_name', ''), | |
| 'uniprot_id': t.get('uniprot_id', ''), | |
| 'min_p_value': t.get('min_p_value', 1.0), | |
| 'target_score': round(float(t.get('target_score') or 0), 4), | |
| 'n_herbs': t.get('n_herbs', 0), | |
| 'source_herbs': str(t.get('source_herbs', ''))[:80], | |
| }) | |
| # Formula fingerprint top 12 | |
| fp = dict(sorted( | |
| result.get('formula_fingerprint', {}).items(), | |
| key=lambda x: x[1], reverse=True | |
| )[:12]) | |
| return { | |
| 'input': result['input'], | |
| 'parsed_elements': result.get('parsed_elements', []), | |
| 'query_vector': [round(v, 3) for v in result.get('query_vector', [])], | |
| 'pattern_label': result.get('pattern_label', ''), | |
| 'n_matched': result.get('n_matched', 0), | |
| 'formula_fingerprint': fp, | |
| 'matched_cases': cases, | |
| 'core_herbs': herbs, | |
| 'top_compounds_tcmsp': tcmsp_c, | |
| 'top_compounds_symmap': symmap_c, | |
| 'top_targets_tcmsp': tcmsp_t, | |
| 'top_targets_symmap': symmap_t, | |
| } | |
| def translate(): | |
| if request.method == 'OPTIONS': | |
| return jsonify({}), 200 | |
| if not _state['ready']: | |
| return jsonify({'error': '系统初始化中,请稍候', 'loading': True, | |
| 'message': _state['message']}), 503 | |
| body = request.get_json(force=True, silent=True) or {} | |
| syndrome = body.get('syndrome', '').strip() | |
| if not syndrome: | |
| return jsonify({'error': '请输入证候描述'}), 400 | |
| top_k_compounds = int(body.get('top_k_compounds', 30)) | |
| top_k_targets = int(body.get('top_k_targets', 30)) | |
| # Cache check | |
| cache_key = f"{syndrome}|{body.get('top_k_cases',15)}|{top_k_compounds}|{top_k_targets}" | |
| if cache_key in _cache: | |
| cached = dict(_cache[cache_key]) | |
| cached['cached'] = True | |
| return jsonify(cached) | |
| params = { | |
| 'top_k_cases': int(body.get('top_k_cases', 15)), | |
| 'top_k_herbs': int(body.get('top_k_herbs', 20)), | |
| 'top_k_compounds': top_k_compounds, | |
| 'top_k_targets': top_k_targets, | |
| 'ob_min': float(body.get('ob_min', 30.0)), | |
| 'dl_min': float(body.get('dl_min', 0.18)), | |
| 'confidence_min': float(body.get('confidence_min', 0.5)), | |
| 'fdr_max': float(body.get('fdr_max', 0.05)), | |
| } | |
| try: | |
| t0 = time.time() | |
| raw = _translator.translate(syndrome, **params) | |
| elapsed = round(time.time() - t0, 2) | |
| payload = _slim(raw, top_k_compounds, top_k_targets) | |
| payload['elapsed_sec'] = elapsed | |
| payload['cached'] = False | |
| # Store in cache (evict oldest if full) | |
| if len(_cache) >= _CACHE_MAX: | |
| _cache.pop(next(iter(_cache))) | |
| _cache[cache_key] = payload | |
| return jsonify(payload) | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| # ── Static frontend ─────────────────────────────────────────────────────────── | |
| def spa(path): | |
| dist = os.path.join(os.path.dirname(__file__), '..', 'frontend', 'dist') | |
| full = os.path.join(dist, path) | |
| if path and os.path.exists(full): | |
| return send_from_directory(dist, path) | |
| return send_from_directory(dist, 'index.html') | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 5001)) | |
| print(f"TCM Translation API → http://localhost:{port}") | |
| print("Pre-warming in background thread…") | |
| app.run(host='0.0.0.0', port=port, debug=False, threaded=True) | |