| import logging |
| import sys |
| import os |
|
|
| |
| sys.path.append(os.getcwd()) |
|
|
| from rag_engine.retriever import OncoRAGRetriever |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| def test_sota_retrieval(): |
| |
| |
| retriever = OncoRAGRetriever( |
| db_path="data/chroma_db", |
| collection_name="clinical_guidelines", |
| distance_threshold=0.5 |
| ) |
| |
| |
| logger.info("\n--- TEST 1: Genomic Query ---") |
| results_genomic = retriever.query("Patient has BRAF V600E mutation. What are the evidence-based treatments?") |
| for i, res in enumerate(results_genomic): |
| print(f"[{i+1}] Source: {res['source']} | Type: {res.get('type', 'Standard')}") |
| print(f"Content: {res['text'][:200]}...") |
| |
| |
| logger.info("\n--- TEST 2: Clinical Trial Query ---") |
| results_trials = retriever.query("Search for recruiting trials for Non-Small Cell Lung Cancer.") |
| for i, res in enumerate(results_trials): |
| print(f"[{i+1}] Source: {res['source']} | Type: {res.get('type', 'Standard')}") |
| print(f"Content: {res['text'][:200]}...") |
|
|
| |
| logger.info("\n--- TEST 3: Graph Search Query ---") |
| |
| results_graph = retriever.query("Explain the relation between osimertinib and egfr in nsclc.") |
| for i, res in enumerate(results_graph): |
| print(f"[{i+1}] Source: {res['source']} | Type: {res.get('type', 'Standard')}") |
| print(f"Content: {res['text'][:200]}...") |
|
|
| if __name__ == "__main__": |
| test_sota_retrieval() |
|
|