codemogul commited on
Commit
03edd98
·
1 Parent(s): b865bd7

Upload model.ipynb

Browse files
Files changed (1) hide show
  1. model.ipynb +297 -0
model.ipynb ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "id": "ace57031",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/html": [
12
+ "<div>\n",
13
+ "<style scoped>\n",
14
+ " .dataframe tbody tr th:only-of-type {\n",
15
+ " vertical-align: middle;\n",
16
+ " }\n",
17
+ "\n",
18
+ " .dataframe tbody tr th {\n",
19
+ " vertical-align: top;\n",
20
+ " }\n",
21
+ "\n",
22
+ " .dataframe thead th {\n",
23
+ " text-align: right;\n",
24
+ " }\n",
25
+ "</style>\n",
26
+ "<table border=\"1\" class=\"dataframe\">\n",
27
+ " <thead>\n",
28
+ " <tr style=\"text-align: right;\">\n",
29
+ " <th></th>\n",
30
+ " <th>Question_ID</th>\n",
31
+ " <th>Questions</th>\n",
32
+ " <th>Answers</th>\n",
33
+ " </tr>\n",
34
+ " </thead>\n",
35
+ " <tbody>\n",
36
+ " <tr>\n",
37
+ " <th>0</th>\n",
38
+ " <td>1590140</td>\n",
39
+ " <td>What does it mean to have a mental illness?</td>\n",
40
+ " <td>Mental illnesses are health conditions that di...</td>\n",
41
+ " </tr>\n",
42
+ " <tr>\n",
43
+ " <th>1</th>\n",
44
+ " <td>2110618</td>\n",
45
+ " <td>Who does mental illness affect?</td>\n",
46
+ " <td>It is estimated that mental illness affects 1 ...</td>\n",
47
+ " </tr>\n",
48
+ " <tr>\n",
49
+ " <th>2</th>\n",
50
+ " <td>6361820</td>\n",
51
+ " <td>What causes mental illness?</td>\n",
52
+ " <td>It is estimated that mental illness affects 1 ...</td>\n",
53
+ " </tr>\n",
54
+ " <tr>\n",
55
+ " <th>3</th>\n",
56
+ " <td>9434130</td>\n",
57
+ " <td>What are some of the warning signs of mental i...</td>\n",
58
+ " <td>Symptoms of mental health disorders vary depen...</td>\n",
59
+ " </tr>\n",
60
+ " <tr>\n",
61
+ " <th>4</th>\n",
62
+ " <td>7657263</td>\n",
63
+ " <td>Can people with mental illness recover?</td>\n",
64
+ " <td>When healing from mental illness, early identi...</td>\n",
65
+ " </tr>\n",
66
+ " </tbody>\n",
67
+ "</table>\n",
68
+ "</div>"
69
+ ],
70
+ "text/plain": [
71
+ " Question_ID Questions \\\n",
72
+ "0 1590140 What does it mean to have a mental illness? \n",
73
+ "1 2110618 Who does mental illness affect? \n",
74
+ "2 6361820 What causes mental illness? \n",
75
+ "3 9434130 What are some of the warning signs of mental i... \n",
76
+ "4 7657263 Can people with mental illness recover? \n",
77
+ "\n",
78
+ " Answers \n",
79
+ "0 Mental illnesses are health conditions that di... \n",
80
+ "1 It is estimated that mental illness affects 1 ... \n",
81
+ "2 It is estimated that mental illness affects 1 ... \n",
82
+ "3 Symptoms of mental health disorders vary depen... \n",
83
+ "4 When healing from mental illness, early identi... "
84
+ ]
85
+ },
86
+ "execution_count": 8,
87
+ "metadata": {},
88
+ "output_type": "execute_result"
89
+ }
90
+ ],
91
+ "source": [
92
+ "from sklearn.feature_extraction.text import TfidfVectorizer\n",
93
+ "from sklearn.model_selection import train_test_split\n",
94
+ "from sklearn.linear_model import LogisticRegression\n",
95
+ "from sklearn.metrics import accuracy_score\n",
96
+ "import pandas as pd\n",
97
+ "import numpy as np\n",
98
+ "import torch\n",
99
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
100
+ "from huggingface_hub import notebook_login\n",
101
+ "# notebook_login()\n",
102
+ "# Step 1: Collect and preprocess data\n",
103
+ "# Get all the questions from Questions column and responses from Questions column in the dataset data.csv\n",
104
+ "# questions = data[\"Questions\"].tolist()\n",
105
+ "# responses = data[\"Responses\"].tolist()\n",
106
+ "questions = []\n",
107
+ "responses = []\n",
108
+ "q_id = []\n",
109
+ "with open(\"mental_health_bot.csv\", \"r\") as f:\n",
110
+ " for line in f:\n",
111
+ " \n",
112
+ " array = line.split(\",\") \n",
113
+ " # questions.append(question)\n",
114
+ " # responses.append(response)\n",
115
+ " # q_id.append(question_id)\n",
116
+ " try:\n",
117
+ " question = array[1]\n",
118
+ " response = array[2]\n",
119
+ " question_id = array[0]\n",
120
+ " questions.append(question)\n",
121
+ " responses.append(response)\n",
122
+ " q_id.append(question_id)\n",
123
+ " except:\n",
124
+ " pass\n",
125
+ "\n",
126
+ "data = pd.read_csv(\"data.csv\")\n",
127
+ "data.head()"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 9,
133
+ "id": "8f51e39d",
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "name": "stdout",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "missing values: Question_ID 0\n",
141
+ "Questions 0\n",
142
+ "Answers 0\n",
143
+ "dtype: int64\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "print('missing values:', data.isnull().sum())"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 10,
154
+ "id": "1d697a39",
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "name": "stdout",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "<class 'pandas.core.frame.DataFrame'>\n",
162
+ "RangeIndex: 149 entries, 0 to 148\n",
163
+ "Data columns (total 3 columns):\n",
164
+ " # Column Non-Null Count Dtype \n",
165
+ "--- ------ -------------- ----- \n",
166
+ " 0 Question_ID 149 non-null object\n",
167
+ " 1 Questions 149 non-null object\n",
168
+ " 2 Answers 149 non-null object\n",
169
+ "dtypes: object(3)\n",
170
+ "memory usage: 3.6+ KB\n",
171
+ "None\n"
172
+ ]
173
+ }
174
+ ],
175
+ "source": [
176
+ "print(data.info())"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 12,
182
+ "id": "c5dde0e4",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "Accuracy: 0.03333333333333333\n"
190
+ ]
191
+ }
192
+ ],
193
+ "source": [
194
+ "# print(questions)\n",
195
+ "# print(responses)\n",
196
+ "\n",
197
+ "\n",
198
+ "# questions = [\"What are some symptoms of depression?\",\n",
199
+ "# \"How can I manage my anxiety?\",\n",
200
+ "# \"What are the treatments for bipolar disorder?\"]\n",
201
+ "# responses = [\"Symptoms of depression include sadness, lack of energy, and loss of interest in activities.\",\n",
202
+ "# \"You can manage your anxiety through techniques such as deep breathing, meditation, and therapy.\",\n",
203
+ "# \"Treatments for bipolar disorder include medication, therapy, and lifestyle changes.\"]\n",
204
+ "\n",
205
+ "vectorizer = TfidfVectorizer()\n",
206
+ "X = vectorizer.fit_transform(questions)\n",
207
+ "y = responses\n",
208
+ "\n",
209
+ "# Step 2: Split data into training and testing sets\n",
210
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
211
+ "\n",
212
+ "# Step 3: Choose a machine learning algorithm\n",
213
+ "model = LogisticRegression()\n",
214
+ "\n",
215
+ "# Step 4: Train the model\n",
216
+ "model.fit(X_train, y_train)\n",
217
+ "\n",
218
+ "model.push_to_hub(\"tabibu-ai/mental-health-chatbot\")\n",
219
+ "pt_model = DistilBertForSequenceClassification.from_pretrained(\"model.ipynb\", from_tf=True)\n",
220
+ "pt_model.save_pretrained(\"model.ipynb\")\n",
221
+ "# load model from hub\n",
222
+ "\n",
223
+ "# Step 5: Evaluate the model\n",
224
+ "y_pred = model.predict(X_test)\n",
225
+ "accuracy = accuracy_score(y_test, y_pred)\n",
226
+ "print(\"Accuracy:\", accuracy)\n",
227
+ "\n",
228
+ "# Step 6: Use the model to make predictions\n",
229
+ "\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 18,
235
+ "id": "14406312",
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "name": "stdout",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "Ask me anything : I feel sad\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "new_question = input(\"Ask me anything : \")\n"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 17,
253
+ "id": "6b9198db",
254
+ "metadata": {},
255
+ "outputs": [
256
+ {
257
+ "name": "stdout",
258
+ "output_type": "stream",
259
+ "text": [
260
+ "Prediction: ['\"It is estimated that mental illness affects 1 in 5 adults in America']\n"
261
+ ]
262
+ }
263
+ ],
264
+ "source": [
265
+ "new_question_vector = vectorizer.transform([new_question])\n",
266
+ "prediction = model.predict(new_question_vector)\n",
267
+ "print(\"Prediction:\", prediction)"
268
+ ]
269
+ }
270
+ ],
271
+ "metadata": {
272
+ "kernelspec": {
273
+ "display_name": "Python 3 (ipykernel)",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.10.7"
288
+ },
289
+ "vscode": {
290
+ "interpreter": {
291
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
292
+ }
293
+ }
294
+ },
295
+ "nbformat": 4,
296
+ "nbformat_minor": 5
297
+ }