from dotenv import load_dotenv load_dotenv() from flask import Flask, request, jsonify, send_from_directory, Response import json as json_lib import os from src.bio_rag.pipeline import BioRAGPipeline from src.bio_rag.config import BioRAGConfig app = Flask(__name__, static_folder='static') # Load pipeline once at startup print("Loading Bio-RAG pipeline...") config = BioRAGConfig() pipeline = BioRAGPipeline(config) print("Pipeline ready!") @app.route('/') def index(): return send_from_directory('static', 'index.html') @app.route('/api/ask', methods=['POST']) def ask(): try: data = request.get_json() question = data.get('question', '').strip() if not question: return jsonify({'error': 'No question provided'}), 400 result = pipeline.ask(question) return jsonify(result.to_dict()) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/ask-stream', methods=['POST']) def ask_stream(): data = request.get_json() question = data.get('question', '').strip() if not question: return jsonify({'error': 'No question provided'}), 400 def generate(): import time try: _start_time = time.time() phase_times = {} token_stats = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} yield f"data: {json_lib.dumps({'step': 0, 'status': 'active'})}\n\n" time.sleep(0.1) yield f"data: {json_lib.dumps({'step': 0, 'status': 'done'})}\n\n" time.sleep(0.1) is_valid, msg = pipeline.query_processor.validate_domain(question) if not is_valid: r = {'question': question, 'original_answer': '', 'final_answer': msg, 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': msg, 'processing_time_seconds': round(time.time() - _start_time, 2)} yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n" return yield f"data: {json_lib.dumps({'step': 1, 'status': 'active'})}\n\n" time.sleep(0.1) _p1_start = time.time() queries = pipeline.query_processor.expand_queries(question) phase_times['query_expansion'] = round(time.time() - _p1_start, 2) yield f"data: {json_lib.dumps({'step': 1, 'status': 'done'})}\n\n" time.sleep(0.1) yield f"data: {json_lib.dumps({'step': 2, 'status': 'active'})}\n\n" time.sleep(0.1) _p2_start = time.time() passages = pipeline.retriever.retrieve(queries) phase_times['retrieval'] = round(time.time() - _p2_start, 2) yield f"data: {json_lib.dumps({'step': 2, 'status': 'done'})}\n\n" time.sleep(0.1) if len(passages) < 3: r = {'question': question, 'original_answer': '', 'final_answer': 'Insufficient evidence.', 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': 'Insufficient evidence.', 'processing_time_seconds': round(time.time() - _start_time, 2)} yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n" return yield f"data: {json_lib.dumps({'step': 3, 'status': 'active'})}\n\n" time.sleep(0.1) _p3_start = time.time() original_answer = pipeline.generator.generate(question, passages) phase_times['generation'] = round(time.time() - _p3_start, 2) if hasattr(pipeline.generator, 'last_usage'): u = pipeline.generator.last_usage token_stats['prompt_tokens'] += u.prompt_tokens token_stats['completion_tokens'] += u.completion_tokens token_stats['total_tokens'] += u.total_tokens yield f"data: {json_lib.dumps({'step': 3, 'status': 'done'})}\n\n" time.sleep(0.1) # Send answer_ready event try: answer_event = json_lib.dumps({'answer_ready': True, 'answer': original_answer}, ensure_ascii=False) print(f"[DEBUG] answer_ready event length: {len(answer_event)}") yield f"data: {answer_event}\n\n" except Exception as e: print(f"[ERROR] Failed to send answer_ready: {e}") yield f"data: {json_lib.dumps({'answer_ready': True, 'answer': 'Error encoding answer'})}\n\n" yield f"data: {json_lib.dumps({'step': 4, 'status': 'active'})}\n\n" time.sleep(0.1) _p4_start = time.time() try: co = pipeline.claim_decomposer.decompose(question, original_answer) claims = co if isinstance(co, list) and len(co) > 0 else [original_answer] except Exception: claims = [original_answer] yield f"data: {json_lib.dumps({'step': 4, 'status': 'done'})}\n\n" phase_times['decomposition'] = round(time.time() - _p4_start, 2) time.sleep(0.1) yield f"data: {json_lib.dumps({'step': 5, 'status': 'active'})}\n\n" time.sleep(0.1) _p5_start = time.time() claim_checks = [] max_risk = 0.0 for claim in claims: eq = f"{question} {claim}" cp = pipeline.retriever.retrieve([eq])[:10] ce = " ".join([p.text for p in cp])[:1500] nli = pipeline.nli_evaluator.evaluate(claim, [ce]) pf = pipeline.risk_scorer.calculate_profile(claim) rs = pipeline.risk_scorer.compute_weighted_risk(nli, pf) max_risk = max(max_risk, rs) claim_checks.append({"claim": claim, "nli_prob": round(nli, 4), "severity_score": pf.severity, "type_score": pf.type_score, "omission_score": pf.omission, "risk_score": round(rs, 4)}) yield f"data: {json_lib.dumps({'step': 5, 'status': 'done'})}\n\n" phase_times['verification'] = round(time.time() - _p5_start, 2) yield f"data: {json_lib.dumps({'step': 6, 'status': 'active'})}\n\n" time.sleep(0.05) yield f"data: {json_lib.dumps({'step': 6, 'status': 'done'})}\n\n" time.sleep(0.05) yield f"data: {json_lib.dumps({'step': 7, 'status': 'active'})}\n\n" time.sleep(0.05) yield f"data: {json_lib.dumps({'step': 7, 'status': 'done'})}\n\n" time.sleep(0.05) yield f"data: {json_lib.dumps({'step': 8, 'status': 'active'})}\n\n" time.sleep(0.1) is_safe = max_risk < 0.7 fa = original_answer if is_safe else f"WARNING: This answer contains potentially unverified medical information.\n\n{original_answer}" yield f"data: {json_lib.dumps({'step': 8, 'status': 'done'})}\n\n" time.sleep(0.1) ev = [{'text': p.text if hasattr(p, 'text') else str(p), 'qid': p.qid if hasattr(p, 'qid') else ''} for p in passages[:3]] r = { 'question': question, 'original_answer': original_answer, 'final_answer': fa, 'evidence': ev, 'claims': claims, 'claim_checks': claim_checks, 'max_risk_score': round(max_risk, 4), 'safe': is_safe, 'rejection_message': '', 'processing_time_seconds': round(time.time() - _start_time, 2), 'processing_stats': { 'total_db_size': len(pipeline.retriever._docs), 'queries_generated': len(queries), 'passages_retrieved': len(passages), 'claims_verified': len(claims), 'evidence_per_claim': 10, 'total_evidence_evaluated': len(claims) * 10, 'phase_times': phase_times, 'token_usage': token_stats, } } yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n" except Exception as e: yield f"data: {json_lib.dumps({'error': str(e)})}\n\n" return Response(generate(), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive'}) if __name__ == '__main__': import os port = int(os.environ.get('PORT', 7860)) app.run(debug=False, host='0.0.0.0', port=port)