BioRAG / web_app.py
aseelflihan's picture
feat: add token usage tracking and display, update sample questions for demo scenarios
5cb4c11
from dotenv import load_dotenv
load_dotenv()
from flask import Flask, request, jsonify, send_from_directory, Response
import json as json_lib
import os
from src.bio_rag.pipeline import BioRAGPipeline
from src.bio_rag.config import BioRAGConfig
app = Flask(__name__, static_folder='static')
# Load pipeline once at startup
print("Loading Bio-RAG pipeline...")
config = BioRAGConfig()
pipeline = BioRAGPipeline(config)
print("Pipeline ready!")
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/api/ask', methods=['POST'])
def ask():
try:
data = request.get_json()
question = data.get('question', '').strip()
if not question:
return jsonify({'error': 'No question provided'}), 400
result = pipeline.ask(question)
return jsonify(result.to_dict())
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/ask-stream', methods=['POST'])
def ask_stream():
data = request.get_json()
question = data.get('question', '').strip()
if not question:
return jsonify({'error': 'No question provided'}), 400
def generate():
import time
try:
_start_time = time.time()
phase_times = {}
token_stats = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
yield f"data: {json_lib.dumps({'step': 0, 'status': 'active'})}\n\n"
time.sleep(0.1)
yield f"data: {json_lib.dumps({'step': 0, 'status': 'done'})}\n\n"
time.sleep(0.1)
is_valid, msg = pipeline.query_processor.validate_domain(question)
if not is_valid:
r = {'question': question, 'original_answer': '', 'final_answer': msg, 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': msg, 'processing_time_seconds': round(time.time() - _start_time, 2)}
yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n"
return
yield f"data: {json_lib.dumps({'step': 1, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p1_start = time.time()
queries = pipeline.query_processor.expand_queries(question)
phase_times['query_expansion'] = round(time.time() - _p1_start, 2)
yield f"data: {json_lib.dumps({'step': 1, 'status': 'done'})}\n\n"
time.sleep(0.1)
yield f"data: {json_lib.dumps({'step': 2, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p2_start = time.time()
passages = pipeline.retriever.retrieve(queries)
phase_times['retrieval'] = round(time.time() - _p2_start, 2)
yield f"data: {json_lib.dumps({'step': 2, 'status': 'done'})}\n\n"
time.sleep(0.1)
if len(passages) < 3:
r = {'question': question, 'original_answer': '', 'final_answer': 'Insufficient evidence.', 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': 'Insufficient evidence.', 'processing_time_seconds': round(time.time() - _start_time, 2)}
yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n"
return
yield f"data: {json_lib.dumps({'step': 3, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p3_start = time.time()
original_answer = pipeline.generator.generate(question, passages)
phase_times['generation'] = round(time.time() - _p3_start, 2)
if hasattr(pipeline.generator, 'last_usage'):
u = pipeline.generator.last_usage
token_stats['prompt_tokens'] += u.prompt_tokens
token_stats['completion_tokens'] += u.completion_tokens
token_stats['total_tokens'] += u.total_tokens
yield f"data: {json_lib.dumps({'step': 3, 'status': 'done'})}\n\n"
time.sleep(0.1)
# Send answer_ready event
try:
answer_event = json_lib.dumps({'answer_ready': True, 'answer': original_answer}, ensure_ascii=False)
print(f"[DEBUG] answer_ready event length: {len(answer_event)}")
yield f"data: {answer_event}\n\n"
except Exception as e:
print(f"[ERROR] Failed to send answer_ready: {e}")
yield f"data: {json_lib.dumps({'answer_ready': True, 'answer': 'Error encoding answer'})}\n\n"
yield f"data: {json_lib.dumps({'step': 4, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p4_start = time.time()
try:
co = pipeline.claim_decomposer.decompose(question, original_answer)
claims = co if isinstance(co, list) and len(co) > 0 else [original_answer]
except Exception:
claims = [original_answer]
yield f"data: {json_lib.dumps({'step': 4, 'status': 'done'})}\n\n"
phase_times['decomposition'] = round(time.time() - _p4_start, 2)
time.sleep(0.1)
yield f"data: {json_lib.dumps({'step': 5, 'status': 'active'})}\n\n"
time.sleep(0.1)
_p5_start = time.time()
claim_checks = []
max_risk = 0.0
for claim in claims:
eq = f"{question} {claim}"
cp = pipeline.retriever.retrieve([eq])[:10]
ce = " ".join([p.text for p in cp])[:1500]
nli = pipeline.nli_evaluator.evaluate(claim, [ce])
pf = pipeline.risk_scorer.calculate_profile(claim)
rs = pipeline.risk_scorer.compute_weighted_risk(nli, pf)
max_risk = max(max_risk, rs)
claim_checks.append({"claim": claim, "nli_prob": round(nli, 4), "severity_score": pf.severity, "type_score": pf.type_score, "omission_score": pf.omission, "risk_score": round(rs, 4)})
yield f"data: {json_lib.dumps({'step': 5, 'status': 'done'})}\n\n"
phase_times['verification'] = round(time.time() - _p5_start, 2)
yield f"data: {json_lib.dumps({'step': 6, 'status': 'active'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 6, 'status': 'done'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 7, 'status': 'active'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 7, 'status': 'done'})}\n\n"
time.sleep(0.05)
yield f"data: {json_lib.dumps({'step': 8, 'status': 'active'})}\n\n"
time.sleep(0.1)
is_safe = max_risk < 0.7
fa = original_answer if is_safe else f"WARNING: This answer contains potentially unverified medical information.\n\n{original_answer}"
yield f"data: {json_lib.dumps({'step': 8, 'status': 'done'})}\n\n"
time.sleep(0.1)
ev = [{'text': p.text if hasattr(p, 'text') else str(p), 'qid': p.qid if hasattr(p, 'qid') else ''} for p in passages[:3]]
r = {
'question': question,
'original_answer': original_answer,
'final_answer': fa,
'evidence': ev,
'claims': claims,
'claim_checks': claim_checks,
'max_risk_score': round(max_risk, 4),
'safe': is_safe,
'rejection_message': '',
'processing_time_seconds': round(time.time() - _start_time, 2),
'processing_stats': {
'total_db_size': len(pipeline.retriever._docs),
'queries_generated': len(queries),
'passages_retrieved': len(passages),
'claims_verified': len(claims),
'evidence_per_claim': 10,
'total_evidence_evaluated': len(claims) * 10,
'phase_times': phase_times,
'token_usage': token_stats,
}
}
yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n"
except Exception as e:
yield f"data: {json_lib.dumps({'error': str(e)})}\n\n"
return Response(generate(), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive'})
if __name__ == '__main__':
import os
port = int(os.environ.get('PORT', 7860))
app.run(debug=False, host='0.0.0.0', port=port)