Spaces:
Running
Running
File size: 8,724 Bytes
a556c58 2a2c039 a556c58 5cb4c11 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 5cb4c11 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 2a2c039 a556c58 5cb4c11 a556c58 2a2c039 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | from dotenv import load_dotenv
load_dotenv()
from flask import Flask, request, jsonify, send_from_directory, Response
import json as json_lib
import os
from src.bio_rag.pipeline import BioRAGPipeline
from src.bio_rag.config import BioRAGConfig
app = Flask(__name__, static_folder='static')
# Load pipeline once at startup
print("Loading Bio-RAG pipeline...")
config = BioRAGConfig()
pipeline = BioRAGPipeline(config)
print("Pipeline ready!")
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/api/ask', methods=['POST'])
def ask():
try:
data = request.get_json()
question = data.get('question', '').strip()
if not question:
return jsonify({'error': 'No question provided'}), 400
result = pipeline.ask(question)
return jsonify(result.to_dict())
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/ask-stream', methods=['POST'])
def ask_stream():
data = request.get_json()
question = data.get('question', '').strip()
if not question:
return jsonify({'error': 'No question provided'}), 400
def generate():
import time
try:
_start_time = time.time()
phase_times = {}
token_stats = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
yield f"data: {json_lib.dumps({'step': 0, 'status': 'active'})}\n\n"
time.sleep(0.1)
yield f"data: {json_lib.dumps({'step': 0, 'status': 'done'})}\n\n"
time.sleep(0.1)
is_valid, msg = pipeline.query_processor.validate_domain(question)
if not is_valid:
r = {'question': question, 'original_answer': '', 'final_answer': msg, 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': msg, 'processing_time_seconds': round(time.time() - _start_time, 2)}
yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n"
return
yield f"data: {json_lib.dumps({'step': 1, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p1_start = time.time()
queries = pipeline.query_processor.expand_queries(question)
phase_times['query_expansion'] = round(time.time() - _p1_start, 2)
yield f"data: {json_lib.dumps({'step': 1, 'status': 'done'})}\n\n"
time.sleep(0.1)
yield f"data: {json_lib.dumps({'step': 2, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p2_start = time.time()
passages = pipeline.retriever.retrieve(queries)
phase_times['retrieval'] = round(time.time() - _p2_start, 2)
yield f"data: {json_lib.dumps({'step': 2, 'status': 'done'})}\n\n"
time.sleep(0.1)
if len(passages) < 3:
r = {'question': question, 'original_answer': '', 'final_answer': 'Insufficient evidence.', 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': 'Insufficient evidence.', 'processing_time_seconds': round(time.time() - _start_time, 2)}
yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n"
return
yield f"data: {json_lib.dumps({'step': 3, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p3_start = time.time()
original_answer = pipeline.generator.generate(question, passages)
phase_times['generation'] = round(time.time() - _p3_start, 2)
if hasattr(pipeline.generator, 'last_usage'):
u = pipeline.generator.last_usage
token_stats['prompt_tokens'] += u.prompt_tokens
token_stats['completion_tokens'] += u.completion_tokens
token_stats['total_tokens'] += u.total_tokens
yield f"data: {json_lib.dumps({'step': 3, 'status': 'done'})}\n\n"
time.sleep(0.1)
# Send answer_ready event
try:
answer_event = json_lib.dumps({'answer_ready': True, 'answer': original_answer}, ensure_ascii=False)
print(f"[DEBUG] answer_ready event length: {len(answer_event)}")
yield f"data: {answer_event}\n\n"
except Exception as e:
print(f"[ERROR] Failed to send answer_ready: {e}")
yield f"data: {json_lib.dumps({'answer_ready': True, 'answer': 'Error encoding answer'})}\n\n"
yield f"data: {json_lib.dumps({'step': 4, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p4_start = time.time()
try:
co = pipeline.claim_decomposer.decompose(question, original_answer)
claims = co if isinstance(co, list) and len(co) > 0 else [original_answer]
except Exception:
claims = [original_answer]
yield f"data: {json_lib.dumps({'step': 4, 'status': 'done'})}\n\n"
phase_times['decomposition'] = round(time.time() - _p4_start, 2)
time.sleep(0.1)
yield f"data: {json_lib.dumps({'step': 5, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p5_start = time.time()
claim_checks = []
max_risk = 0.0
for claim in claims:
eq = f"{question} {claim}"
cp = pipeline.retriever.retrieve([eq])[:10]
ce = " ".join([p.text for p in cp])[:1500]
nli = pipeline.nli_evaluator.evaluate(claim, [ce])
pf = pipeline.risk_scorer.calculate_profile(claim)
rs = pipeline.risk_scorer.compute_weighted_risk(nli, pf)
max_risk = max(max_risk, rs)
claim_checks.append({"claim": claim, "nli_prob": round(nli, 4), "severity_score": pf.severity, "type_score": pf.type_score, "omission_score": pf.omission, "risk_score": round(rs, 4)})
yield f"data: {json_lib.dumps({'step': 5, 'status': 'done'})}\n\n"
phase_times['verification'] = round(time.time() - _p5_start, 2)
yield f"data: {json_lib.dumps({'step': 6, 'status': 'active'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 6, 'status': 'done'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 7, 'status': 'active'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 7, 'status': 'done'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 8, 'status': 'active'})}\n\n"
time.sleep(0.1)
is_safe = max_risk < 0.7
fa = original_answer if is_safe else f"WARNING: This answer contains potentially unverified medical information.\n\n{original_answer}"
yield f"data: {json_lib.dumps({'step': 8, 'status': 'done'})}\n\n"
time.sleep(0.1)
ev = [{'text': p.text if hasattr(p, 'text') else str(p), 'qid': p.qid if hasattr(p, 'qid') else ''} for p in passages[:3]]
r = {
'question': question,
'original_answer': original_answer,
'final_answer': fa,
'evidence': ev,
'claims': claims,
'claim_checks': claim_checks,
'max_risk_score': round(max_risk, 4),
'safe': is_safe,
'rejection_message': '',
'processing_time_seconds': round(time.time() - _start_time, 2),
'processing_stats': {
'total_db_size': len(pipeline.retriever._docs),
'queries_generated': len(queries),
'passages_retrieved': len(passages),
'claims_verified': len(claims),
'evidence_per_claim': 10,
'total_evidence_evaluated': len(claims) * 10,
'phase_times': phase_times,
'token_usage': token_stats,
}
}
yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n"
except Exception as e:
yield f"data: {json_lib.dumps({'error': str(e)})}\n\n"
return Response(generate(), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive'})
if __name__ == '__main__':
import os
port = int(os.environ.get('PORT', 7860))
app.run(debug=False, host='0.0.0.0', port=port)
|