File size: 5,363 Bytes
aea514c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d53317c
 
 
 
 
aea514c
25c227b
aea514c
25c227b
aea514c
 
 
 
 
 
 
 
84aab3c
 
2ac2fa7
 
 
aea514c
 
 
 
 
 
 
 
df6d691
aea514c
 
 
 
 
 
 
df6d691
aea514c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df6d691
aea514c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8140c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_openai import ChatOpenAI
import numpy as np
from sentence_transformers import CrossEncoder
from dotenv import load_dotenv
import streamlit as st
from datasets import load_dataset
import os
import pickle
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore  # Add this import
import time

load_dotenv()

def get_vector_store():
    """Load vectorstore from pre-computed embeddings"""
    
    try:
        # Load pre-computed data

        if not os.path.exists('src/medical_embeddings.npy'):
            raise FileNotFoundError("medical_embeddings.npy not found")
        if not os.path.exists('src/medical_texts.pkl'):
            raise FileNotFoundError("medical_texts.pkl not found")
        print("πŸ“₯ Loading pre-computed embeddings...")
        embeddings_array = np.load('src/medical_embeddings.npy')
        
        with open('src/medical_texts.pkl', 'rb') as f:
            texts = pickle.load(f)
        
        print(f"βœ… Loaded {len(embeddings_array)} pre-computed embeddings")
        
        # Create FAISS index from pre-computed embeddings
        dimension = embeddings_array.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings_array.astype('float32')) # type: ignore


        # import os
        # os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp'
        # os.makedirs('/tmp', exist_ok=True)
        
        # Create embedding function for new queries
        embeddings_function = HuggingFaceEmbeddings(
            model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
        )
        
        # Create proper Document objects and InMemoryDocstore
        documents_dict = {}
        documents = []
        for i, text in enumerate(texts):
            # Create Document objects with proper metadata
            doc = Document(
                page_content=text,
                metadata={"doc_id": i, "type": "medical_qa"}
            )
            documents_dict[str(i)] = doc
            documents.append(doc)
        
        # Create proper docstore
        docstore = InMemoryDocstore(documents_dict)
        
        # Create index to docstore mapping
        index_to_docstore_id = {i: str(i) for i in range(len(texts))}
        
        # Create FAISS vectorstore with proper parameters
        vectorstore = FAISS(
            embedding_function=embeddings_function,
            index=index,
            docstore=docstore,
            index_to_docstore_id=index_to_docstore_id
        )
        
        return vectorstore, documents
        
    except FileNotFoundError as e:
        print(f"❌ Pre-computed files not found: {e}")
        print("πŸ”„ Falling back to creating embeddings...")
        return None, None
    
    except Exception as e:
        print(f"❌ Error loading pre-computed embeddings: {e}")
        print("πŸ”„ Falling back to creating embeddings...")
        return None, None


@st.cache_resource
def load_medical_system():
    """Load the medical RAG system (cached for performance)"""
    
    with st.spinner("πŸ”„ Loading medical knowledge base..."):
        # Load dataset
        # ds = load_dataset("keivalya/MedQuad-MedicalQnADataset")
        
        # # Create documents
        # documents = []
        # for i, item in enumerate(ds['train']): # type: ignore
        #     content = f"Question: {item['Question']}\nAnswer: {item['Answer']}" # type: ignore
        #     metadata = {
        #         "doc_id": i,
        #         "question": item['Question'], # type: ignore
        #         "answer": item['Answer'], # type: ignore
        #         "question_type": item['qtype'], # type: ignore
        #         "type": "qa_pair"
        #     }
        #     documents.append(Document(page_content=content, metadata=metadata))
        
        
        start = time.time()
        # Try to load existing vectorstore
        vectorstore, documents = get_vector_store()
        end = time.time()

        if vectorstore is None or documents is None:
            st.error("❌ Could not load the vectorstore. Please ensure the embeddings and text files exist.")
            st.stop()

        total_time = end - start
        st.success(f"βœ… Loaded existing vectorstore in {total_time:.2f} seconds")
        
        # Create retrievers
        bm25_retriever = BM25Retriever.from_documents(documents)
        vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
        
        ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, vector_retriever],
            weights=[0.3, 0.7]
        )
        
        # create LLM
        openai_key = os.getenv("OPENAI_API_KEY")
        if not openai_key:
            st.error("❌ OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml")
            st.stop()
        llm = ChatOpenAI(temperature=0, api_key=openai_key) # type: ignore
        
        # Create reranker
        reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        
        return documents, ensemble_retriever, llm, reranker