Ahsan-Asim commited on
Commit
4e2451c
Β·
1 Parent(s): e1cc9d4
Files changed (1) hide show
  1. app.py +155 -31
app.py CHANGED
@@ -1,41 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import faiss
3
  import pickle
4
  import numpy as np
5
  import torch
 
6
  from transformers import T5Tokenizer, T5ForConditionalGeneration
7
  from sentence_transformers import SentenceTransformer
8
 
9
- # Load LLM model (local folder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @st.cache_resource
11
  def load_llm():
12
- model_path = "./Generator_Model"
13
- tokenizer = T5Tokenizer.from_pretrained(model_path)
14
- model = T5ForConditionalGeneration.from_pretrained(model_path)
15
  return tokenizer, model
16
 
17
- # Load embedding model (local folder)
18
  @st.cache_resource
19
  def load_embedding_model():
20
- embed_model_path = "./Embedding_Model1"
21
- return SentenceTransformer(embed_model_path)
22
 
23
  # Load FAISS index and embeddings
24
  @st.cache_resource
25
  def load_faiss():
26
- faiss_index = faiss.read_index("faiss_index_file.index")
27
- data = np.load("embeddings_file.npy", allow_pickle=True)
28
- return faiss_index, data
29
-
 
30
 
31
- # Search function
32
- def search(query, embed_model, index, data):
33
  query_embedding = embed_model.encode([query]).astype('float32')
34
- _, I = index.search(query_embedding, k=5) # Top 5 results
35
- results = [data['texts'][i] for i in I[0] if i != -1]
36
  return results
37
 
38
- # Generate response using LLM
39
  def generate_response(context, query, tokenizer, model):
40
  input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
41
  inputs = tokenizer.encode(input_text, return_tensors="pt")
@@ -43,33 +165,35 @@ def generate_response(context, query, tokenizer, model):
43
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
  return response
45
 
46
- # Streamlit App
47
  def main():
48
- st.title("Local LLM + FAISS + Embedding Search App")
49
- st.markdown("πŸ” Ask a question, and get context-aware answers!")
50
 
51
- # Load everything once
 
 
 
 
 
 
 
 
52
  tokenizer, llm_model = load_llm()
53
  embed_model = load_embedding_model()
54
- faiss_index, data = load_faiss()
55
 
56
- query = st.text_input("Enter your query:")
57
 
58
  if query:
59
- with st.spinner("Processing..."):
60
- # Search relevant contexts
61
  contexts = search(query, embed_model, faiss_index, data)
62
  combined_context = " ".join(contexts)
63
-
64
- # Generate answer
65
  response = generate_response(combined_context, query, tokenizer, llm_model)
66
 
67
- st.subheader("Response:")
 
68
  st.write(response)
69
 
70
- st.subheader("Top Retrieved Contexts:")
71
- for idx, ctx in enumerate(contexts, 1):
72
- st.markdown(f"**{idx}.** {ctx}")
73
-
74
  if __name__ == "__main__":
75
  main()
 
1
+ # import streamlit as st
2
+ # import faiss
3
+ # import pickle
4
+ # import numpy as np
5
+ # import torch
6
+ # from transformers import T5Tokenizer, T5ForConditionalGeneration
7
+ # from sentence_transformers import SentenceTransformer
8
+
9
+ # # Load LLM model (local folder)
10
+ # @st.cache_resource
11
+ # def load_llm():
12
+ # model_path = "./Generator_Model"
13
+ # tokenizer = T5Tokenizer.from_pretrained(model_path)
14
+ # model = T5ForConditionalGeneration.from_pretrained(model_path)
15
+ # return tokenizer, model
16
+
17
+ # # Load embedding model (local folder)
18
+ # @st.cache_resource
19
+ # def load_embedding_model():
20
+ # embed_model_path = "./Embedding_Model1"
21
+ # return SentenceTransformer(embed_model_path)
22
+
23
+ # # Load FAISS index and embeddings
24
+ # @st.cache_resource
25
+ # def load_faiss():
26
+ # faiss_index = faiss.read_index("faiss_index_file.index")
27
+ # data = np.load("embeddings_file.npy", allow_pickle=True)
28
+ # return faiss_index, data
29
+
30
+
31
+ # # Search function
32
+ # def search(query, embed_model, index, data):
33
+ # query_embedding = embed_model.encode([query]).astype('float32')
34
+ # _, I = index.search(query_embedding, k=5) # Top 5 results
35
+ # results = [data['texts'][i] for i in I[0] if i != -1]
36
+ # return results
37
+
38
+ # # Generate response using LLM
39
+ # def generate_response(context, query, tokenizer, model):
40
+ # input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
41
+ # inputs = tokenizer.encode(input_text, return_tensors="pt")
42
+ # outputs = model.generate(inputs, max_length=512, do_sample=True, temperature=0.7)
43
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+ # return response
45
+
46
+ # # Streamlit App
47
+ # def main():
48
+ # st.title("Local LLM + FAISS + Embedding Search App")
49
+ # st.markdown("πŸ” Ask a question, and get context-aware answers!")
50
+
51
+ # # Load everything once
52
+ # tokenizer, llm_model = load_llm()
53
+ # embed_model = load_embedding_model()
54
+ # faiss_index, data = load_faiss()
55
+
56
+ # query = st.text_input("Enter your query:")
57
+
58
+ # if query:
59
+ # with st.spinner("Processing..."):
60
+ # # Search relevant contexts
61
+ # contexts = search(query, embed_model, faiss_index, data)
62
+ # combined_context = " ".join(contexts)
63
+
64
+ # # Generate answer
65
+ # response = generate_response(combined_context, query, tokenizer, llm_model)
66
+
67
+ # st.subheader("Response:")
68
+ # st.write(response)
69
+
70
+ # st.subheader("Top Retrieved Contexts:")
71
+ # for idx, ctx in enumerate(contexts, 1):
72
+ # st.markdown(f"**{idx}.** {ctx}")
73
+
74
+ # if __name__ == "__main__":
75
+ # main()
76
+
77
+
78
+
79
+ ###########################
80
+ import os
81
  import streamlit as st
