ghuman7 commited on
Commit
1001445
·
verified ·
1 Parent(s): c31f3b6

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  MentalHealth_RAG/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
37
  MentalHealth_RAG/MentalHealth/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  MentalHealth_RAG/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
37
  MentalHealth_RAG/MentalHealth/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
38
+ data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
39
+ MentalHealth/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
Evaluation_MH/.ipynb_checkpoints/Evaluation-checkpoint.ipynb ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f7b87c2c",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# pip install -U sentence-transformers"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "8a7cc9d8",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
32
+ " from tqdm.autonotebook import tqdm, trange\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "from sentence_transformers import SentenceTransformer\n",
38
+ "from langchain.prompts import PromptTemplate\n",
39
+ "from langchain.chains import LLMChain\n",
40
+ "from langchain_community.llms import Ollama\n",
41
+ "from langchain.evaluation import load_evaluator\n",
42
+ "import faiss\n",
43
+ "import pandas as pd\n",
44
+ "import numpy as np\n",
45
+ "import pickle\n",
46
+ "import time\n",
47
+ "from tqdm import tqdm"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "id": "b6efca1d",
53
+ "metadata": {},
54
+ "source": [
55
+ "# Intialization"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 2,
61
+ "id": "cc9a49d2",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# Load the FAISS index\n",
66
+ "index = faiss.read_index(\"database/pdf_sections_index.faiss\")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "id": "9af39b55",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "model = SentenceTransformer('all-MiniLM-L6-v2')"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 4,
82
+ "id": "fee8cdfd",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "with open('database/pdf_sections_data.pkl', 'rb') as f:\n",
87
+ " sections_data = pickle.load(f)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "id": "d6a1ba6a",
93
+ "metadata": {},
94
+ "source": [
95
+ "# RAG functions"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 5,
101
+ "id": "182bdbd8",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def search_faiss(query, k=3):\n",
106
+ " query_vector = model.encode([query])[0].astype('float32')\n",
107
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
108
+ " distances, indices = index.search(query_vector, k)\n",
109
+ " \n",
110
+ " results = []\n",
111
+ " for dist, idx in zip(distances[0], indices[0]):\n",
112
+ " results.append({\n",
113
+ " 'distance': dist,\n",
114
+ " 'content': sections_data[idx]['content'],\n",
115
+ " 'metadata': sections_data[idx]['metadata']\n",
116
+ " })\n",
117
+ " \n",
118
+ " return results"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 15,
124
+ "id": "67edc46a",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Create a prompt template\n",
129
+ "prompt_template = \"\"\"\n",
130
+ "You are an AI assistant specialized in Mental Health guidelines. \n",
131
+ "Use the following pieces of context to answer the question. \n",
132
+ "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
133
+ "\n",
134
+ "Context:\n",
135
+ "{context}\n",
136
+ "\n",
137
+ "Question: {question}\n",
138
+ "\n",
139
+ "Answer:\"\"\"\n",
140
+ "\n",
141
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
142
+ "\n",
143
+ "llm = Ollama(\n",
144
+ " model=\"llama3\"\n",
145
+ ")\n",
146
+ "\n",
147
+ "# Create the chain\n",
148
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
149
+ "\n",
150
+ "def answer_question(query):\n",
151
+ " # Search for relevant context\n",
152
+ " search_results = search_faiss(query)\n",
153
+ " \n",
154
+ " # Combine the content from the search results\n",
155
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
156
+ "\n",
157
+ " # Run the chain\n",
158
+ " response = chain.run(context=context, question=query)\n",
159
+ " \n",
160
+ " return response"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "3b176af9",
166
+ "metadata": {},
167
+ "source": [
168
+ "# Reading GT"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 16,
174
+ "id": "4ab68dff",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "df = pd.read_csv('data/MentalHealth_Dataset.csv')"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 17,
184
+ "id": "4e7e22d7",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stderr",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "time_list=[]\n",
197
+ "response_list=[]\n",
198
+ "for i in tqdm(range(len(df))):\n",
199
+ " query = df['Questions'].values[i]\n",
200
+ " start = time.time()\n",
201
+ " response = answer_question(query)\n",
202
+ " end = time.time() \n",
203
+ " time_list.append(end-start)\n",
204
+ " response_list.append(response)"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 18,
210
+ "id": "2b327e90",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "df['latency'] = time_list\n",
215
+ "df['response'] = response_list"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "3c147204",
221
+ "metadata": {},
222
+ "source": [
223
+ "# Evaluation"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 29,
229
+ "id": "d799e541",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "eval_llm = Ollama(\n",
234
+ " model=\"phi3\"\n",
235
+ ")"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 30,
241
+ "id": "c2f788dc",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "metrics = ['correctness', 'relevance', 'coherence', 'conciseness']"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 31,
251
+ "id": "83ec2b8d",
252
+ "metadata": {},
253
+ "outputs": [
254
+ {
255
+ "name": "stderr",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n",
259
+ "100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n",
260
+ "100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n",
261
+ "100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "for metric in metrics:\n",
267
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
268
+ " \n",
269
+ " reasoning = []\n",
270
+ " value = []\n",
271
+ " score = []\n",
272
+ " \n",
273
+ " for i in tqdm(range(len(df))):\n",
274
+ " eval_result = evaluator.evaluate_strings(\n",
275
+ " prediction=df.response.values[i],\n",
276
+ " input=df.Questions.values[i],\n",
277
+ " reference=df.Answers.values[i]\n",
278
+ " )\n",
279
+ " reasoning.append(eval_result['reasoning'])\n",
280
+ " value.append(eval_result['value'])\n",
281
+ " score.append(eval_result['score'])\n",
282
+ " \n",
283
+ " df[metric+'_reasoning'] = reasoning\n",
284
+ " df[metric+'_value'] = value\n",
285
+ " df[metric+'_score'] = score "
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 78,
291
+ "id": "f1673a31",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "<div>\n",
298
+ "<style scoped>\n",
299
+ " .dataframe tbody tr th:only-of-type {\n",
300
+ " vertical-align: middle;\n",
301
+ " }\n",
302
+ "\n",
303
+ " .dataframe tbody tr th {\n",
304
+ " vertical-align: top;\n",
305
+ " }\n",
306
+ "\n",
307
+ " .dataframe thead th {\n",
308
+ " text-align: right;\n",
309
+ " }\n",
310
+ "</style>\n",
311
+ "<table border=\"1\" class=\"dataframe\">\n",
312
+ " <thead>\n",
313
+ " <tr style=\"text-align: right;\">\n",
314
+ " <th></th>\n",
315
+ " <th>Questions</th>\n",
316
+ " <th>Answers</th>\n",
317
+ " <th>latency</th>\n",
318
+ " <th>response</th>\n",
319
+ " <th>correctness_reasoning</th>\n",
320
+ " <th>correctness_value</th>\n",
321
+ " <th>correctness_score</th>\n",
322
+ " <th>relevance_reasoning</th>\n",
323
+ " <th>relevance_value</th>\n",
324
+ " <th>relevance_score</th>\n",
325
+ " <th>coherence_reasoning</th>\n",
326
+ " <th>coherence_value</th>\n",
327
+ " <th>coherence_score</th>\n",
328
+ " <th>conciseness_reasoning</th>\n",
329
+ " <th>conciseness_value</th>\n",
330
+ " <th>conciseness_score</th>\n",
331
+ " </tr>\n",
332
+ " </thead>\n",
333
+ " <tbody>\n",
334
+ " <tr>\n",
335
+ " <th>0</th>\n",
336
+ " <td>What is Mental Health</td>\n",
337
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
338
+ " <td>11.974234</td>\n",
339
+ " <td>Based on the provided context, specifically fr...</td>\n",
340
+ " <td>The submission refers to the provided input wh...</td>\n",
341
+ " <td>Y</td>\n",
342
+ " <td>1</td>\n",
343
+ " <td>Step 1: Evaluate relevance criterion\\nThe subm...</td>\n",
344
+ " <td>Y</td>\n",
345
+ " <td>1</td>\n",
346
+ " <td>Step 1: Assess coherence\\nThe submission direc...</td>\n",
347
+ " <td>Y</td>\n",
348
+ " <td>1</td>\n",
349
+ " <td>1. The submission directly answers the questio...</td>\n",
350
+ " <td>Y</td>\n",
351
+ " <td>1</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <th>1</th>\n",
355
+ " <td>What are the most common mental disorders ment...</td>\n",
356
+ " <td>The most common mental disorders include depre...</td>\n",
357
+ " <td>5.863329</td>\n",
358
+ " <td>Based on the provided context, the mental diso...</td>\n",
359
+ " <td>Step 1: Check if the submission is factually a...</td>\n",
360
+ " <td>Y</td>\n",
361
+ " <td>1</td>\n",
362
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
363
+ " <td>Y</td>\n",
364
+ " <td>1</td>\n",
365
+ " <td>The submission begins with an appropriate ques...</td>\n",
366
+ " <td>Y</td>\n",
367
+ " <td>1</td>\n",
368
+ " <td>Step 1: Review conciseness criterion\\nThe subm...</td>\n",
369
+ " <td>Y</td>\n",
370
+ " <td>1</td>\n",
371
+ " </tr>\n",
372
+ " <tr>\n",
373
+ " <th>2</th>\n",
374
+ " <td>What are the early warning signs and symptoms ...</td>\n",
375
+ " <td>Early warning signs and symptoms of depression...</td>\n",
376
+ " <td>13.434543</td>\n",
377
+ " <td>Based on the provided context, I found a refer...</td>\n",
378
+ " <td>Step 1: Evaluate Correctness\\nThe submission a...</td>\n",
379
+ " <td>Y</td>\n",
380
+ " <td>1</td>\n",
381
+ " <td>Step 1: Identify the relevant criterion from t...</td>\n",
382
+ " <td>Y</td>\n",
383
+ " <td>1</td>\n",
384
+ " <td>Step 1: Evaluate coherence\\nThe submission is ...</td>\n",
385
+ " <td>Y</td>\n",
386
+ " <td>1</td>\n",
387
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
388
+ " <td>Y</td>\n",
389
+ " <td>1</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <th>3</th>\n",
393
+ " <td>How can someone help a person who suffers from...</td>\n",
394
+ " <td>To help someone with anxiety, one can support ...</td>\n",
395
+ " <td>13.838464</td>\n",
396
+ " <td>According to the provided context, specificall...</td>\n",
397
+ " <td>Step 1: Correctness\\nThe submission accurately...</td>\n",
398
+ " <td>Y</td>\n",
399
+ " <td>1</td>\n",
400
+ " <td>Step 1: Analyze relevance criterion\\nThe submi...</td>\n",
401
+ " <td>Y</td>\n",
402
+ " <td>1</td>\n",
403
+ " <td>Step 1: Evaluate coherence\\nThe submission dis...</td>\n",
404
+ " <td>Y</td>\n",
405
+ " <td>1</td>\n",
406
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
407
+ " <td>N</td>\n",
408
+ " <td>0</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <th>4</th>\n",
412
+ " <td>What are the causes of mental illness listed i...</td>\n",
413
+ " <td>Causes of mental illness include abnormal func...</td>\n",
414
+ " <td>6.871735</td>\n",
415
+ " <td>According to the provided context, the causes ...</td>\n",
416
+ " <td>The submission lists factors that align with t...</td>\n",
417
+ " <td>N</td>\n",
418
+ " <td>0</td>\n",
419
+ " <td>Step 1: Review relevance criterion - Check if ...</td>\n",
420
+ " <td>Y</td>\n",
421
+ " <td>1</td>\n",
422
+ " <td>Step 1: Compare the submission with the provid...</td>\n",
423
+ " <td>Y</td>\n",
424
+ " <td>1</td>\n",
425
+ " <td>Step 1: Assess conciseness\\nThe submission is ...</td>\n",
426
+ " <td>Y</td>\n",
427
+ " <td>1</td>\n",
428
+ " </tr>\n",
429
+ " </tbody>\n",
430
+ "</table>\n",
431
+ "</div>"
432
+ ],
433
+ "text/plain": [
434
+ " Questions \\\n",
435
+ "0 What is Mental Health \n",
436
+ "1 What are the most common mental disorders ment... \n",
437
+ "2 What are the early warning signs and symptoms ... \n",
438
+ "3 How can someone help a person who suffers from... \n",
439
+ "4 What are the causes of mental illness listed i... \n",
440
+ "\n",
441
+ " Answers latency \\\n",
442
+ "0 Mental Health is a \" state of well-being in wh... 11.974234 \n",
443
+ "1 The most common mental disorders include depre... 5.863329 \n",
444
+ "2 Early warning signs and symptoms of depression... 13.434543 \n",
445
+ "3 To help someone with anxiety, one can support ... 13.838464 \n",
446
+ "4 Causes of mental illness include abnormal func... 6.871735 \n",
447
+ "\n",
448
+ " response \\\n",
449
+ "0 Based on the provided context, specifically fr... \n",
450
+ "1 Based on the provided context, the mental diso... \n",
451
+ "2 Based on the provided context, I found a refer... \n",
452
+ "3 According to the provided context, specificall... \n",
453
+ "4 According to the provided context, the causes ... \n",
454
+ "\n",
455
+ " correctness_reasoning correctness_value \\\n",
456
+ "0 The submission refers to the provided input wh... Y \n",
457
+ "1 Step 1: Check if the submission is factually a... Y \n",
458
+ "2 Step 1: Evaluate Correctness\\nThe submission a... Y \n",
459
+ "3 Step 1: Correctness\\nThe submission accurately... Y \n",
460
+ "4 The submission lists factors that align with t... N \n",
461
+ "\n",
462
+ " correctness_score relevance_reasoning \\\n",
463
+ "0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n",
464
+ "1 1 Step 1: Analyze the relevance criterion\\nThe s... \n",
465
+ "2 1 Step 1: Identify the relevant criterion from t... \n",
466
+ "3 1 Step 1: Analyze relevance criterion\\nThe submi... \n",
467
+ "4 0 Step 1: Review relevance criterion - Check if ... \n",
468
+ "\n",
469
+ " relevance_value relevance_score \\\n",
470
+ "0 Y 1 \n",
471
+ "1 Y 1 \n",
472
+ "2 Y 1 \n",
473
+ "3 Y 1 \n",
474
+ "4 Y 1 \n",
475
+ "\n",
476
+ " coherence_reasoning coherence_value \\\n",
477
+ "0 Step 1: Assess coherence\\nThe submission direc... Y \n",
478
+ "1 The submission begins with an appropriate ques... Y \n",
479
+ "2 Step 1: Evaluate coherence\\nThe submission is ... Y \n",
480
+ "3 Step 1: Evaluate coherence\\nThe submission dis... Y \n",
481
+ "4 Step 1: Compare the submission with the provid... Y \n",
482
+ "\n",
483
+ " coherence_score conciseness_reasoning \\\n",
484
+ "0 1 1. The submission directly answers the questio... \n",
485
+ "1 1 Step 1: Review conciseness criterion\\nThe subm... \n",
486
+ "2 1 Step 1: Evaluate conciseness - The submission ... \n",
487
+ "3 1 Step 1: Evaluate conciseness - The submission ... \n",
488
+ "4 1 Step 1: Assess conciseness\\nThe submission is ... \n",
489
+ "\n",
490
+ " conciseness_value conciseness_score \n",
491
+ "0 Y 1 \n",
492
+ "1 Y 1 \n",
493
+ "2 Y 1 \n",
494
+ "3 N 0 \n",
495
+ "4 Y 1 "
496
+ ]
497
+ },
498
+ "execution_count": 78,
499
+ "metadata": {},
500
+ "output_type": "execute_result"
501
+ }
502
+ ],
503
+ "source": [
504
+ "df.head()"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": 32,
510
+ "id": "7797a360",
511
+ "metadata": {},
512
+ "outputs": [
513
+ {
514
+ "data": {
515
+ "text/plain": [
516
+ "correctness_score 0.800000\n",
517
+ "relevance_score 0.900000\n",
518
+ "coherence_score 1.000000\n",
519
+ "conciseness_score 0.800000\n",
520
+ "latency 10.544803\n",
521
+ "dtype: float64"
522
+ ]
523
+ },
524
+ "execution_count": 32,
525
+ "metadata": {},
526
+ "output_type": "execute_result"
527
+ }
528
+ ],
529
+ "source": [
530
+ "df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": 34,
536
+ "id": "fe667926",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "irr_q=pd.read_csv('data/Unrelated_questions.csv')"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": 35,
546
+ "id": "189f8a0f",
547
+ "metadata": {},
548
+ "outputs": [
549
+ {
550
+ "name": "stderr",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n"
554
+ ]
555
+ }
556
+ ],
557
+ "source": [
558
+ "time_list=[]\n",
559
+ "response_list=[]\n",
560
+ "for i in tqdm(range(len(irr_q))):\n",
561
+ " query = irr_q['Questions'].values[i]\n",
562
+ " start = time.time()\n",
563
+ " response = answer_question(query)\n",
564
+ " end = time.time() \n",
565
+ " time_list.append(end-start)\n",
566
+ " response_list.append(response)"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 36,
572
+ "id": "b0244ea0",
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "irr_q['response']=response_list\n",
577
+ "irr_q['latency']=time_list"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 79,
583
+ "id": "dc3b1ade",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/html": [
589
+ "<div>\n",
590
+ "<style scoped>\n",
591
+ " .dataframe tbody tr th:only-of-type {\n",
592
+ " vertical-align: middle;\n",
593
+ " }\n",
594
+ "\n",
595
+ " .dataframe tbody tr th {\n",
596
+ " vertical-align: top;\n",
597
+ " }\n",
598
+ "\n",
599
+ " .dataframe thead th {\n",
600
+ " text-align: right;\n",
601
+ " }\n",
602
+ "</style>\n",
603
+ "<table border=\"1\" class=\"dataframe\">\n",
604
+ " <thead>\n",
605
+ " <tr style=\"text-align: right;\">\n",
606
+ " <th></th>\n",
607
+ " <th>Questions</th>\n",
608
+ " <th>response</th>\n",
609
+ " <th>latency</th>\n",
610
+ " <th>irrelevant_score</th>\n",
611
+ " </tr>\n",
612
+ " </thead>\n",
613
+ " <tbody>\n",
614
+ " <tr>\n",
615
+ " <th>0</th>\n",
616
+ " <td>What is the capital of Mars?</td>\n",
617
+ " <td>I don't know. The provided context does not se...</td>\n",
618
+ " <td>12.207266</td>\n",
619
+ " <td>True</td>\n",
620
+ " </tr>\n",
621
+ " <tr>\n",
622
+ " <th>1</th>\n",
623
+ " <td>How many unicorns live in New York City?</td>\n",
624
+ " <td>I don't know. The information provided does no...</td>\n",
625
+ " <td>2.368774</td>\n",
626
+ " <td>True</td>\n",
627
+ " </tr>\n",
628
+ " <tr>\n",
629
+ " <th>2</th>\n",
630
+ " <td>What is the color of happiness?</td>\n",
631
+ " <td>I don't know! The provided context only talks ...</td>\n",
632
+ " <td>5.480067</td>\n",
633
+ " <td>True</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <th>3</th>\n",
637
+ " <td>Can cats fly on Tuesdays?</td>\n",
638
+ " <td>I don't know the answer to this question as it...</td>\n",
639
+ " <td>5.272529</td>\n",
640
+ " <td>True</td>\n",
641
+ " </tr>\n",
642
+ " <tr>\n",
643
+ " <th>4</th>\n",
644
+ " <td>How much does a thought weigh?</td>\n",
645
+ " <td>I don't know. The context provided is about me...</td>\n",
646
+ " <td>5.253224</td>\n",
647
+ " <td>True</td>\n",
648
+ " </tr>\n",
649
+ " </tbody>\n",
650
+ "</table>\n",
651
+ "</div>"
652
+ ],
653
+ "text/plain": [
654
+ " Questions \\\n",
655
+ "0 What is the capital of Mars? \n",
656
+ "1 How many unicorns live in New York City? \n",
657
+ "2 What is the color of happiness? \n",
658
+ "3 Can cats fly on Tuesdays? \n",
659
+ "4 How much does a thought weigh? \n",
660
+ "\n",
661
+ " response latency \\\n",
662
+ "0 I don't know. The provided context does not se... 12.207266 \n",
663
+ "1 I don't know. The information provided does no... 2.368774 \n",
664
+ "2 I don't know! The provided context only talks ... 5.480067 \n",
665
+ "3 I don't know the answer to this question as it... 5.272529 \n",
666
+ "4 I don't know. The context provided is about me... 5.253224 \n",
667
+ "\n",
668
+ " irrelevant_score \n",
669
+ "0 True \n",
670
+ "1 True \n",
671
+ "2 True \n",
672
+ "3 True \n",
673
+ "4 True "
674
+ ]
675
+ },
676
+ "execution_count": 79,
677
+ "metadata": {},
678
+ "output_type": "execute_result"
679
+ }
680
+ ],
681
+ "source": [
682
+ "irr_q.head()"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": 37,
688
+ "id": "8620e50c",
689
+ "metadata": {},
690
+ "outputs": [
691
+ {
692
+ "data": {
693
+ "text/plain": [
694
+ "0 12.207266\n",
695
+ "1 2.368774\n",
696
+ "2 5.480067\n",
697
+ "3 5.272529\n",
698
+ "4 5.253224\n",
699
+ "5 5.351224\n",
700
+ "6 8.118429\n",
701
+ "7 7.288261\n",
702
+ "8 3.856500\n",
703
+ "9 7.745016\n",
704
+ "Name: latency, dtype: float64"
705
+ ]
706
+ },
707
+ "execution_count": 37,
708
+ "metadata": {},
709
+ "output_type": "execute_result"
710
+ }
711
+ ],
712
+ "source": [
713
+ "irr_q['latency']"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "execution_count": 39,
719
+ "id": "debd3461",
720
+ "metadata": {},
721
+ "outputs": [],
722
+ "source": [
723
+ "irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": 40,
729
+ "id": "bef1d3a4",
730
+ "metadata": {},
731
+ "outputs": [
732
+ {
733
+ "data": {
734
+ "text/plain": [
735
+ "irrelevant_score 0.900000\n",
736
+ "latency 6.294129\n",
737
+ "dtype: float64"
738
+ ]
739
+ },
740
+ "execution_count": 40,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "irr_q[['irrelevant_score','latency']].mean()"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "markdown",
751
+ "id": "c1610a70",
752
+ "metadata": {},
753
+ "source": [
754
+ "# Improvement"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 48,
760
+ "id": "ff6614f9",
761
+ "metadata": {},
762
+ "outputs": [],
763
+ "source": [
764
+ "new_prompt_template = \"\"\"\n",
765
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
766
+ "Use the provided context to answer the question short and accurately. \n",
767
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
768
+ "\n",
769
+ "Context:\n",
770
+ "{context}\n",
771
+ "\n",
772
+ "Question: {question}\n",
773
+ "\n",
774
+ "Answer:\"\"\"\n",
775
+ "\n",
776
+ "prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n",
777
+ "\n",
778
+ "llm = Ollama(\n",
779
+ " model=\"llama3\"\n",
780
+ ")\n",
781
+ "\n",
782
+ "# Create the chain\n",
783
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
784
+ "\n",
785
+ "def answer_question_new(query):\n",
786
+ " # Search for relevant context\n",
787
+ " search_results = search_faiss(query)\n",
788
+ " \n",
789
+ " # Combine the content from the search results\n",
790
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
791
+ "\n",
792
+ " # Run the chain\n",
793
+ " response = chain.run(context=context, question=query)\n",
794
+ " \n",
795
+ " return response"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": 49,
801
+ "id": "20580d50",
802
+ "metadata": {},
803
+ "outputs": [],
804
+ "source": [
805
+ "df2=df.copy()"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 50,
811
+ "id": "b1b3d725",
812
+ "metadata": {},
813
+ "outputs": [
814
+ {
815
+ "name": "stderr",
816
+ "output_type": "stream",
817
+ "text": [
818
+ "100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n"
819
+ ]
820
+ }
821
+ ],
822
+ "source": [
823
+ "time_list=[]\n",
824
+ "response_list=[]\n",
825
+ "for i in tqdm(range(len(df2))):\n",
826
+ " query = df2['Questions'].values[i]\n",
827
+ " start = time.time()\n",
828
+ " response = answer_question(query)\n",
829
+ " end = time.time() \n",
830
+ " time_list.append(end-start)\n",
831
+ " response_list.append(response)"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 51,
837
+ "id": "63f41256",
838
+ "metadata": {},
839
+ "outputs": [],
840
+ "source": [
841
+ "df2['latency'] = time_list\n",
842
+ "df2['response'] = response_list"
843
+ ]
844
+ },
845
+ {
846
+ "cell_type": "code",
847
+ "execution_count": 52,
848
+ "id": "0d8a6065",
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "name": "stderr",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n",
856
+ "100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n",
857
+ "100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n",
858
+ "100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n"
859
+ ]
860
+ }
861
+ ],
862
+ "source": [
863
+ "for metric in metrics:\n",
864
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
865
+ " \n",
866
+ " reasoning = []\n",
867
+ " value = []\n",
868
+ " score = []\n",
869
+ " \n",
870
+ " for i in tqdm(range(len(df2))):\n",
871
+ " eval_result = evaluator.evaluate_strings(\n",
872
+ " prediction=df2.response.values[i],\n",
873
+ " input=df2.Questions.values[i],\n",
874
+ " reference=df2.Answers.values[i]\n",
875
+ " )\n",
876
+ " reasoning.append(eval_result['reasoning'])\n",
877
+ " value.append(eval_result['value'])\n",
878
+ " score.append(eval_result['score'])\n",
879
+ " \n",
880
+ " df2[metric+'_reasoning'] = reasoning\n",
881
+ " df2[metric+'_value'] = value\n",
882
+ " df2[metric+'_score'] = score "
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": 77,
888
+ "id": "c648632c",
889
+ "metadata": {},
890
+ "outputs": [
891
+ {
892
+ "data": {
893
+ "text/html": [
894
+ "<div>\n",
895
+ "<style scoped>\n",
896
+ " .dataframe tbody tr th:only-of-type {\n",
897
+ " vertical-align: middle;\n",
898
+ " }\n",
899
+ "\n",
900
+ " .dataframe tbody tr th {\n",
901
+ " vertical-align: top;\n",
902
+ " }\n",
903
+ "\n",
904
+ " .dataframe thead th {\n",
905
+ " text-align: right;\n",
906
+ " }\n",
907
+ "</style>\n",
908
+ "<table border=\"1\" class=\"dataframe\">\n",
909
+ " <thead>\n",
910
+ " <tr style=\"text-align: right;\">\n",
911
+ " <th></th>\n",
912
+ " <th>Questions</th>\n",
913
+ " <th>Answers</th>\n",
914
+ " <th>latency</th>\n",
915
+ " <th>response</th>\n",
916
+ " <th>correctness_reasoning</th>\n",
917
+ " <th>correctness_value</th>\n",
918
+ " <th>correctness_score</th>\n",
919
+ " <th>relevance_reasoning</th>\n",
920
+ " <th>relevance_value</th>\n",
921
+ " <th>relevance_score</th>\n",
922
+ " <th>coherence_reasoning</th>\n",
923
+ " <th>coherence_value</th>\n",
924
+ " <th>coherence_score</th>\n",
925
+ " <th>conciseness_reasoning</th>\n",
926
+ " <th>conciseness_value</th>\n",
927
+ " <th>conciseness_score</th>\n",
928
+ " </tr>\n",
929
+ " </thead>\n",
930
+ " <tbody>\n",
931
+ " <tr>\n",
932
+ " <th>0</th>\n",
933
+ " <td>What is Mental Health</td>\n",
934
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
935
+ " <td>11.046327</td>\n",
936
+ " <td>Based on the context provided, mental health r...</td>\n",
937
+ " <td>Step 1: Evaluate if the submission is factuall...</td>\n",
938
+ " <td>N</td>\n",
939
+ " <td>0</td>\n",
940
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
941
+ " <td>N</td>\n",
942
+ " <td>0</td>\n",
943
+ " <td>The submission discusses mental health in rela...</td>\n",
944
+ " <td>Y</td>\n",
945
+ " <td>1</td>\n",
946
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
947
+ " <td>Y</td>\n",
948
+ " <td>1</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <th>1</th>\n",
952
+ " <td>What are the most common mental disorders ment...</td>\n",
953
+ " <td>The most common mental disorders include depre...</td>\n",
954
+ " <td>4.509713</td>\n",
955
+ " <td>The handbook mentions several mental illnesses...</td>\n",
956
+ " <td>The submission mentions depression and schizop...</td>\n",
957
+ " <td>N</td>\n",
958
+ " <td>0</td>\n",
959
+ " <td>Step 1: Analyze relevance criterion - Check if...</td>\n",
960
+ " <td>Y</td>\n",
961
+ " <td>1</td>\n",
962
+ " <td>Step 1: Assess coherence\\nThe submission menti...</td>\n",
963
+ " <td>N</td>\n",
964
+ " <td>0</td>\n",
965
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
966
+ " <td>N</td>\n",
967
+ " <td>0</td>\n",
968
+ " </tr>\n",
969
+ " <tr>\n",
970
+ " <th>2</th>\n",
971
+ " <td>What are the early warning signs and symptoms ...</td>\n",
972
+ " <td>Early warning signs and symptoms of depression...</td>\n",
973
+ " <td>8.501180</td>\n",
974
+ " <td>According to the provided context, specificall...</td>\n",
975
+ " <td>The submission matches the reference data in t...</td>\n",
976
+ " <td>Y</td>\n",
977
+ " <td>1</td>\n",
978
+ " <td>The submission refers directly to information ...</td>\n",
979
+ " <td>Y</td>\n",
980
+ " <td>1</td>\n",
981
+ " <td>Step 1: Evaluate coherence - The submission is...</td>\n",
982
+ " <td>Y</td>\n",
983
+ " <td>1</td>\n",
984
+ " <td>The submission is concise and includes most of...</td>\n",
985
+ " <td>Y</td>\n",
986
+ " <td>1</td>\n",
987
+ " </tr>\n",
988
+ " <tr>\n",
989
+ " <th>3</th>\n",
990
+ " <td>How can someone help a person who suffers from...</td>\n",
991
+ " <td>To help someone with anxiety, one can support ...</td>\n",
992
+ " <td>10.611402</td>\n",
993
+ " <td>According to the Mental Health Handbook, when ...</td>\n",
994
+ " <td>The submission seems consistent with the refer...</td>\n",
995
+ " <td>Y</td>\n",
996
+ " <td>1</td>\n",
997
+ " <td>Step 1: Review relevance criterion\\nThe submis...</td>\n",
998
+ " <td>Y</td>\n",
999
+ " <td>1</td>\n",
1000
+ " <td>The submission is coherent, well-structured, a...</td>\n",
1001
+ " <td>Y</td>\n",
1002
+ " <td>1</td>\n",
1003
+ " <td>The submission is relatively concise and cover...</td>\n",
1004
+ " <td>Y</td>\n",
1005
+ " <td>1</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <th>4</th>\n",
1009
+ " <td>What are the causes of mental illness listed i...</td>\n",
1010
+ " <td>Causes of mental illness include abnormal func...</td>\n",
1011
+ " <td>6.299272</td>\n",
1012
+ " <td>According to the context, the causes of mental...</td>\n",
1013
+ " <td>The submission lists causes such as neglect, s...</td>\n",
1014
+ " <td>N</td>\n",
1015
+ " <td>0</td>\n",
1016
+ " <td>The submission mentions factors that are part ...</td>\n",
1017
+ " <td>N</td>\n",
1018
+ " <td>0</td>\n",
1019
+ " <td>The submission is coherent and well-structured...</td>\n",
1020
+ " <td>Y</td>\n",
1021
+ " <td>1</td>\n",
1022
+ " <td>Step 1: Read and understand both the input dat...</td>\n",
1023
+ " <td>N</td>\n",
1024
+ " <td>0</td>\n",
1025
+ " </tr>\n",
1026
+ " </tbody>\n",
1027
+ "</table>\n",
1028
+ "</div>"
1029
+ ],
1030
+ "text/plain": [
1031
+ " Questions \\\n",
1032
+ "0 What is Mental Health \n",
1033
+ "1 What are the most common mental disorders ment... \n",
1034
+ "2 What are the early warning signs and symptoms ... \n",
1035
+ "3 How can someone help a person who suffers from... \n",
1036
+ "4 What are the causes of mental illness listed i... \n",
1037
+ "\n",
1038
+ " Answers latency \\\n",
1039
+ "0 Mental Health is a \" state of well-being in wh... 11.046327 \n",
1040
+ "1 The most common mental disorders include depre... 4.509713 \n",
1041
+ "2 Early warning signs and symptoms of depression... 8.501180 \n",
1042
+ "3 To help someone with anxiety, one can support ... 10.611402 \n",
1043
+ "4 Causes of mental illness include abnormal func... 6.299272 \n",
1044
+ "\n",
1045
+ " response \\\n",
1046
+ "0 Based on the context provided, mental health r... \n",
1047
+ "1 The handbook mentions several mental illnesses... \n",
1048
+ "2 According to the provided context, specificall... \n",
1049
+ "3 According to the Mental Health Handbook, when ... \n",
1050
+ "4 According to the context, the causes of mental... \n",
1051
+ "\n",
1052
+ " correctness_reasoning correctness_value \\\n",
1053
+ "0 Step 1: Evaluate if the submission is factuall... N \n",
1054
+ "1 The submission mentions depression and schizop... N \n",
1055
+ "2 The submission matches the reference data in t... Y \n",
1056
+ "3 The submission seems consistent with the refer... Y \n",
1057
+ "4 The submission lists causes such as neglect, s... N \n",
1058
+ "\n",
1059
+ " correctness_score relevance_reasoning \\\n",
1060
+ "0 0 Step 1: Analyze the relevance criterion\\nThe s... \n",
1061
+ "1 0 Step 1: Analyze relevance criterion - Check if... \n",
1062
+ "2 1 The submission refers directly to information ... \n",
1063
+ "3 1 Step 1: Review relevance criterion\\nThe submis... \n",
1064
+ "4 0 The submission mentions factors that are part ... \n",
1065
+ "\n",
1066
+ " relevance_value relevance_score \\\n",
1067
+ "0 N 0 \n",
1068
+ "1 Y 1 \n",
1069
+ "2 Y 1 \n",
1070
+ "3 Y 1 \n",
1071
+ "4 N 0 \n",
1072
+ "\n",
1073
+ " coherence_reasoning coherence_value \\\n",
1074
+ "0 The submission discusses mental health in rela... Y \n",
1075
+ "1 Step 1: Assess coherence\\nThe submission menti... N \n",
1076
+ "2 Step 1: Evaluate coherence - The submission is... Y \n",
1077
+ "3 The submission is coherent, well-structured, a... Y \n",
1078
+ "4 The submission is coherent and well-structured... Y \n",
1079
+ "\n",
1080
+ " coherence_score conciseness_reasoning \\\n",
1081
+ "0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1082
+ "1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1083
+ "2 1 The submission is concise and includes most of... \n",
1084
+ "3 1 The submission is relatively concise and cover... \n",
1085
+ "4 1 Step 1: Read and understand both the input dat... \n",
1086
+ "\n",
1087
+ " conciseness_value conciseness_score \n",
1088
+ "0 Y 1 \n",
1089
+ "1 N 0 \n",
1090
+ "2 Y 1 \n",
1091
+ "3 Y 1 \n",
1092
+ "4 N 0 "
1093
+ ]
1094
+ },
1095
+ "execution_count": 77,
1096
+ "metadata": {},
1097
+ "output_type": "execute_result"
1098
+ }
1099
+ ],
1100
+ "source": [
1101
+ "df2.head()"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "execution_count": 47,
1107
+ "id": "2d1002b2",
1108
+ "metadata": {},
1109
+ "outputs": [
1110
+ {
1111
+ "data": {
1112
+ "text/plain": [
1113
+ "correctness_score 0.500000\n",
1114
+ "relevance_score 0.888889\n",
1115
+ "coherence_score 0.888889\n",
1116
+ "conciseness_score 0.900000\n",
1117
+ "latency 8.190205\n",
1118
+ "dtype: float64"
1119
+ ]
1120
+ },
1121
+ "execution_count": 47,
1122
+ "metadata": {},
1123
+ "output_type": "execute_result"
1124
+ }
1125
+ ],
1126
+ "source": [
1127
+ "df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
1128
+ ]
1129
+ },
1130
+ {
1131
+ "cell_type": "markdown",
1132
+ "id": "e808bdcf",
1133
+ "metadata": {},
1134
+ "source": [
1135
+ "# Query relevance"
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "code",
1140
+ "execution_count": 66,
1141
+ "id": "6b541f3d",
1142
+ "metadata": {},
1143
+ "outputs": [],
1144
+ "source": [
1145
+ "def new_search_faiss(query, k=3, threshold=1.5):\n",
1146
+ " query_vector = model.encode([query])[0].astype('float32')\n",
1147
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
1148
+ " distances, indices = index.search(query_vector, k)\n",
1149
+ " \n",
1150
+ " results = []\n",
1151
+ " for dist, idx in zip(distances[0], indices[0]):\n",
1152
+ " if dist < threshold: # Only include results within the threshold distance\n",
1153
+ " results.append({\n",
1154
+ " 'distance': dist,\n",
1155
+ " 'content': sections_data[idx]['content'],\n",
1156
+ " 'metadata': sections_data[idx]['metadata']\n",
1157
+ " })\n",
1158
+ " \n",
1159
+ " return results"
1160
+ ]
1161
+ },
1162
+ {
1163
+ "cell_type": "code",
1164
+ "execution_count": 70,
1165
+ "id": "4f579654",
1166
+ "metadata": {},
1167
+ "outputs": [],
1168
+ "source": [
1169
+ "new_prompt_template = \"\"\"\n",
1170
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
1171
+ "Use the provided context to answer the question short and accurately. \n",
1172
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
1173
+ "\n",
1174
+ "Context:\n",
1175
+ "{context}\n",
1176
+ "\n",
1177
+ "Question: {question}\n",
1178
+ "\n",
1179
+ "Answer:\"\"\"\n",
1180
+ "\n",
1181
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
1182
+ "\n",
1183
+ "llm = Ollama(\n",
1184
+ " model=\"llama3\"\n",
1185
+ ")\n",
1186
+ "\n",
1187
+ "# Create the chain\n",
1188
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
1189
+ "\n",
1190
+ "def new_answer_question(query):\n",
1191
+ " # Search for relevant context\n",
1192
+ " search_results = new_search_faiss(query)\n",
1193
+ " \n",
1194
+ " if search_results==[]:\n",
1195
+ " response=\"I don't know, sorry\"\n",
1196
+ " else:\n",
1197
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
1198
+ " response = chain.run(context=context, question=query)\n",
1199
+ " \n",
1200
+ " return response"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "code",
1205
+ "execution_count": 71,
1206
+ "id": "1f83ef1b",
1207
+ "metadata": {},
1208
+ "outputs": [],
1209
+ "source": [
1210
+ "irr_q2=irr_q.copy()"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "execution_count": 72,
1216
+ "id": "f06474e3",
1217
+ "metadata": {},
1218
+ "outputs": [
1219
+ {
1220
+ "name": "stderr",
1221
+ "output_type": "stream",
1222
+ "text": [
1223
+ "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n"
1224
+ ]
1225
+ }
1226
+ ],
1227
+ "source": [
1228
+ "time_list=[]\n",
1229
+ "response_list=[]\n",
1230
+ "for i in tqdm(range(len(irr_q2))):\n",
1231
+ " query = irr_q['Questions'].values[i]\n",
1232
+ " start = time.time()\n",
1233
+ " response = new_answer_question(query)\n",
1234
+ " end = time.time() \n",
1235
+ " time_list.append(end-start)\n",
1236
+ " response_list.append(response)"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": 73,
1242
+ "id": "52db6b82",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": [
1246
+ "irr_q2['response']=response_list\n",
1247
+ "irr_q2['latency']=time_list"
1248
+ ]
1249
+ },
1250
+ {
1251
+ "cell_type": "code",
1252
+ "execution_count": 80,
1253
+ "id": "80a178ee",
1254
+ "metadata": {},
1255
+ "outputs": [
1256
+ {
1257
+ "data": {
1258
+ "text/html": [
1259
+ "<div>\n",
1260
+ "<style scoped>\n",
1261
+ " .dataframe tbody tr th:only-of-type {\n",
1262
+ " vertical-align: middle;\n",
1263
+ " }\n",
1264
+ "\n",
1265
+ " .dataframe tbody tr th {\n",
1266
+ " vertical-align: top;\n",
1267
+ " }\n",
1268
+ "\n",
1269
+ " .dataframe thead th {\n",
1270
+ " text-align: right;\n",
1271
+ " }\n",
1272
+ "</style>\n",
1273
+ "<table border=\"1\" class=\"dataframe\">\n",
1274
+ " <thead>\n",
1275
+ " <tr style=\"text-align: right;\">\n",
1276
+ " <th></th>\n",
1277
+ " <th>Questions</th>\n",
1278
+ " <th>response</th>\n",
1279
+ " <th>latency</th>\n",
1280
+ " <th>irrelevant_score</th>\n",
1281
+ " </tr>\n",
1282
+ " </thead>\n",
1283
+ " <tbody>\n",
1284
+ " <tr>\n",
1285
+ " <th>0</th>\n",
1286
+ " <td>What is the capital of Mars?</td>\n",
1287
+ " <td>I don't know, sorry</td>\n",
1288
+ " <td>0.061378</td>\n",
1289
+ " <td>True</td>\n",
1290
+ " </tr>\n",
1291
+ " <tr>\n",
1292
+ " <th>1</th>\n",
1293
+ " <td>How many unicorns live in New York City?</td>\n",
1294
+ " <td>I don't know, sorry</td>\n",
1295
+ " <td>0.012511</td>\n",
1296
+ " <td>True</td>\n",
1297
+ " </tr>\n",
1298
+ " <tr>\n",
1299
+ " <th>2</th>\n",
1300
+ " <td>What is the color of happiness?</td>\n",
1301
+ " <td>I don't know, sorry</td>\n",
1302
+ " <td>0.011900</td>\n",
1303
+ " <td>True</td>\n",
1304
+ " </tr>\n",
1305
+ " <tr>\n",
1306
+ " <th>3</th>\n",
1307
+ " <td>Can cats fly on Tuesdays?</td>\n",
1308
+ " <td>I don't know, sorry</td>\n",
1309
+ " <td>0.011438</td>\n",
1310
+ " <td>True</td>\n",
1311
+ " </tr>\n",
1312
+ " <tr>\n",
1313
+ " <th>4</th>\n",
1314
+ " <td>How much does a thought weigh?</td>\n",
1315
+ " <td>I don't know, sorry</td>\n",
1316
+ " <td>0.010644</td>\n",
1317
+ " <td>True</td>\n",
1318
+ " </tr>\n",
1319
+ " </tbody>\n",
1320
+ "</table>\n",
1321
+ "</div>"
1322
+ ],
1323
+ "text/plain": [
1324
+ " Questions response latency \\\n",
1325
+ "0 What is the capital of Mars? I don't know, sorry 0.061378 \n",
1326
+ "1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n",
1327
+ "2 What is the color of happiness? I don't know, sorry 0.011900 \n",
1328
+ "3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n",
1329
+ "4 How much does a thought weigh? I don't know, sorry 0.010644 \n",
1330
+ "\n",
1331
+ " irrelevant_score \n",
1332
+ "0 True \n",
1333
+ "1 True \n",
1334
+ "2 True \n",
1335
+ "3 True \n",
1336
+ "4 True "
1337
+ ]
1338
+ },
1339
+ "execution_count": 80,
1340
+ "metadata": {},
1341
+ "output_type": "execute_result"
1342
+ }
1343
+ ],
1344
+ "source": [
1345
+ "irr_q2.head()"
1346
+ ]
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "execution_count": 74,
1351
+ "id": "4508de9e",
1352
+ "metadata": {},
1353
+ "outputs": [],
1354
+ "source": [
1355
+ "irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")"
1356
+ ]
1357
+ },
1358
+ {
1359
+ "cell_type": "code",
1360
+ "execution_count": 75,
1361
+ "id": "3d34ba06",
1362
+ "metadata": {},
1363
+ "outputs": [
1364
+ {
1365
+ "data": {
1366
+ "text/plain": [
1367
+ "irrelevant_score 1.000000\n",
1368
+ "latency 0.016068\n",
1369
+ "dtype: float64"
1370
+ ]
1371
+ },
1372
+ "execution_count": 75,
1373
+ "metadata": {},
1374
+ "output_type": "execute_result"
1375
+ }
1376
+ ],
1377
+ "source": [
1378
+ "irr_q2[['irrelevant_score','latency']].mean()"
1379
+ ]
1380
+ }
1381
+ ],
1382
+ "metadata": {
1383
+ "kernelspec": {
1384
+ "display_name": "Python 3 (ipykernel)",
1385
+ "language": "python",
1386
+ "name": "python3"
1387
+ },
1388
+ "language_info": {
1389
+ "codemirror_mode": {
1390
+ "name": "ipython",
1391
+ "version": 3
1392
+ },
1393
+ "file_extension": ".py",
1394
+ "mimetype": "text/x-python",
1395
+ "name": "python",
1396
+ "nbconvert_exporter": "python",
1397
+ "pygments_lexer": "ipython3",
1398
+ "version": "3.11.5"
1399
+ }
1400
+ },
1401
+ "nbformat": 4,
1402
+ "nbformat_minor": 5
1403
+ }
Evaluation_MH/Evaluation.ipynb ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f7b87c2c",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# pip install -U sentence-transformers"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "8a7cc9d8",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
32
+ " from tqdm.autonotebook import tqdm, trange\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "from sentence_transformers import SentenceTransformer\n",
38
+ "from langchain.prompts import PromptTemplate\n",
39
+ "from langchain.chains import LLMChain\n",
40
+ "from langchain_community.llms import Ollama\n",
41
+ "from langchain.evaluation import load_evaluator\n",
42
+ "import faiss\n",
43
+ "import pandas as pd\n",
44
+ "import numpy as np\n",
45
+ "import pickle\n",
46
+ "import time\n",
47
+ "from tqdm import tqdm"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "id": "b6efca1d",
53
+ "metadata": {},
54
+ "source": [
55
+ "# Intialization"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 2,
61
+ "id": "cc9a49d2",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# Load the FAISS index\n",
66
+ "index = faiss.read_index(\"database/pdf_sections_index.faiss\")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "id": "9af39b55",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "model = SentenceTransformer('all-MiniLM-L6-v2')"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 4,
82
+ "id": "fee8cdfd",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "with open('database/pdf_sections_data.pkl', 'rb') as f:\n",
87
+ " sections_data = pickle.load(f)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "id": "d6a1ba6a",
93
+ "metadata": {},
94
+ "source": [
95
+ "# RAG functions"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 5,
101
+ "id": "182bdbd8",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def search_faiss(query, k=3):\n",
106
+ " query_vector = model.encode([query])[0].astype('float32')\n",
107
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
108
+ " distances, indices = index.search(query_vector, k)\n",
109
+ " \n",
110
+ " results = []\n",
111
+ " for dist, idx in zip(distances[0], indices[0]):\n",
112
+ " results.append({\n",
113
+ " 'distance': dist,\n",
114
+ " 'content': sections_data[idx]['content'],\n",
115
+ " 'metadata': sections_data[idx]['metadata']\n",
116
+ " })\n",
117
+ " \n",
118
+ " return results"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 15,
124
+ "id": "67edc46a",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Create a prompt template\n",
129
+ "prompt_template = \"\"\"\n",
130
+ "You are an AI assistant specialized in Mental Health guidelines. \n",
131
+ "Use the following pieces of context to answer the question. \n",
132
+ "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
133
+ "\n",
134
+ "Context:\n",
135
+ "{context}\n",
136
+ "\n",
137
+ "Question: {question}\n",
138
+ "\n",
139
+ "Answer:\"\"\"\n",
140
+ "\n",
141
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
142
+ "\n",
143
+ "llm = Ollama(\n",
144
+ " model=\"llama3\"\n",
145
+ ")\n",
146
+ "\n",
147
+ "# Create the chain\n",
148
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
149
+ "\n",
150
+ "def answer_question(query):\n",
151
+ " # Search for relevant context\n",
152
+ " search_results = search_faiss(query)\n",
153
+ " \n",
154
+ " # Combine the content from the search results\n",
155
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
156
+ "\n",
157
+ " # Run the chain\n",
158
+ " response = chain.run(context=context, question=query)\n",
159
+ " \n",
160
+ " return response"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "3b176af9",
166
+ "metadata": {},
167
+ "source": [
168
+ "# Reading GT"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 16,
174
+ "id": "4ab68dff",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "df = pd.read_csv('data/MentalHealth_Dataset.csv')"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 17,
184
+ "id": "4e7e22d7",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stderr",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "time_list=[]\n",
197
+ "response_list=[]\n",
198
+ "for i in tqdm(range(len(df))):\n",
199
+ " query = df['Questions'].values[i]\n",
200
+ " start = time.time()\n",
201
+ " response = answer_question(query)\n",
202
+ " end = time.time() \n",
203
+ " time_list.append(end-start)\n",
204
+ " response_list.append(response)"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 18,
210
+ "id": "2b327e90",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "df['latency'] = time_list\n",
215
+ "df['response'] = response_list"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "3c147204",
221
+ "metadata": {},
222
+ "source": [
223
+ "# Evaluation"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 29,
229
+ "id": "d799e541",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "eval_llm = Ollama(\n",
234
+ " model=\"phi3\"\n",
235
+ ")"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 30,
241
+ "id": "c2f788dc",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "metrics = ['correctness', 'relevance', 'coherence', 'conciseness']"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 31,
251
+ "id": "83ec2b8d",
252
+ "metadata": {},
253
+ "outputs": [
254
+ {
255
+ "name": "stderr",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n",
259
+ "100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n",
260
+ "100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n",
261
+ "100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "for metric in metrics:\n",
267
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
268
+ " \n",
269
+ " reasoning = []\n",
270
+ " value = []\n",
271
+ " score = []\n",
272
+ " \n",
273
+ " for i in tqdm(range(len(df))):\n",
274
+ " eval_result = evaluator.evaluate_strings(\n",
275
+ " prediction=df.response.values[i],\n",
276
+ " input=df.Questions.values[i],\n",
277
+ " reference=df.Answers.values[i]\n",
278
+ " )\n",
279
+ " reasoning.append(eval_result['reasoning'])\n",
280
+ " value.append(eval_result['value'])\n",
281
+ " score.append(eval_result['score'])\n",
282
+ " \n",
283
+ " df[metric+'_reasoning'] = reasoning\n",
284
+ " df[metric+'_value'] = value\n",
285
+ " df[metric+'_score'] = score "
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 78,
291
+ "id": "f1673a31",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "<div>\n",
298
+ "<style scoped>\n",
299
+ " .dataframe tbody tr th:only-of-type {\n",
300
+ " vertical-align: middle;\n",
301
+ " }\n",
302
+ "\n",
303
+ " .dataframe tbody tr th {\n",
304
+ " vertical-align: top;\n",
305
+ " }\n",
306
+ "\n",
307
+ " .dataframe thead th {\n",
308
+ " text-align: right;\n",
309
+ " }\n",
310
+ "</style>\n",
311
+ "<table border=\"1\" class=\"dataframe\">\n",
312
+ " <thead>\n",
313
+ " <tr style=\"text-align: right;\">\n",
314
+ " <th></th>\n",
315
+ " <th>Questions</th>\n",
316
+ " <th>Answers</th>\n",
317
+ " <th>latency</th>\n",
318
+ " <th>response</th>\n",
319
+ " <th>correctness_reasoning</th>\n",
320
+ " <th>correctness_value</th>\n",
321
+ " <th>correctness_score</th>\n",
322
+ " <th>relevance_reasoning</th>\n",
323
+ " <th>relevance_value</th>\n",
324
+ " <th>relevance_score</th>\n",
325
+ " <th>coherence_reasoning</th>\n",
326
+ " <th>coherence_value</th>\n",
327
+ " <th>coherence_score</th>\n",
328
+ " <th>conciseness_reasoning</th>\n",
329
+ " <th>conciseness_value</th>\n",
330
+ " <th>conciseness_score</th>\n",
331
+ " </tr>\n",
332
+ " </thead>\n",
333
+ " <tbody>\n",
334
+ " <tr>\n",
335
+ " <th>0</th>\n",
336
+ " <td>What is Mental Health</td>\n",
337
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
338
+ " <td>11.974234</td>\n",
339
+ " <td>Based on the provided context, specifically fr...</td>\n",
340
+ " <td>The submission refers to the provided input wh...</td>\n",
341
+ " <td>Y</td>\n",
342
+ " <td>1</td>\n",
343
+ " <td>Step 1: Evaluate relevance criterion\\nThe subm...</td>\n",
344
+ " <td>Y</td>\n",
345
+ " <td>1</td>\n",
346
+ " <td>Step 1: Assess coherence\\nThe submission direc...</td>\n",
347
+ " <td>Y</td>\n",
348
+ " <td>1</td>\n",
349
+ " <td>1. The submission directly answers the questio...</td>\n",
350
+ " <td>Y</td>\n",
351
+ " <td>1</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <th>1</th>\n",
355
+ " <td>What are the most common mental disorders ment...</td>\n",
356
+ " <td>The most common mental disorders include depre...</td>\n",
357
+ " <td>5.863329</td>\n",
358
+ " <td>Based on the provided context, the mental diso...</td>\n",
359
+ " <td>Step 1: Check if the submission is factually a...</td>\n",
360
+ " <td>Y</td>\n",
361
+ " <td>1</td>\n",
362
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
363
+ " <td>Y</td>\n",
364
+ " <td>1</td>\n",
365
+ " <td>The submission begins with an appropriate ques...</td>\n",
366
+ " <td>Y</td>\n",
367
+ " <td>1</td>\n",
368
+ " <td>Step 1: Review conciseness criterion\\nThe subm...</td>\n",
369
+ " <td>Y</td>\n",
370
+ " <td>1</td>\n",
371
+ " </tr>\n",
372
+ " <tr>\n",
373
+ " <th>2</th>\n",
374
+ " <td>What are the early warning signs and symptoms ...</td>\n",
375
+ " <td>Early warning signs and symptoms of depression...</td>\n",
376
+ " <td>13.434543</td>\n",
377
+ " <td>Based on the provided context, I found a refer...</td>\n",
378
+ " <td>Step 1: Evaluate Correctness\\nThe submission a...</td>\n",
379
+ " <td>Y</td>\n",
380
+ " <td>1</td>\n",
381
+ " <td>Step 1: Identify the relevant criterion from t...</td>\n",
382
+ " <td>Y</td>\n",
383
+ " <td>1</td>\n",
384
+ " <td>Step 1: Evaluate coherence\\nThe submission is ...</td>\n",
385
+ " <td>Y</td>\n",
386
+ " <td>1</td>\n",
387
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
388
+ " <td>Y</td>\n",
389
+ " <td>1</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <th>3</th>\n",
393
+ " <td>How can someone help a person who suffers from...</td>\n",
394
+ " <td>To help someone with anxiety, one can support ...</td>\n",
395
+ " <td>13.838464</td>\n",
396
+ " <td>According to the provided context, specificall...</td>\n",
397
+ " <td>Step 1: Correctness\\nThe submission accurately...</td>\n",
398
+ " <td>Y</td>\n",
399
+ " <td>1</td>\n",
400
+ " <td>Step 1: Analyze relevance criterion\\nThe submi...</td>\n",
401
+ " <td>Y</td>\n",
402
+ " <td>1</td>\n",
403
+ " <td>Step 1: Evaluate coherence\\nThe submission dis...</td>\n",
404
+ " <td>Y</td>\n",
405
+ " <td>1</td>\n",
406
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
407
+ " <td>N</td>\n",
408
+ " <td>0</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <th>4</th>\n",
412
+ " <td>What are the causes of mental illness listed i...</td>\n",
413
+ " <td>Causes of mental illness include abnormal func...</td>\n",
414
+ " <td>6.871735</td>\n",
415
+ " <td>According to the provided context, the causes ...</td>\n",
416
+ " <td>The submission lists factors that align with t...</td>\n",
417
+ " <td>N</td>\n",
418
+ " <td>0</td>\n",
419
+ " <td>Step 1: Review relevance criterion - Check if ...</td>\n",
420
+ " <td>Y</td>\n",
421
+ " <td>1</td>\n",
422
+ " <td>Step 1: Compare the submission with the provid...</td>\n",
423
+ " <td>Y</td>\n",
424
+ " <td>1</td>\n",
425
+ " <td>Step 1: Assess conciseness\\nThe submission is ...</td>\n",
426
+ " <td>Y</td>\n",
427
+ " <td>1</td>\n",
428
+ " </tr>\n",
429
+ " </tbody>\n",
430
+ "</table>\n",
431
+ "</div>"
432
+ ],
433
+ "text/plain": [
434
+ " Questions \\\n",
435
+ "0 What is Mental Health \n",
436
+ "1 What are the most common mental disorders ment... \n",
437
+ "2 What are the early warning signs and symptoms ... \n",
438
+ "3 How can someone help a person who suffers from... \n",
439
+ "4 What are the causes of mental illness listed i... \n",
440
+ "\n",
441
+ " Answers latency \\\n",
442
+ "0 Mental Health is a \" state of well-being in wh... 11.974234 \n",
443
+ "1 The most common mental disorders include depre... 5.863329 \n",
444
+ "2 Early warning signs and symptoms of depression... 13.434543 \n",
445
+ "3 To help someone with anxiety, one can support ... 13.838464 \n",
446
+ "4 Causes of mental illness include abnormal func... 6.871735 \n",
447
+ "\n",
448
+ " response \\\n",
449
+ "0 Based on the provided context, specifically fr... \n",
450
+ "1 Based on the provided context, the mental diso... \n",
451
+ "2 Based on the provided context, I found a refer... \n",
452
+ "3 According to the provided context, specificall... \n",
453
+ "4 According to the provided context, the causes ... \n",
454
+ "\n",
455
+ " correctness_reasoning correctness_value \\\n",
456
+ "0 The submission refers to the provided input wh... Y \n",
457
+ "1 Step 1: Check if the submission is factually a... Y \n",
458
+ "2 Step 1: Evaluate Correctness\\nThe submission a... Y \n",
459
+ "3 Step 1: Correctness\\nThe submission accurately... Y \n",
460
+ "4 The submission lists factors that align with t... N \n",
461
+ "\n",
462
+ " correctness_score relevance_reasoning \\\n",
463
+ "0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n",
464
+ "1 1 Step 1: Analyze the relevance criterion\\nThe s... \n",
465
+ "2 1 Step 1: Identify the relevant criterion from t... \n",
466
+ "3 1 Step 1: Analyze relevance criterion\\nThe submi... \n",
467
+ "4 0 Step 1: Review relevance criterion - Check if ... \n",
468
+ "\n",
469
+ " relevance_value relevance_score \\\n",
470
+ "0 Y 1 \n",
471
+ "1 Y 1 \n",
472
+ "2 Y 1 \n",
473
+ "3 Y 1 \n",
474
+ "4 Y 1 \n",
475
+ "\n",
476
+ " coherence_reasoning coherence_value \\\n",
477
+ "0 Step 1: Assess coherence\\nThe submission direc... Y \n",
478
+ "1 The submission begins with an appropriate ques... Y \n",
479
+ "2 Step 1: Evaluate coherence\\nThe submission is ... Y \n",
480
+ "3 Step 1: Evaluate coherence\\nThe submission dis... Y \n",
481
+ "4 Step 1: Compare the submission with the provid... Y \n",
482
+ "\n",
483
+ " coherence_score conciseness_reasoning \\\n",
484
+ "0 1 1. The submission directly answers the questio... \n",
485
+ "1 1 Step 1: Review conciseness criterion\\nThe subm... \n",
486
+ "2 1 Step 1: Evaluate conciseness - The submission ... \n",
487
+ "3 1 Step 1: Evaluate conciseness - The submission ... \n",
488
+ "4 1 Step 1: Assess conciseness\\nThe submission is ... \n",
489
+ "\n",
490
+ " conciseness_value conciseness_score \n",
491
+ "0 Y 1 \n",
492
+ "1 Y 1 \n",
493
+ "2 Y 1 \n",
494
+ "3 N 0 \n",
495
+ "4 Y 1 "
496
+ ]
497
+ },
498
+ "execution_count": 78,
499
+ "metadata": {},
500
+ "output_type": "execute_result"
501
+ }
502
+ ],
503
+ "source": [
504
+ "df.head()"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": 32,
510
+ "id": "7797a360",
511
+ "metadata": {},
512
+ "outputs": [
513
+ {
514
+ "data": {
515
+ "text/plain": [
516
+ "correctness_score 0.800000\n",
517
+ "relevance_score 0.900000\n",
518
+ "coherence_score 1.000000\n",
519
+ "conciseness_score 0.800000\n",
520
+ "latency 10.544803\n",
521
+ "dtype: float64"
522
+ ]
523
+ },
524
+ "execution_count": 32,
525
+ "metadata": {},
526
+ "output_type": "execute_result"
527
+ }
528
+ ],
529
+ "source": [
530
+ "df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": 34,
536
+ "id": "fe667926",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "irr_q=pd.read_csv('data/Unrelated_questions.csv')"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": 35,
546
+ "id": "189f8a0f",
547
+ "metadata": {},
548
+ "outputs": [
549
+ {
550
+ "name": "stderr",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n"
554
+ ]
555
+ }
556
+ ],
557
+ "source": [
558
+ "time_list=[]\n",
559
+ "response_list=[]\n",
560
+ "for i in tqdm(range(len(irr_q))):\n",
561
+ " query = irr_q['Questions'].values[i]\n",
562
+ " start = time.time()\n",
563
+ " response = answer_question(query)\n",
564
+ " end = time.time() \n",
565
+ " time_list.append(end-start)\n",
566
+ " response_list.append(response)"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 36,
572
+ "id": "b0244ea0",
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "irr_q['response']=response_list\n",
577
+ "irr_q['latency']=time_list"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 79,
583
+ "id": "dc3b1ade",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/html": [
589
+ "<div>\n",
590
+ "<style scoped>\n",
591
+ " .dataframe tbody tr th:only-of-type {\n",
592
+ " vertical-align: middle;\n",
593
+ " }\n",
594
+ "\n",
595
+ " .dataframe tbody tr th {\n",
596
+ " vertical-align: top;\n",
597
+ " }\n",
598
+ "\n",
599
+ " .dataframe thead th {\n",
600
+ " text-align: right;\n",
601
+ " }\n",
602
+ "</style>\n",
603
+ "<table border=\"1\" class=\"dataframe\">\n",
604
+ " <thead>\n",
605
+ " <tr style=\"text-align: right;\">\n",
606
+ " <th></th>\n",
607
+ " <th>Questions</th>\n",
608
+ " <th>response</th>\n",
609
+ " <th>latency</th>\n",
610
+ " <th>irrelevant_score</th>\n",
611
+ " </tr>\n",
612
+ " </thead>\n",
613
+ " <tbody>\n",
614
+ " <tr>\n",
615
+ " <th>0</th>\n",
616
+ " <td>What is the capital of Mars?</td>\n",
617
+ " <td>I don't know. The provided context does not se...</td>\n",
618
+ " <td>12.207266</td>\n",
619
+ " <td>True</td>\n",
620
+ " </tr>\n",
621
+ " <tr>\n",
622
+ " <th>1</th>\n",
623
+ " <td>How many unicorns live in New York City?</td>\n",
624
+ " <td>I don't know. The information provided does no...</td>\n",
625
+ " <td>2.368774</td>\n",
626
+ " <td>True</td>\n",
627
+ " </tr>\n",
628
+ " <tr>\n",
629
+ " <th>2</th>\n",
630
+ " <td>What is the color of happiness?</td>\n",
631
+ " <td>I don't know! The provided context only talks ...</td>\n",
632
+ " <td>5.480067</td>\n",
633
+ " <td>True</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <th>3</th>\n",
637
+ " <td>Can cats fly on Tuesdays?</td>\n",
638
+ " <td>I don't know the answer to this question as it...</td>\n",
639
+ " <td>5.272529</td>\n",
640
+ " <td>True</td>\n",
641
+ " </tr>\n",
642
+ " <tr>\n",
643
+ " <th>4</th>\n",
644
+ " <td>How much does a thought weigh?</td>\n",
645
+ " <td>I don't know. The context provided is about me...</td>\n",
646
+ " <td>5.253224</td>\n",
647
+ " <td>True</td>\n",
648
+ " </tr>\n",
649
+ " </tbody>\n",
650
+ "</table>\n",
651
+ "</div>"
652
+ ],
653
+ "text/plain": [
654
+ " Questions \\\n",
655
+ "0 What is the capital of Mars? \n",
656
+ "1 How many unicorns live in New York City? \n",
657
+ "2 What is the color of happiness? \n",
658
+ "3 Can cats fly on Tuesdays? \n",
659
+ "4 How much does a thought weigh? \n",
660
+ "\n",
661
+ " response latency \\\n",
662
+ "0 I don't know. The provided context does not se... 12.207266 \n",
663
+ "1 I don't know. The information provided does no... 2.368774 \n",
664
+ "2 I don't know! The provided context only talks ... 5.480067 \n",
665
+ "3 I don't know the answer to this question as it... 5.272529 \n",
666
+ "4 I don't know. The context provided is about me... 5.253224 \n",
667
+ "\n",
668
+ " irrelevant_score \n",
669
+ "0 True \n",
670
+ "1 True \n",
671
+ "2 True \n",
672
+ "3 True \n",
673
+ "4 True "
674
+ ]
675
+ },
676
+ "execution_count": 79,
677
+ "metadata": {},
678
+ "output_type": "execute_result"
679
+ }
680
+ ],
681
+ "source": [
682
+ "irr_q.head()"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": 37,
688
+ "id": "8620e50c",
689
+ "metadata": {},
690
+ "outputs": [
691
+ {
692
+ "data": {
693
+ "text/plain": [
694
+ "0 12.207266\n",
695
+ "1 2.368774\n",
696
+ "2 5.480067\n",
697
+ "3 5.272529\n",
698
+ "4 5.253224\n",
699
+ "5 5.351224\n",
700
+ "6 8.118429\n",
701
+ "7 7.288261\n",
702
+ "8 3.856500\n",
703
+ "9 7.745016\n",
704
+ "Name: latency, dtype: float64"
705
+ ]
706
+ },
707
+ "execution_count": 37,
708
+ "metadata": {},
709
+ "output_type": "execute_result"
710
+ }
711
+ ],
712
+ "source": [
713
+ "irr_q['latency']"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "execution_count": 39,
719
+ "id": "debd3461",
720
+ "metadata": {},
721
+ "outputs": [],
722
+ "source": [
723
+ "irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": 40,
729
+ "id": "bef1d3a4",
730
+ "metadata": {},
731
+ "outputs": [
732
+ {
733
+ "data": {
734
+ "text/plain": [
735
+ "irrelevant_score 0.900000\n",
736
+ "latency 6.294129\n",
737
+ "dtype: float64"
738
+ ]
739
+ },
740
+ "execution_count": 40,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "irr_q[['irrelevant_score','latency']].mean()"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "markdown",
751
+ "id": "c1610a70",
752
+ "metadata": {},
753
+ "source": [
754
+ "# Improvement"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 48,
760
+ "id": "ff6614f9",
761
+ "metadata": {},
762
+ "outputs": [],
763
+ "source": [
764
+ "new_prompt_template = \"\"\"\n",
765
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
766
+ "Use the provided context to answer the question short and accurately. \n",
767
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
768
+ "\n",
769
+ "Context:\n",
770
+ "{context}\n",
771
+ "\n",
772
+ "Question: {question}\n",
773
+ "\n",
774
+ "Answer:\"\"\"\n",
775
+ "\n",
776
+ "prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n",
777
+ "\n",
778
+ "llm = Ollama(\n",
779
+ " model=\"llama3\"\n",
780
+ ")\n",
781
+ "\n",
782
+ "# Create the chain\n",
783
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
784
+ "\n",
785
+ "def answer_question_new(query):\n",
786
+ " # Search for relevant context\n",
787
+ " search_results = search_faiss(query)\n",
788
+ " \n",
789
+ " # Combine the content from the search results\n",
790
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
791
+ "\n",
792
+ " # Run the chain\n",
793
+ " response = chain.run(context=context, question=query)\n",
794
+ " \n",
795
+ " return response"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": 49,
801
+ "id": "20580d50",
802
+ "metadata": {},
803
+ "outputs": [],
804
+ "source": [
805
+ "df2=df.copy()"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 50,
811
+ "id": "b1b3d725",
812
+ "metadata": {},
813
+ "outputs": [
814
+ {
815
+ "name": "stderr",
816
+ "output_type": "stream",
817
+ "text": [
818
+ "100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n"
819
+ ]
820
+ }
821
+ ],
822
+ "source": [
823
+ "time_list=[]\n",
824
+ "response_list=[]\n",
825
+ "for i in tqdm(range(len(df2))):\n",
826
+ " query = df2['Questions'].values[i]\n",
827
+ " start = time.time()\n",
828
+ " response = answer_question(query)\n",
829
+ " end = time.time() \n",
830
+ " time_list.append(end-start)\n",
831
+ " response_list.append(response)"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 51,
837
+ "id": "63f41256",
838
+ "metadata": {},
839
+ "outputs": [],
840
+ "source": [
841
+ "df2['latency'] = time_list\n",
842
+ "df2['response'] = response_list"
843
+ ]
844
+ },
845
+ {
846
+ "cell_type": "code",
847
+ "execution_count": 52,
848
+ "id": "0d8a6065",
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "name": "stderr",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n",
856
+ "100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n",
857
+ "100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n",
858
+ "100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n"
859
+ ]
860
+ }
861
+ ],
862
+ "source": [
863
+ "for metric in metrics:\n",
864
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
865
+ " \n",
866
+ " reasoning = []\n",
867
+ " value = []\n",
868
+ " score = []\n",
869
+ " \n",
870
+ " for i in tqdm(range(len(df2))):\n",
871
+ " eval_result = evaluator.evaluate_strings(\n",
872
+ " prediction=df2.response.values[i],\n",
873
+ " input=df2.Questions.values[i],\n",
874
+ " reference=df2.Answers.values[i]\n",
875
+ " )\n",
876
+ " reasoning.append(eval_result['reasoning'])\n",
877
+ " value.append(eval_result['value'])\n",
878
+ " score.append(eval_result['score'])\n",
879
+ " \n",
880
+ " df2[metric+'_reasoning'] = reasoning\n",
881
+ " df2[metric+'_value'] = value\n",
882
+ " df2[metric+'_score'] = score "
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": 77,
888
+ "id": "c648632c",
889
+ "metadata": {},
890
+ "outputs": [
891
+ {
892
+ "data": {
893
+ "text/html": [
894
+ "<div>\n",
895
+ "<style scoped>\n",
896
+ " .dataframe tbody tr th:only-of-type {\n",
897
+ " vertical-align: middle;\n",
898
+ " }\n",
899
+ "\n",
900
+ " .dataframe tbody tr th {\n",
901
+ " vertical-align: top;\n",
902
+ " }\n",
903
+ "\n",
904
+ " .dataframe thead th {\n",
905
+ " text-align: right;\n",
906
+ " }\n",
907
+ "</style>\n",
908
+ "<table border=\"1\" class=\"dataframe\">\n",
909
+ " <thead>\n",
910
+ " <tr style=\"text-align: right;\">\n",
911
+ " <th></th>\n",
912
+ " <th>Questions</th>\n",
913
+ " <th>Answers</th>\n",
914
+ " <th>latency</th>\n",
915
+ " <th>response</th>\n",
916
+ " <th>correctness_reasoning</th>\n",
917
+ " <th>correctness_value</th>\n",
918
+ " <th>correctness_score</th>\n",
919
+ " <th>relevance_reasoning</th>\n",
920
+ " <th>relevance_value</th>\n",
921
+ " <th>relevance_score</th>\n",
922
+ " <th>coherence_reasoning</th>\n",
923
+ " <th>coherence_value</th>\n",
924
+ " <th>coherence_score</th>\n",
925
+ " <th>conciseness_reasoning</th>\n",
926
+ " <th>conciseness_value</th>\n",
927
+ " <th>conciseness_score</th>\n",
928
+ " </tr>\n",
929
+ " </thead>\n",
930
+ " <tbody>\n",
931
+ " <tr>\n",
932
+ " <th>0</th>\n",
933
+ " <td>What is Mental Health</td>\n",
934
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
935
+ " <td>11.046327</td>\n",
936
+ " <td>Based on the context provided, mental health r...</td>\n",
937
+ " <td>Step 1: Evaluate if the submission is factuall...</td>\n",
938
+ " <td>N</td>\n",
939
+ " <td>0</td>\n",
940
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
941
+ " <td>N</td>\n",
942
+ " <td>0</td>\n",
943
+ " <td>The submission discusses mental health in rela...</td>\n",
944
+ " <td>Y</td>\n",
945
+ " <td>1</td>\n",
946
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
947
+ " <td>Y</td>\n",
948
+ " <td>1</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <th>1</th>\n",
952
+ " <td>What are the most common mental disorders ment...</td>\n",
953
+ " <td>The most common mental disorders include depre...</td>\n",
954
+ " <td>4.509713</td>\n",
955
+ " <td>The handbook mentions several mental illnesses...</td>\n",
956
+ " <td>The submission mentions depression and schizop...</td>\n",
957
+ " <td>N</td>\n",
958
+ " <td>0</td>\n",
959
+ " <td>Step 1: Analyze relevance criterion - Check if...</td>\n",
960
+ " <td>Y</td>\n",
961
+ " <td>1</td>\n",
962
+ " <td>Step 1: Assess coherence\\nThe submission menti...</td>\n",
963
+ " <td>N</td>\n",
964
+ " <td>0</td>\n",
965
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
966
+ " <td>N</td>\n",
967
+ " <td>0</td>\n",
968
+ " </tr>\n",
969
+ " <tr>\n",
970
+ " <th>2</th>\n",
971
+ " <td>What are the early warning signs and symptoms ...</td>\n",
972
+ " <td>Early warning signs and symptoms of depression...</td>\n",
973
+ " <td>8.501180</td>\n",
974
+ " <td>According to the provided context, specificall...</td>\n",
975
+ " <td>The submission matches the reference data in t...</td>\n",
976
+ " <td>Y</td>\n",
977
+ " <td>1</td>\n",
978
+ " <td>The submission refers directly to information ...</td>\n",
979
+ " <td>Y</td>\n",
980
+ " <td>1</td>\n",
981
+ " <td>Step 1: Evaluate coherence - The submission is...</td>\n",
982
+ " <td>Y</td>\n",
983
+ " <td>1</td>\n",
984
+ " <td>The submission is concise and includes most of...</td>\n",
985
+ " <td>Y</td>\n",
986
+ " <td>1</td>\n",
987
+ " </tr>\n",
988
+ " <tr>\n",
989
+ " <th>3</th>\n",
990
+ " <td>How can someone help a person who suffers from...</td>\n",
991
+ " <td>To help someone with anxiety, one can support ...</td>\n",
992
+ " <td>10.611402</td>\n",
993
+ " <td>According to the Mental Health Handbook, when ...</td>\n",
994
+ " <td>The submission seems consistent with the refer...</td>\n",
995
+ " <td>Y</td>\n",
996
+ " <td>1</td>\n",
997
+ " <td>Step 1: Review relevance criterion\\nThe submis...</td>\n",
998
+ " <td>Y</td>\n",
999
+ " <td>1</td>\n",
1000
+ " <td>The submission is coherent, well-structured, a...</td>\n",
1001
+ " <td>Y</td>\n",
1002
+ " <td>1</td>\n",
1003
+ " <td>The submission is relatively concise and cover...</td>\n",
1004
+ " <td>Y</td>\n",
1005
+ " <td>1</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <th>4</th>\n",
1009
+ " <td>What are the causes of mental illness listed i...</td>\n",
1010
+ " <td>Causes of mental illness include abnormal func...</td>\n",
1011
+ " <td>6.299272</td>\n",
1012
+ " <td>According to the context, the causes of mental...</td>\n",
1013
+ " <td>The submission lists causes such as neglect, s...</td>\n",
1014
+ " <td>N</td>\n",
1015
+ " <td>0</td>\n",
1016
+ " <td>The submission mentions factors that are part ...</td>\n",
1017
+ " <td>N</td>\n",
1018
+ " <td>0</td>\n",
1019
+ " <td>The submission is coherent and well-structured...</td>\n",
1020
+ " <td>Y</td>\n",
1021
+ " <td>1</td>\n",
1022
+ " <td>Step 1: Read and understand both the input dat...</td>\n",
1023
+ " <td>N</td>\n",
1024
+ " <td>0</td>\n",
1025
+ " </tr>\n",
1026
+ " </tbody>\n",
1027
+ "</table>\n",
1028
+ "</div>"
1029
+ ],
1030
+ "text/plain": [
1031
+ " Questions \\\n",
1032
+ "0 What is Mental Health \n",
1033
+ "1 What are the most common mental disorders ment... \n",
1034
+ "2 What are the early warning signs and symptoms ... \n",
1035
+ "3 How can someone help a person who suffers from... \n",
1036
+ "4 What are the causes of mental illness listed i... \n",
1037
+ "\n",
1038
+ " Answers latency \\\n",
1039
+ "0 Mental Health is a \" state of well-being in wh... 11.046327 \n",
1040
+ "1 The most common mental disorders include depre... 4.509713 \n",
1041
+ "2 Early warning signs and symptoms of depression... 8.501180 \n",
1042
+ "3 To help someone with anxiety, one can support ... 10.611402 \n",
1043
+ "4 Causes of mental illness include abnormal func... 6.299272 \n",
1044
+ "\n",
1045
+ " response \\\n",
1046
+ "0 Based on the context provided, mental health r... \n",
1047
+ "1 The handbook mentions several mental illnesses... \n",
1048
+ "2 According to the provided context, specificall... \n",
1049
+ "3 According to the Mental Health Handbook, when ... \n",
1050
+ "4 According to the context, the causes of mental... \n",
1051
+ "\n",
1052
+ " correctness_reasoning correctness_value \\\n",
1053
+ "0 Step 1: Evaluate if the submission is factuall... N \n",
1054
+ "1 The submission mentions depression and schizop... N \n",
1055
+ "2 The submission matches the reference data in t... Y \n",
1056
+ "3 The submission seems consistent with the refer... Y \n",
1057
+ "4 The submission lists causes such as neglect, s... N \n",
1058
+ "\n",
1059
+ " correctness_score relevance_reasoning \\\n",
1060
+ "0 0 Step 1: Analyze the relevance criterion\\nThe s... \n",
1061
+ "1 0 Step 1: Analyze relevance criterion - Check if... \n",
1062
+ "2 1 The submission refers directly to information ... \n",
1063
+ "3 1 Step 1: Review relevance criterion\\nThe submis... \n",
1064
+ "4 0 The submission mentions factors that are part ... \n",
1065
+ "\n",
1066
+ " relevance_value relevance_score \\\n",
1067
+ "0 N 0 \n",
1068
+ "1 Y 1 \n",
1069
+ "2 Y 1 \n",
1070
+ "3 Y 1 \n",
1071
+ "4 N 0 \n",
1072
+ "\n",
1073
+ " coherence_reasoning coherence_value \\\n",
1074
+ "0 The submission discusses mental health in rela... Y \n",
1075
+ "1 Step 1: Assess coherence\\nThe submission menti... N \n",
1076
+ "2 Step 1: Evaluate coherence - The submission is... Y \n",
1077
+ "3 The submission is coherent, well-structured, a... Y \n",
1078
+ "4 The submission is coherent and well-structured... Y \n",
1079
+ "\n",
1080
+ " coherence_score conciseness_reasoning \\\n",
1081
+ "0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1082
+ "1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1083
+ "2 1 The submission is concise and includes most of... \n",
1084
+ "3 1 The submission is relatively concise and cover... \n",
1085
+ "4 1 Step 1: Read and understand both the input dat... \n",
1086
+ "\n",
1087
+ " conciseness_value conciseness_score \n",
1088
+ "0 Y 1 \n",
1089
+ "1 N 0 \n",
1090
+ "2 Y 1 \n",
1091
+ "3 Y 1 \n",
1092
+ "4 N 0 "
1093
+ ]
1094
+ },
1095
+ "execution_count": 77,
1096
+ "metadata": {},
1097
+ "output_type": "execute_result"
1098
+ }
1099
+ ],
1100
+ "source": [
1101
+ "df2.head()"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "execution_count": 47,
1107
+ "id": "2d1002b2",
1108
+ "metadata": {},
1109
+ "outputs": [
1110
+ {
1111
+ "data": {
1112
+ "text/plain": [
1113
+ "correctness_score 0.500000\n",
1114
+ "relevance_score 0.888889\n",
1115
+ "coherence_score 0.888889\n",
1116
+ "conciseness_score 0.900000\n",
1117
+ "latency 8.190205\n",
1118
+ "dtype: float64"
1119
+ ]
1120
+ },
1121
+ "execution_count": 47,
1122
+ "metadata": {},
1123
+ "output_type": "execute_result"
1124
+ }
1125
+ ],
1126
+ "source": [
1127
+ "df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
1128
+ ]
1129
+ },
1130
+ {
1131
+ "cell_type": "markdown",
1132
+ "id": "e808bdcf",
1133
+ "metadata": {},
1134
+ "source": [
1135
+ "# Query relevance"
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "code",
1140
+ "execution_count": 66,
1141
+ "id": "6b541f3d",
1142
+ "metadata": {},
1143
+ "outputs": [],
1144
+ "source": [
1145
+ "def new_search_faiss(query, k=3, threshold=1.5):\n",
1146
+ " query_vector = model.encode([query])[0].astype('float32')\n",
1147
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
1148
+ " distances, indices = index.search(query_vector, k)\n",
1149
+ " \n",
1150
+ " results = []\n",
1151
+ " for dist, idx in zip(distances[0], indices[0]):\n",
1152
+ " if dist < threshold: # Only include results within the threshold distance\n",
1153
+ " results.append({\n",
1154
+ " 'distance': dist,\n",
1155
+ " 'content': sections_data[idx]['content'],\n",
1156
+ " 'metadata': sections_data[idx]['metadata']\n",
1157
+ " })\n",
1158
+ " \n",
1159
+ " return results"
1160
+ ]
1161
+ },
1162
+ {
1163
+ "cell_type": "code",
1164
+ "execution_count": 70,
1165
+ "id": "4f579654",
1166
+ "metadata": {},
1167
+ "outputs": [],
1168
+ "source": [
1169
+ "new_prompt_template = \"\"\"\n",
1170
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
1171
+ "Use the provided context to answer the question short and accurately. \n",
1172
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
1173
+ "\n",
1174
+ "Context:\n",
1175
+ "{context}\n",
1176
+ "\n",
1177
+ "Question: {question}\n",
1178
+ "\n",
1179
+ "Answer:\"\"\"\n",
1180
+ "\n",
1181
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
1182
+ "\n",
1183
+ "llm = Ollama(\n",
1184
+ " model=\"llama3\"\n",
1185
+ ")\n",
1186
+ "\n",
1187
+ "# Create the chain\n",
1188
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
1189
+ "\n",
1190
+ "def new_answer_question(query):\n",
1191
+ " # Search for relevant context\n",
1192
+ " search_results = new_search_faiss(query)\n",
1193
+ " \n",
1194
+ " if search_results==[]:\n",
1195
+ " response=\"I don't know, sorry\"\n",
1196
+ " else:\n",
1197
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
1198
+ " response = chain.run(context=context, question=query)\n",
1199
+ " \n",
1200
+ " return response"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "code",
1205
+ "execution_count": 71,
1206
+ "id": "1f83ef1b",
1207
+ "metadata": {},
1208
+ "outputs": [],
1209
+ "source": [
1210
+ "irr_q2=irr_q.copy()"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "execution_count": 72,
1216
+ "id": "f06474e3",
1217
+ "metadata": {},
1218
+ "outputs": [
1219
+ {
1220
+ "name": "stderr",
1221
+ "output_type": "stream",
1222
+ "text": [
1223
+ "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n"
1224
+ ]
1225
+ }
1226
+ ],
1227
+ "source": [
1228
+ "time_list=[]\n",
1229
+ "response_list=[]\n",
1230
+ "for i in tqdm(range(len(irr_q2))):\n",
1231
+ " query = irr_q['Questions'].values[i]\n",
1232
+ " start = time.time()\n",
1233
+ " response = new_answer_question(query)\n",
1234
+ " end = time.time() \n",
1235
+ " time_list.append(end-start)\n",
1236
+ " response_list.append(response)"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": 73,
1242
+ "id": "52db6b82",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": [
1246
+ "irr_q2['response']=response_list\n",
1247
+ "irr_q2['latency']=time_list"
1248
+ ]
1249
+ },
1250
+ {
1251
+ "cell_type": "code",
1252
+ "execution_count": 80,
1253
+ "id": "80a178ee",
1254
+ "metadata": {},
1255
+ "outputs": [
1256
+ {
1257
+ "data": {
1258
+ "text/html": [
1259
+ "<div>\n",
1260
+ "<style scoped>\n",
1261
+ " .dataframe tbody tr th:only-of-type {\n",
1262
+ " vertical-align: middle;\n",
1263
+ " }\n",
1264
+ "\n",
1265
+ " .dataframe tbody tr th {\n",
1266
+ " vertical-align: top;\n",
1267
+ " }\n",
1268
+ "\n",
1269
+ " .dataframe thead th {\n",
1270
+ " text-align: right;\n",
1271
+ " }\n",
1272
+ "</style>\n",
1273
+ "<table border=\"1\" class=\"dataframe\">\n",
1274
+ " <thead>\n",
1275
+ " <tr style=\"text-align: right;\">\n",
1276
+ " <th></th>\n",
1277
+ " <th>Questions</th>\n",
1278
+ " <th>response</th>\n",
1279
+ " <th>latency</th>\n",
1280
+ " <th>irrelevant_score</th>\n",
1281
+ " </tr>\n",
1282
+ " </thead>\n",
1283
+ " <tbody>\n",
1284
+ " <tr>\n",
1285
+ " <th>0</th>\n",
1286
+ " <td>What is the capital of Mars?</td>\n",
1287
+ " <td>I don't know, sorry</td>\n",
1288
+ " <td>0.061378</td>\n",
1289
+ " <td>True</td>\n",
1290
+ " </tr>\n",
1291
+ " <tr>\n",
1292
+ " <th>1</th>\n",
1293
+ " <td>How many unicorns live in New York City?</td>\n",
1294
+ " <td>I don't know, sorry</td>\n",
1295
+ " <td>0.012511</td>\n",
1296
+ " <td>True</td>\n",
1297
+ " </tr>\n",
1298
+ " <tr>\n",
1299
+ " <th>2</th>\n",
1300
+ " <td>What is the color of happiness?</td>\n",
1301
+ " <td>I don't know, sorry</td>\n",
1302
+ " <td>0.011900</td>\n",
1303
+ " <td>True</td>\n",
1304
+ " </tr>\n",
1305
+ " <tr>\n",
1306
+ " <th>3</th>\n",
1307
+ " <td>Can cats fly on Tuesdays?</td>\n",
1308
+ " <td>I don't know, sorry</td>\n",
1309
+ " <td>0.011438</td>\n",
1310
+ " <td>True</td>\n",
1311
+ " </tr>\n",
1312
+ " <tr>\n",
1313
+ " <th>4</th>\n",
1314
+ " <td>How much does a thought weigh?</td>\n",
1315
+ " <td>I don't know, sorry</td>\n",
1316
+ " <td>0.010644</td>\n",
1317
+ " <td>True</td>\n",
1318
+ " </tr>\n",
1319
+ " </tbody>\n",
1320
+ "</table>\n",
1321
+ "</div>"
1322
+ ],
1323
+ "text/plain": [
1324
+ " Questions response latency \\\n",
1325
+ "0 What is the capital of Mars? I don't know, sorry 0.061378 \n",
1326
+ "1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n",
1327
+ "2 What is the color of happiness? I don't know, sorry 0.011900 \n",
1328
+ "3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n",
1329
+ "4 How much does a thought weigh? I don't know, sorry 0.010644 \n",
1330
+ "\n",
1331
+ " irrelevant_score \n",
1332
+ "0 True \n",
1333
+ "1 True \n",
1334
+ "2 True \n",
1335
+ "3 True \n",
1336
+ "4 True "
1337
+ ]
1338
+ },
1339
+ "execution_count": 80,
1340
+ "metadata": {},
1341
+ "output_type": "execute_result"
1342
+ }
1343
+ ],
1344
+ "source": [
1345
+ "irr_q2.head()"
1346
+ ]
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "execution_count": 74,
1351
+ "id": "4508de9e",
1352
+ "metadata": {},
1353
+ "outputs": [],
1354
+ "source": [
1355
+ "irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")"
1356
+ ]
1357
+ },
1358
+ {
1359
+ "cell_type": "code",
1360
+ "execution_count": 75,
1361
+ "id": "3d34ba06",
1362
+ "metadata": {},
1363
+ "outputs": [
1364
+ {
1365
+ "data": {
1366
+ "text/plain": [
1367
+ "irrelevant_score 1.000000\n",
1368
+ "latency 0.016068\n",
1369
+ "dtype: float64"
1370
+ ]
1371
+ },
1372
+ "execution_count": 75,
1373
+ "metadata": {},
1374
+ "output_type": "execute_result"
1375
+ }
1376
+ ],
1377
+ "source": [
1378
+ "irr_q2[['irrelevant_score','latency']].mean()"
1379
+ ]
1380
+ }
1381
+ ],
1382
+ "metadata": {
1383
+ "kernelspec": {
1384
+ "display_name": "Python 3 (ipykernel)",
1385
+ "language": "python",
1386
+ "name": "python3"
1387
+ },
1388
+ "language_info": {
1389
+ "codemirror_mode": {
1390
+ "name": "ipython",
1391
+ "version": 3
1392
+ },
1393
+ "file_extension": ".py",
1394
+ "mimetype": "text/x-python",
1395
+ "name": "python",
1396
+ "nbconvert_exporter": "python",
1397
+ "pygments_lexer": "ipython3",
1398
+ "version": "3.11.0"
1399
+ }
1400
+ },
1401
+ "nbformat": 4,
1402
+ "nbformat_minor": 5
1403
+ }
Evaluation_MH/Mental Health Evaluation Report.pdf ADDED
Binary file (72.9 kB). View file
 
Evaluation_MH/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mental Health Evaluations
2
+
3
+ ## Overview
4
+ This project focuses on evaluating the performance and latency of our Retrieval-Augmented Generation (RAG) application in the context of mental health. We implemented and evaluated several methods to enhance the system's performance.
5
+
6
+ ## Requirements
7
+
8
+ Evaluation data: All the datasets, including Ground Truth and Unrelated queries used in the evaluation notebook, are present in the Data folder.
9
+ Evaluation LLM: Phi3 is used for evaluation as we have used Llama3 to develop our application. Make sure to run the following after you have installed Ollama:
10
+
11
+ ```
12
+ Ollama pull phi3
13
+ ```
14
+ ## Detailed Report & Demonstration
15
+ * For a comprehensive analysis of the methodology, improvements, results, and future work, please refer to our : Mental Health Evaluation Report
16
+ * Video Demonstration: [Mental Health Evaluation Video](https://youtu.be/XUXMPrq55oU)
17
+
18
+ ## Key Improvements
19
+ 1. Modified prompt for concise answers
20
+ 2. Implemented threshold-based retrieval
21
+
22
+ ## Results
23
+ Our improvements led to:
24
+ - Increased accuracy and relevance scores
25
+ - Significant reduction in latency, especially for irrelevant queries
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Aditi Yadav
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MentalHealth/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Aditi Yadav
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MentalHealth/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NutriNudge
2
+ This project implements a question-answering system for dietary guidelines using a combination of document embedding, vector search, and language model-based generation.
3
+
4
+ ## Setup
5
+
6
+ 1. Clone the repository:
7
+ ```
8
+ git clone https://github.com/chakraborty-arnab/NutriNudge.git
9
+ ```
10
+ 2. Move into the Folder:
11
+ ```
12
+ cd NutriNudge
13
+ ```
14
+ 3. Install the required dependencies:
15
+ ```
16
+ pip install -r requirements.txt
17
+ ```
18
+ 4. Download and install Ollama from [Ollama](https://ollama.com/)
19
+
20
+ 5. Pull the Llama3 model using Ollama:
21
+ ```
22
+ ollama pull llama3
23
+ ```
24
+
25
+ ## Creating the Vector Database
26
+
27
+ 1. Ensure you have the PDF file `Dietary_Guidelines_for_Americans_2020-2025.pdf` in the `data/` directory.
28
+
29
+ 2. Run the script to create the vector database:
30
+ ```
31
+ python create_vectordb.py
32
+ ```
33
+ This will create two files in the `database/` directory:
34
+ - `pdf_sections_index.faiss`: The FAISS index file
35
+ - `pdf_sections_data.pkl`: The pickle file containing section data
36
+
37
+ ## Running the Application
38
+
39
+ Start the Streamlit app:
40
+ ```
41
+ streamlit run app.py
42
+ ```
43
+ The application will be accessible in your web browser at `http://localhost:8501`.
44
+
45
+ ## Usage
46
+
47
+ 1. Enter your question about dietary guidelines in the text input field.
48
+ 2. Click the "Get Answer" button.
49
+ 3. The system will search for relevant information and generate an answer.
50
+ 4. The answer will be displayed, and you can expand the "Show Context" section to see the relevant text used to generate the answer.
51
+
52
+ ## Files
53
+
54
+ - `create_vectordb.py`: Script to create the vector database from the PDF document.
55
+ - `app.py`: The main Streamlit application for the Q&A system.
56
+ - `data/`: Directory containing the source PDF.
57
+ - `database/`: Directory where the vector database files are stored.
58
+
59
+ ## Note
60
+
61
+ Ensure that you have sufficient disk space and computational resources to run the vector database creation and the Streamlit application. The performance may vary depending on your hardware capabilities.
MentalHealth/app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains import LLMChain
5
+ from langchain_community.llms import Ollama
6
+ import faiss
7
+ import numpy as np
8
+ import pickle
9
+
10
+ # Load the FAISS index
11
+ @st.cache(allow_output_mutation=True)
12
+ def load_faiss_index():
13
+ try:
14
+ return faiss.read_index("database/pdf_sections_index.faiss")
15
+ except FileNotFoundError:
16
+ st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
17
+ st.stop()
18
+
19
+ # Load the embedding model
20
+ @st.cache(allow_output_mutation=True)
21
+ def load_embedding_model():
22
+ return SentenceTransformer('all-MiniLM-L6-v2')
23
+
24
+ # Load sections data
25
+ @st.cache(allow_output_mutation=True)
26
+ def load_sections_data():
27
+ try:
28
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
29
+ return pickle.load(f)
30
+ except FileNotFoundError:
31
+ st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
32
+ st.stop()
33
+
34
+ # Initialize resources
35
+ index = load_faiss_index()
36
+ model = load_embedding_model()
37
+ sections_data = load_sections_data()
38
+
39
+ def search_faiss(query, k=3):
40
+ query_vector = model.encode([query])[0].astype('float32')
41
+ query_vector = np.expand_dims(query_vector, axis=0)
42
+ distances, indices = index.search(query_vector, k)
43
+
44
+ results = []
45
+ for dist, idx in zip(distances[0], indices[0]):
46
+ results.append({
47
+ 'distance': dist,
48
+ 'content': sections_data[idx]['content'],
49
+ 'metadata': sections_data[idx]['metadata']
50
+ })
51
+
52
+ return results
53
+
54
+ prompt_template = """
55
+ You are an AI assistant specialized in dietary guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
56
+
57
+ Context:
58
+ {context}
59
+
60
+ Question: {question}
61
+
62
+ Answer:"""
63
+
64
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
65
+
66
+ @st.cache(allow_output_mutation=True)
67
+ def load_llm():
68
+ return Ollama(model="llama3")
69
+
70
+ llm = load_llm()
71
+ chain = LLMChain(llm=llm, prompt=prompt)
72
+
73
+ def answer_question(query):
74
+ search_results = search_faiss(query)
75
+ context = "\n\n".join([result['content'] for result in search_results])
76
+ response = chain.run(context=context, question=query)
77
+ return response, context
78
+
79
+ # Streamlit UI
80
+ st.title("Mental Health Guidelines Q&A")
81
+
82
+ query = st.text_input("Enter your question about Mental Health guidelines:")
83
+
84
+ if st.button("Get Answer"):
85
+ if query:
86
+ with st.spinner("Searching and generating answer..."):
87
+ answer, context = answer_question(query)
88
+ st.subheader("Answer:")
89
+ st.write(answer)
90
+ with st.expander("Show Context"):
91
+ st.write(context)
92
+ else:
93
+ st.warning("Please enter a question.")
MentalHealth/create_vectordb.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import pickle
7
+
8
+ # Load the PDF
9
+ pdf_path = "data\Mental Health Handbook English.pdf"
10
+ loader = PyPDFLoader(file_path=pdf_path)
11
+
12
+ # Load the content
13
+ documents = loader.load()
14
+
15
+ # Split the document into sections
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
17
+ sections = text_splitter.split_documents(documents)
18
+
19
+ # Load the embedding model
20
+ model = SentenceTransformer('all-MiniLM-L6-v2')
21
+
22
+ # Generate embeddings for each section
23
+ section_texts = [section.page_content for section in sections]
24
+ embeddings = model.encode(section_texts)
25
+
26
+ print(embeddings.shape)
27
+
28
+ embeddings_np = np.array(embeddings).astype('float32')
29
+
30
+ # Create a FAISS index
31
+ dimension = embeddings_np.shape[1]
32
+ index = faiss.IndexFlatL2(dimension)
33
+
34
+ # Add vectors to the index
35
+ index.add(embeddings_np)
36
+
37
+ # Save the index to a file
38
+ faiss.write_index(index, "database/pdf_sections_index.faiss")
39
+
40
+ # When creating the index:
41
+ sections_data = [
42
+ {
43
+ 'content': section.page_content,
44
+ 'metadata': section.metadata
45
+ }
46
+ for section in sections
47
+ ]
48
+
49
+ # Save sections data
50
+ with open('database/pdf_sections_data.pkl', 'wb') as f:
51
+ pickle.dump(sections_data, f)
52
+
53
+ print("Embeddings stored in FAISS index and saved to file.")
MentalHealth/data/Mental Health Handbook English.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19da603f69fff5a4bc28a04fde30cf977f8fdb8310e9e31f6d21f4c45240c14b
3
+ size 5413709
MentalHealth/database/pdf_sections_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ceb3d84f382b1162df9c1b91f285c167411642572d56999c6bd1cd6b0dd2d7
3
+ size 60012
MentalHealth/database/pdf_sections_index.faiss ADDED
Binary file (66.1 kB). View file
 
MentalHealth/rag.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.chains import LLMChain
4
+ from langchain_community.llms import Ollama
5
+ import faiss
6
+ import numpy as np
7
+ import pickle
8
+
9
+ # Load the FAISS index
10
+ try:
11
+ index = faiss.read_index("database/pdf_sections_index.faiss")
12
+ except FileNotFoundError:
13
+ print("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
14
+ exit(1)
15
+
16
+ # Load the embedding model
17
+ model = SentenceTransformer('all-MiniLM-L6-v2')
18
+
19
+ # Load sections data
20
+ try:
21
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
22
+ sections_data = pickle.load(f)
23
+ except FileNotFoundError:
24
+ print("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
25
+ exit(1)
26
+
27
+ def search_faiss(query, k=3):
28
+ query_vector = model.encode([query])[0].astype('float32')
29
+ query_vector = np.expand_dims(query_vector, axis=0)
30
+ distances, indices = index.search(query_vector, k)
31
+
32
+ results = []
33
+ for dist, idx in zip(distances[0], indices[0]):
34
+ results.append({
35
+ 'distance': dist,
36
+ 'content': sections_data[idx]['content'],
37
+ 'metadata': sections_data[idx]['metadata']
38
+ })
39
+
40
+ return results
41
+
42
+ # Create a prompt template
43
+ prompt_template = """
44
+ You are an AI assistant specialized in dietary guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
45
+
46
+ Context:
47
+ {context}
48
+
49
+ Question: {question}
50
+
51
+ Answer:"""
52
+
53
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
54
+
55
+ llm = Ollama(
56
+ model="llama3"
57
+ )
58
+
59
+ # Create the chain
60
+ chain = LLMChain(llm=llm, prompt=prompt)
61
+
62
+ def answer_question(query):
63
+ # Search for relevant context
64
+ search_results = search_faiss(query)
65
+
66
+ # Combine the content from the search results
67
+ context = "\n\n".join([result['content'] for result in search_results])
68
+
69
+ # Run the chain
70
+ response = chain.run(context=context, question=query)
71
+
72
+ return response
73
+
74
+ # Example usage
75
+ query = "What is Mental Health?"
76
+ answer = answer_question(query)
77
+
78
+ print(f"Question: {query}")
79
+ print(f"Answer: {answer}")
MentalHealth/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pypdf
3
+ langchain
4
+ sentence-transformers
5
+ langchain-community
6
+ opensearch-py
7
+ faiss-cpu
8
+
MentalHealth/simple_retrieval.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+
5
+ # Load the FAISS_index
6
+ index = faiss.read_index("database/pdf_sections_index.faiss")
7
+
8
+ # Load the embedding model
9
+ model = SentenceTransformer('all-MiniLM-L6-v2')
10
+
11
+ def search_faiss(query, k=3):
12
+ query_vector = model.encode([query])[0].astype('float32')
13
+ query_vector = np.expand_dims(query_vector, axis=0)
14
+ distances, indices = index.search(query_vector, k)
15
+ return distances, indices
16
+
17
+ # Example usage
18
+ query = "What are the main dietary guidelines for protein intake?"
19
+ distances, indices = search_faiss(query)
20
+
21
+ print(f"Query: {query}")
22
+ print(f"Distances: {distances}")
23
+ print(f"Indices: {indices}")
README.md CHANGED
@@ -1,12 +1,57 @@
1
- ---
2
- title: MentalHealth Chatbot
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.37.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MentalHealth
2
+ The Mental Health project develops an assistant system for addressing mental health and wellness inquiries by leveraging document embedding, vector search, and language model-based generation using the phi3 model.
3
+
4
+ ## Setup
5
+
6
+ 1. Clone the repository:
7
+ ```
8
+ git clone
9
+ ```
10
+ 2. Install the required dependencies:
11
+ ```
12
+ pip install -r requirements.txt
13
+ ```
14
+ 3. Download and install Ollama from [Ollama](https://ollama.com/)
15
+
16
+ 4. Pull the phi3 model using Ollama:
17
+ ```
18
+ ollama pull phi3
19
+ ```
20
+
21
+ ## Creating the Vector Database
22
+
23
+ 1. Ensure you have the PDF file `Mental Health Handbook English.pdf` in the `data/` directory.
24
+
25
+ 2. Run the script to create the vector database:
26
+ ```
27
+ python create_vectordb.py
28
+ ```
29
+ This will create two files in the `database/` directory:
30
+ - `pdf_sections_index.faiss`: The FAISS index file
31
+ - `pdf_sections_data.pkl`: The pickle file containing section data
32
+
33
+ ## Running the Application
34
+
35
+ Start the Streamlit app:
36
+ ```
37
+ streamlit run app.py
38
+ ```
39
+ The application will be accessible in your web browser at `http://localhost:8501`.
40
+
41
+ ## Usage
42
+
43
+ 1. Type your mental health and wellness query into the text input box.
44
+ 2. Press the "Get Answer" button.
45
+ 3. The system will find relevant information and create a response.
46
+ 4. The response will be shown, and you can click the "Show Context" section to view the text used to generate the response.
47
+
48
+ ## Files
49
+
50
+ - `create_vectordb.py`: Script for generating the vector database from the PDF document.
51
+ - `app.py`: The primary Streamlit application for the Q&A system.
52
+ - `data/`: Directory holding the source PDF.
53
+ - `database/`: Directory for storing the vector database files.
54
+
55
+ ## Note
56
+
57
+ Make sure you have enough disk space and computational resources to run the vector database creation and the Streamlit application. Performance may vary based on your hardware capabilities.
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains import LLMChain
5
+ from langchain_community.llms import Ollama
6
+ import faiss
7
+ import numpy as np
8
+ import pickle
9
+ import requests
10
+ import json
11
+
12
+ # Load the FAISS index
13
+ @st.cache(allow_output_mutation=True)
14
+ def load_faiss_index():
15
+ try:
16
+ return faiss.read_index("database/pdf_sections_index.faiss")
17
+ except FileNotFoundError:
18
+ st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
19
+ st.stop()
20
+
21
+ # Load the embedding model
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_embedding_model():
24
+ return SentenceTransformer('all-MiniLM-L6-v2')
25
+
26
+ # Load sections data
27
+ @st.cache(allow_output_mutation=True)
28
+ def load_sections_data():
29
+ try:
30
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
31
+ return pickle.load(f)
32
+ except FileNotFoundError:
33
+ st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
34
+ st.stop()
35
+
36
+ # Initialize resources
37
+ index = load_faiss_index()
38
+ model = load_embedding_model()
39
+ sections_data = load_sections_data()
40
+
41
+ def search_faiss(query, k=3):
42
+ query_vector = model.encode([query])[0].astype('float32')
43
+ query_vector = np.expand_dims(query_vector, axis=0)
44
+ distances, indices = index.search(query_vector, k)
45
+
46
+ results = []
47
+ for dist, idx in zip(distances[0], indices[0]):
48
+ results.append({
49
+ 'distance': dist,
50
+ 'content': sections_data[idx]['content'],
51
+ 'metadata': sections_data[idx]['metadata']
52
+ })
53
+
54
+ return results
55
+
56
+ prompt_template = """
57
+ You are an AI assistant specialized in Mental Health & wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
58
+
59
+ Context:
60
+ {context}
61
+
62
+ Question: {question}
63
+
64
+ Answer:"""
65
+
66
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
67
+
68
+ @st.cache(allow_output_mutation=True)
69
+ def load_llm():
70
+ return Ollama(model="phi3")
71
+
72
+ llm = load_llm()
73
+ chain = LLMChain(llm=llm, prompt=prompt)
74
+
75
+ def answer_question(query):
76
+ search_results = search_faiss(query)
77
+ context = "\n\n".join([result['content'] for result in search_results])
78
+ response = chain.run(context=context, question=query)
79
+ return response, context
80
+
81
+ # Streamlit UI
82
+ st.title("Mental Health & Wellness Assistant")
83
+
84
+ query = st.text_input("Enter your question about Mental Health:")
85
+
86
+ if st.button("Get Answer"):
87
+ if query:
88
+ with st.spinner("Searching, Thinking and generating answer..."):
89
+ answer, context = answer_question(query)
90
+ st.subheader("Answer:")
91
+ st.write(answer)
92
+ with st.expander("Show Context"):
93
+ st.write(context)
94
+ else:
95
+ st.warning("Please enter a question.")
96
+
97
+ # Footer section with social links
98
+ st.markdown("""
99
+ <div class="social-icons">
100
+ <a href="https://github.com/yadavadit" target="_blank"><img src="https://img.icons8.com/material-outlined/48/e50914/github.png"/></a>
101
+ <a href="https://www.linkedin.com/in/yaditi/" target="_blank"><img src="https://img.icons8.com/color/48/e50914/linkedin.png"/></a>
102
+ <a href="mailto:yadavadit@northeastern.edu"><img src="https://img.icons8.com/color/48/e50914/gmail.png"/></a>
103
+ </div>
104
+ """, unsafe_allow_html=True)
create_vectordb.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import pickle
7
+
8
+ # Load the PDF
9
+ pdf_path = "data\Mental Health Handbook English.pdf"
10
+ loader = PyPDFLoader(file_path=pdf_path)
11
+
12
+ # Load the content
13
+ documents = loader.load()
14
+
15
+ # Split the document into sections
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
17
+ sections = text_splitter.split_documents(documents)
18
+
19
+ # Load the embedding model
20
+ model = SentenceTransformer('all-MiniLM-L6-v2')
21
+
22
+ # Generate embeddings for each section
23
+ section_texts = [section.page_content for section in sections]
24
+ embeddings = model.encode(section_texts)
25
+
26
+ print(embeddings.shape)
27
+
28
+ embeddings_np = np.array(embeddings).astype('float32')
29
+
30
+ # Create a FAISS index
31
+ dimension = embeddings_np.shape[1]
32
+ index = faiss.IndexFlatL2(dimension)
33
+
34
+ # Add vectors to the index
35
+ index.add(embeddings_np)
36
+
37
+ # Save the index to a file
38
+ faiss.write_index(index, "database/pdf_sections_index.faiss")
39
+
40
+ # When creating the index:
41
+ sections_data = [
42
+ {
43
+ 'content': section.page_content,
44
+ 'metadata': section.metadata
45
+ }
46
+ for section in sections
47
+ ]
48
+
49
+ # Save sections data
50
+ with open('database/pdf_sections_data.pkl', 'wb') as f:
51
+ pickle.dump(sections_data, f)
52
+
53
+ print("Embeddings stored in FAISS index and saved to file.")
data/Mental Health Handbook English.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19da603f69fff5a4bc28a04fde30cf977f8fdb8310e9e31f6d21f4c45240c14b
3
+ size 5413709
data/MentalHealth_Dataset.xlsx ADDED
Binary file (17.7 kB). View file
 
database/pdf_sections_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ceb3d84f382b1162df9c1b91f285c167411642572d56999c6bd1cd6b0dd2d7
3
+ size 60012
database/pdf_sections_index.faiss ADDED
Binary file (66.1 kB). View file
 
rag.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.chains import LLMChain
4
+ from langchain_community.llms import Ollama
5
+ import faiss
6
+ import numpy as np
7
+ import pickle
8
+
9
+ # Load the FAISS index
10
+ try:
11
+ index = faiss.read_index("database/pdf_sections_index.faiss")
12
+ except FileNotFoundError:
13
+ print("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
14
+ exit(1)
15
+
16
+ # Load the embedding model
17
+ model = SentenceTransformer('all-MiniLM-L6-v2')
18
+
19
+ # Load sections data
20
+ try:
21
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
22
+ sections_data = pickle.load(f)
23
+ except FileNotFoundError:
24
+ print("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
25
+ exit(1)
26
+
27
+ def search_faiss(query, k=3):
28
+ query_vector = model.encode([query])[0].astype('float32')
29
+ query_vector = np.expand_dims(query_vector, axis=0)
30
+ distances, indices = index.search(query_vector, k)
31
+
32
+ results = []
33
+ for dist, idx in zip(distances[0], indices[0]):
34
+ results.append({
35
+ 'distance': dist,
36
+ 'content': sections_data[idx]['content'],
37
+ 'metadata': sections_data[idx]['metadata']
38
+ })
39
+
40
+ return results
41
+
42
+ # Create a prompt template
43
+ prompt_template = """
44
+ You are an AI assistant specialized in mental health and wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
45
+
46
+ Context:
47
+ {context}
48
+
49
+ Question: {question}
50
+
51
+ Answer:"""
52
+
53
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
54
+
55
+ llm = Ollama(
56
+ model="phi3"
57
+ )
58
+
59
+ # Create the chain
60
+ chain = LLMChain(llm=llm, prompt=prompt)
61
+
62
+ def answer_question(query):
63
+ # Search for relevant context
64
+ search_results = search_faiss(query)
65
+
66
+ # Combine the content from the search results
67
+ context = "\n\n".join([result['content'] for result in search_results])
68
+
69
+ # Run the chain
70
+ response = chain.run(context=context, question=query)
71
+
72
+ return response
73
+
74
+ # Example usage
75
+ query = "What is mental health?"
76
+ answer = answer_question(query)
77
+
78
+ print(f"Question: {query}")
79
+ print(f"Answer: {answer}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ evaluate
3
+ pypdf
4
+ langchain
5
+ sentence-transformers
6
+ langchain-community
7
+ opensearch-py
8
+ faiss-cpu
9
+ accelerate
10
+ bert_score
11
+
simple_retrieval.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+
5
+ # Load the FAISS index
6
+ index = faiss.read_index("database/pdf_sections_index.faiss")
7
+
8
+ # Load the embedding model
9
+ model = SentenceTransformer('all-MiniLM-L6-v2')
10
+
11
+ def search_faiss(query, k=3):
12
+ query_vector = model.encode([query])[0].astype('float32')
13
+ query_vector = np.expand_dims(query_vector, axis=0)
14
+ distances, indices = index.search(query_vector, k)
15
+ return distances, indices
16
+
17
+ # Example usage
18
+ query = "What is mental Health?"
19
+ distances, indices = search_faiss(query)
20
+
21
+ print(f"Query: {query}")
22
+ print(f"Distances: {distances}")
23
+ print(f"Indices: {indices}")