Spaces:
Running
Running
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from flask import Flask, request, jsonify, send_from_directory, Response | |
| import json as json_lib | |
| import os | |
| from src.bio_rag.pipeline import BioRAGPipeline | |
| from src.bio_rag.config import BioRAGConfig | |
| app = Flask(__name__, static_folder='static') | |
| # Load pipeline once at startup | |
| print("Loading Bio-RAG pipeline...") | |
| config = BioRAGConfig() | |
| pipeline = BioRAGPipeline(config) | |
| print("Pipeline ready!") | |
| def index(): | |
| return send_from_directory('static', 'index.html') | |
| def ask(): | |
| try: | |
| data = request.get_json() | |
| question = data.get('question', '').strip() | |
| if not question: | |
| return jsonify({'error': 'No question provided'}), 400 | |
| result = pipeline.ask(question) | |
| return jsonify(result.to_dict()) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def ask_stream(): | |
| data = request.get_json() | |
| question = data.get('question', '').strip() | |
| if not question: | |
| return jsonify({'error': 'No question provided'}), 400 | |
| def generate(): | |
| import time | |
| try: | |
| _start_time = time.time() | |
| phase_times = {} | |
| token_stats = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} | |
| yield f"data: {json_lib.dumps({'step': 0, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| yield f"data: {json_lib.dumps({'step': 0, 'status': 'done'})}\n\n" | |
| time.sleep(0.1) | |
| is_valid, msg = pipeline.query_processor.validate_domain(question) | |
| if not is_valid: | |
| r = {'question': question, 'original_answer': '', 'final_answer': msg, 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': msg, 'processing_time_seconds': round(time.time() - _start_time, 2)} | |
| yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n" | |
| return | |
| yield f"data: {json_lib.dumps({'step': 1, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| _p1_start = time.time() | |
| queries = pipeline.query_processor.expand_queries(question) | |
| phase_times['query_expansion'] = round(time.time() - _p1_start, 2) | |
| yield f"data: {json_lib.dumps({'step': 1, 'status': 'done'})}\n\n" | |
| time.sleep(0.1) | |
| yield f"data: {json_lib.dumps({'step': 2, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| _p2_start = time.time() | |
| passages = pipeline.retriever.retrieve(queries) | |
| phase_times['retrieval'] = round(time.time() - _p2_start, 2) | |
| yield f"data: {json_lib.dumps({'step': 2, 'status': 'done'})}\n\n" | |
| time.sleep(0.1) | |
| if len(passages) < 3: | |
| r = {'question': question, 'original_answer': '', 'final_answer': 'Insufficient evidence.', 'evidence': [], 'claims': [], 'claim_checks': [], 'max_risk_score': 0, 'safe': True, 'rejection_message': 'Insufficient evidence.', 'processing_time_seconds': round(time.time() - _start_time, 2)} | |
| yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n" | |
| return | |
| yield f"data: {json_lib.dumps({'step': 3, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| _p3_start = time.time() | |
| original_answer = pipeline.generator.generate(question, passages) | |
| phase_times['generation'] = round(time.time() - _p3_start, 2) | |
| if hasattr(pipeline.generator, 'last_usage'): | |
| u = pipeline.generator.last_usage | |
| token_stats['prompt_tokens'] += u.prompt_tokens | |
| token_stats['completion_tokens'] += u.completion_tokens | |
| token_stats['total_tokens'] += u.total_tokens | |
| yield f"data: {json_lib.dumps({'step': 3, 'status': 'done'})}\n\n" | |
| time.sleep(0.1) | |
| # Send answer_ready event | |
| try: | |
| answer_event = json_lib.dumps({'answer_ready': True, 'answer': original_answer}, ensure_ascii=False) | |
| print(f"[DEBUG] answer_ready event length: {len(answer_event)}") | |
| yield f"data: {answer_event}\n\n" | |
| except Exception as e: | |
| print(f"[ERROR] Failed to send answer_ready: {e}") | |
| yield f"data: {json_lib.dumps({'answer_ready': True, 'answer': 'Error encoding answer'})}\n\n" | |
| yield f"data: {json_lib.dumps({'step': 4, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| _p4_start = time.time() | |
| try: | |
| co = pipeline.claim_decomposer.decompose(question, original_answer) | |
| claims = co if isinstance(co, list) and len(co) > 0 else [original_answer] | |
| except Exception: | |
| claims = [original_answer] | |
| yield f"data: {json_lib.dumps({'step': 4, 'status': 'done'})}\n\n" | |
| phase_times['decomposition'] = round(time.time() - _p4_start, 2) | |
| time.sleep(0.1) | |
| yield f"data: {json_lib.dumps({'step': 5, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| _p5_start = time.time() | |
| claim_checks = [] | |
| max_risk = 0.0 | |
| for claim in claims: | |
| eq = f"{question} {claim}" | |
| cp = pipeline.retriever.retrieve([eq])[:10] | |
| ce = " ".join([p.text for p in cp])[:1500] | |
| nli = pipeline.nli_evaluator.evaluate(claim, [ce]) | |
| pf = pipeline.risk_scorer.calculate_profile(claim) | |
| rs = pipeline.risk_scorer.compute_weighted_risk(nli, pf) | |
| max_risk = max(max_risk, rs) | |
| claim_checks.append({"claim": claim, "nli_prob": round(nli, 4), "severity_score": pf.severity, "type_score": pf.type_score, "omission_score": pf.omission, "risk_score": round(rs, 4)}) | |
| yield f"data: {json_lib.dumps({'step': 5, 'status': 'done'})}\n\n" | |
| phase_times['verification'] = round(time.time() - _p5_start, 2) | |
| yield f"data: {json_lib.dumps({'step': 6, 'status': 'active'})}\n\n" | |
| time.sleep(0.05) | |
| yield f"data: {json_lib.dumps({'step': 6, 'status': 'done'})}\n\n" | |
| time.sleep(0.05) | |
| yield f"data: {json_lib.dumps({'step': 7, 'status': 'active'})}\n\n" | |
| time.sleep(0.05) | |
| yield f"data: {json_lib.dumps({'step': 7, 'status': 'done'})}\n\n" | |
| time.sleep(0.05) | |
| yield f"data: {json_lib.dumps({'step': 8, 'status': 'active'})}\n\n" | |
| time.sleep(0.1) | |
| is_safe = max_risk < 0.7 | |
| fa = original_answer if is_safe else f"WARNING: This answer contains potentially unverified medical information.\n\n{original_answer}" | |
| yield f"data: {json_lib.dumps({'step': 8, 'status': 'done'})}\n\n" | |
| time.sleep(0.1) | |
| ev = [{'text': p.text if hasattr(p, 'text') else str(p), 'qid': p.qid if hasattr(p, 'qid') else ''} for p in passages[:3]] | |
| r = { | |
| 'question': question, | |
| 'original_answer': original_answer, | |
| 'final_answer': fa, | |
| 'evidence': ev, | |
| 'claims': claims, | |
| 'claim_checks': claim_checks, | |
| 'max_risk_score': round(max_risk, 4), | |
| 'safe': is_safe, | |
| 'rejection_message': '', | |
| 'processing_time_seconds': round(time.time() - _start_time, 2), | |
| 'processing_stats': { | |
| 'total_db_size': len(pipeline.retriever._docs), | |
| 'queries_generated': len(queries), | |
| 'passages_retrieved': len(passages), | |
| 'claims_verified': len(claims), | |
| 'evidence_per_claim': 10, | |
| 'total_evidence_evaluated': len(claims) * 10, | |
| 'phase_times': phase_times, | |
| 'token_usage': token_stats, | |
| } | |
| } | |
| yield f"data: {json_lib.dumps({'complete': True, 'result': r})}\n\n" | |
| except Exception as e: | |
| yield f"data: {json_lib.dumps({'error': str(e)})}\n\n" | |
| return Response(generate(), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive'}) | |
| if __name__ == '__main__': | |
| import os | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(debug=False, host='0.0.0.0', port=port) | |