82
  import faiss
83
  import pickle
84
  import numpy as np
85
  import torch
86
+ import gdown
87
  from transformers import T5Tokenizer, T5ForConditionalGeneration
88
  from sentence_transformers import SentenceTransformer
89
 
90
+ # Function to download a full folder from Google Drive
91
+ def download_folder_from_google_drive(folder_url, output_path):
92
+ if not os.path.exists(output_path):
93
+ gdown.download_folder(url=folder_url, output=output_path, quiet=False, use_cookies=False)
94
+
95
+ # Download individual files
96
+ def download_file_from_google_drive(file_id, destination):
97
+ if not os.path.exists(destination):
98
+ url = f"https://drive.google.com/uc?id={file_id}"
99
+ gdown.download(url, destination, quiet=False)
100
+
101
+ # Setup models and files
102
+ @st.cache_resource
103
+ def setup_files():
104
+ os.makedirs("models/embedding_model", exist_ok=True)
105
+ os.makedirs("models/generator_model", exist_ok=True)
106
+ os.makedirs("models/files", exist_ok=True)
107
+
108
+ # Download embedding model (folder)
109
+ download_folder_from_google_drive(
110
+ "https://drive.google.com/drive/folders/1GzPk2ehr7rzOr65Am1Hg3A87FOTNHLAM?usp=sharing",
111
+ "models/embedding_model"
112
+ )
113
+
114
+ # Download generator model (folder)
115
+ download_folder_from_google_drive(
116
+ "https://drive.google.com/drive/folders/1338KWiBE-6sWsTO2iH7Pgu8eRI7EE7Vr?usp=sharing",
117
+ "models/generator_model"
118
+ )
119
+
120
+ # Download FAISS index, texts.pkl, embeddings.npy
121
+ download_file_from_google_drive("11J_VI1buTgnvhoP3z2HM6X5aPzbBO2ed", "models/files/faiss_index_file.index")
122
+ download_file_from_google_drive("1RTEwp8xDgxLnRUiy7ClTskFuTu0GtWBT", "models/files/texts.pkl")
123
+ download_file_from_google_drive("1N54imsqJIJGeqM3buiRzp1ivK_BtC7rR", "models/files/embeddings.npy")
124
+
125
+ # Paths
126
+ EMBEDDING_MODEL_PATH = "models/embedding_model"
127
+ GENERATOR_MODEL_PATH = "models/generator_model"
128
+ FAISS_INDEX_PATH = "models/files/faiss_index_file.index"
129
+ TEXTS_PATH = "models/files/texts.pkl"
130
+ EMBEDDINGS_PATH = "models/files/embeddings.npy"
131
+
132
+ # Load LLM model (Generator model)
133
  @st.cache_resource
134
  def load_llm():
135
+ tokenizer = T5Tokenizer.from_pretrained(GENERATOR_MODEL_PATH)
136
+ model = T5ForConditionalGeneration.from_pretrained(GENERATOR_MODEL_PATH)
 
137
  return tokenizer, model
138
 
139
+ # Load embedding model
140
  @st.cache_resource
141
  def load_embedding_model():
142
+ return SentenceTransformer(EMBEDDING_MODEL_PATH)
 
143
 
144
  # Load FAISS index and embeddings
145
  @st.cache_resource
146
  def load_faiss():
147
+ faiss_index = faiss.read_index(FAISS_INDEX_PATH)
148
+ with open(TEXTS_PATH, "rb") as f:
149
+ data = pickle.load(f)
150
+ embeddings = np.load(EMBEDDINGS_PATH, allow_pickle=True)
151
+ return faiss_index, data, embeddings
152
 
153
+ # Search top-k contexts
154
+ def search(query, embed_model, index, data, k=5):
155
  query_embedding = embed_model.encode([query]).astype('float32')
156
+ _, I = index.search(query_embedding, k)
157
+ results = [data[i] for i in I[0] if i != -1]
158
  return results
159
 
160
+ # Generate response
161
  def generate_response(context, query, tokenizer, model):
162
  input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
163
  inputs = tokenizer.encode(input_text, return_tensors="pt")
 
165
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
166
  return response
167
 
168
+ # Streamlit app
169
  def main():
170
+ st.set_page_config(page_title="Clinical QA with RAG", page_icon="🩺")
171
+ st.title("πŸ”Ž Clinical QA System (RAG + FAISS + T5)")
172
 
173
+ st.markdown(
174
+ """
175
+ Enter your **clinical question** below.
176
+ The system will retrieve relevant context and generate an informed answer using a local model. πŸš€
177
+ """
178
+ )
179
+
180
+ # Download + Load everything
181
+ setup_files()
182
  tokenizer, llm_model = load_llm()
183
  embed_model = load_embedding_model()
184
+ faiss_index, data, embeddings = load_faiss()
185
 
186
+ query = st.text_input("πŸ’¬ Your Question:")
187
 
188
  if query:
189
+ with st.spinner("πŸ” Retrieving and Generating..."):
 
190
  contexts = search(query, embed_model, faiss_index, data)
191
  combined_context = " ".join(contexts)
 
 
192
  response = generate_response(combined_context, query, tokenizer, llm_model)
193
 
194
+ st.success("βœ… Answer Ready!")
195
+ st.subheader("πŸ“„ Response:")
196
  st.write(response)
197
 
 
 
 
 
198
  if __name__ == "__main__":
199
  main()