maryam1320 commited on
Commit
c3660a4
·
verified ·
1 Parent(s): 80419b3

fine tuning model

Browse files
Files changed (1) hide show
  1. fine_tuning_.ipynb +632 -0
fine_tuning_.ipynb ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "colab": {
22
+ "base_uri": "https://localhost:8080/"
23
+ },
24
+ "id": "pHzbx3WTkGcl",
25
+ "outputId": "2da56fff-5007-42d1-d924-4b0bc2ec08e6"
26
+ },
27
+ "outputs": [
28
+ {
29
+ "output_type": "stream",
30
+ "name": "stdout",
31
+ "text": [
32
+ "Collecting gradientai\n",
33
+ " Downloading gradientai-1.11.0-py3-none-any.whl (375 kB)\n",
34
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m375.5/375.5 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
35
+ "\u001b[?25hCollecting aenum>=3.1.11 (from gradientai)\n",
36
+ " Downloading aenum-3.1.15-py3-none-any.whl (137 kB)\n",
37
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m137.6/137.6 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
38
+ "\u001b[?25hCollecting pydantic<2.0.0,>=1.10.5 (from gradientai)\n",
39
+ " Downloading pydantic-1.10.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n",
40
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
41
+ "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from gradientai) (2.8.2)\n",
42
+ "Requirement already satisfied: urllib3>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from gradientai) (2.0.7)\n",
43
+ "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2.0.0,>=1.10.5->gradientai) (4.11.0)\n",
44
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->gradientai) (1.16.0)\n",
45
+ "Installing collected packages: aenum, pydantic, gradientai\n",
46
+ " Attempting uninstall: pydantic\n",
47
+ " Found existing installation: pydantic 2.7.1\n",
48
+ " Uninstalling pydantic-2.7.1:\n",
49
+ " Successfully uninstalled pydantic-2.7.1\n",
50
+ "Successfully installed aenum-3.1.15 gradientai-1.11.0 pydantic-1.10.15\n"
51
+ ]
52
+ }
53
+ ],
54
+ "source": [
55
+ "!pip install gradientai --upgrade"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "import os\n",
62
+ "import pandas as pd\n",
63
+ "os.environ['GRADIENT_WORKSPACE_ID']='9d0447f2-fcd4-4177-9145-9f019fd59f1e_workspace'\n",
64
+ "os.environ['GRADIENT_ACCESS_TOKEN']='cPErsUMgadGMbzeq8z8W36eJn7UA0Uob'"
65
+ ],
66
+ "metadata": {
67
+ "id": "XJfF9GXCkM1f"
68
+ },
69
+ "execution_count": null,
70
+ "outputs": []
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "source": [
75
+ "df = pd.read_csv(\"https://raw.githubusercontent.com/CS-5302/CS-5302-Project-Group-15/main/Datasets/testing/combined_df.csv\")\n",
76
+ "df"
77
+ ],
78
+ "metadata": {
79
+ "colab": {
80
+ "base_uri": "https://localhost:8080/",
81
+ "height": 424
82
+ },
83
+ "id": "1hAAan9ikNdE",
84
+ "outputId": "b84b6caf-1bc0-4aca-b4f5-3cc47e17293b"
85
+ },
86
+ "execution_count": null,
87
+ "outputs": [
88
+ {
89
+ "output_type": "execute_result",
90
+ "data": {
91
+ "text/plain": [
92
+ " prompts \\\n",
93
+ "0 I have Fever, Fatigue, Difficulty Breathing. W... \n",
94
+ "1 I have Cough, Fatigue. What disease do i have? \n",
95
+ "2 I have Cough, Fatigue. What disease do i have? \n",
96
+ "3 I have Fever, Cough, Difficulty Breathing. Wha... \n",
97
+ "4 I have Fever, Cough, Difficulty Breathing. Wha... \n",
98
+ "... ... \n",
99
+ "1344 i have leg pain, neck pain, paresthesia, leg c... \n",
100
+ "1345 i have nasal congestion, white discharge from ... \n",
101
+ "1346 i have cough, white discharge from eye, dimini... \n",
102
+ "1347 i have diminished hearing, headache, facial pa... \n",
103
+ "1348 i have diminished hearing, facial pain. what d... \n",
104
+ "\n",
105
+ " results \n",
106
+ "0 You have Influenza. \n",
107
+ "1 You don't have a disease. \n",
108
+ "2 You don't have a disease. \n",
109
+ "3 You have Asthma. \n",
110
+ "4 You have Asthma. \n",
111
+ "... ... \n",
112
+ "1344 you have spondylolisthesis. \n",
113
+ "1345 you have conjunctivitis due to virus. \n",
114
+ "1346 you have conjunctivitis due to virus. \n",
115
+ "1347 you have open wound of the nose. \n",
116
+ "1348 you have open wound of the nose. \n",
117
+ "\n",
118
+ "[1349 rows x 2 columns]"
119
+ ],
120
+ "text/html": [
121
+ "\n",
122
+ " <div id=\"df-3fcbebb9-218f-4ac3-b9f7-9218062579df\" class=\"colab-df-container\">\n",
123
+ " <div>\n",
124
+ "<style scoped>\n",
125
+ " .dataframe tbody tr th:only-of-type {\n",
126
+ " vertical-align: middle;\n",
127
+ " }\n",
128
+ "\n",
129
+ " .dataframe tbody tr th {\n",
130
+ " vertical-align: top;\n",
131
+ " }\n",
132
+ "\n",
133
+ " .dataframe thead th {\n",
134
+ " text-align: right;\n",
135
+ " }\n",
136
+ "</style>\n",
137
+ "<table border=\"1\" class=\"dataframe\">\n",
138
+ " <thead>\n",
139
+ " <tr style=\"text-align: right;\">\n",
140
+ " <th></th>\n",
141
+ " <th>prompts</th>\n",
142
+ " <th>results</th>\n",
143
+ " </tr>\n",
144
+ " </thead>\n",
145
+ " <tbody>\n",
146
+ " <tr>\n",
147
+ " <th>0</th>\n",
148
+ " <td>I have Fever, Fatigue, Difficulty Breathing. W...</td>\n",
149
+ " <td>You have Influenza.</td>\n",
150
+ " </tr>\n",
151
+ " <tr>\n",
152
+ " <th>1</th>\n",
153
+ " <td>I have Cough, Fatigue. What disease do i have?</td>\n",
154
+ " <td>You don't have a disease.</td>\n",
155
+ " </tr>\n",
156
+ " <tr>\n",
157
+ " <th>2</th>\n",
158
+ " <td>I have Cough, Fatigue. What disease do i have?</td>\n",
159
+ " <td>You don't have a disease.</td>\n",
160
+ " </tr>\n",
161
+ " <tr>\n",
162
+ " <th>3</th>\n",
163
+ " <td>I have Fever, Cough, Difficulty Breathing. Wha...</td>\n",
164
+ " <td>You have Asthma.</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <th>4</th>\n",
168
+ " <td>I have Fever, Cough, Difficulty Breathing. Wha...</td>\n",
169
+ " <td>You have Asthma.</td>\n",
170
+ " </tr>\n",
171
+ " <tr>\n",
172
+ " <th>...</th>\n",
173
+ " <td>...</td>\n",
174
+ " <td>...</td>\n",
175
+ " </tr>\n",
176
+ " <tr>\n",
177
+ " <th>1344</th>\n",
178
+ " <td>i have leg pain, neck pain, paresthesia, leg c...</td>\n",
179
+ " <td>you have spondylolisthesis.</td>\n",
180
+ " </tr>\n",
181
+ " <tr>\n",
182
+ " <th>1345</th>\n",
183
+ " <td>i have nasal congestion, white discharge from ...</td>\n",
184
+ " <td>you have conjunctivitis due to virus.</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <th>1346</th>\n",
188
+ " <td>i have cough, white discharge from eye, dimini...</td>\n",
189
+ " <td>you have conjunctivitis due to virus.</td>\n",
190
+ " </tr>\n",
191
+ " <tr>\n",
192
+ " <th>1347</th>\n",
193
+ " <td>i have diminished hearing, headache, facial pa...</td>\n",
194
+ " <td>you have open wound of the nose.</td>\n",
195
+ " </tr>\n",
196
+ " <tr>\n",
197
+ " <th>1348</th>\n",
198
+ " <td>i have diminished hearing, facial pain. what d...</td>\n",
199
+ " <td>you have open wound of the nose.</td>\n",
200
+ " </tr>\n",
201
+ " </tbody>\n",
202
+ "</table>\n",
203
+ "<p>1349 rows × 2 columns</p>\n",
204
+ "</div>\n",
205
+ " <div class=\"colab-df-buttons\">\n",
206
+ "\n",
207
+ " <div class=\"colab-df-container\">\n",
208
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-3fcbebb9-218f-4ac3-b9f7-9218062579df')\"\n",
209
+ " title=\"Convert this dataframe to an interactive table.\"\n",
210
+ " style=\"display:none;\">\n",
211
+ "\n",
212
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
213
+ " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
214
+ " </svg>\n",
215
+ " </button>\n",
216
+ "\n",
217
+ " <style>\n",
218
+ " .colab-df-container {\n",
219
+ " display:flex;\n",
220
+ " gap: 12px;\n",
221
+ " }\n",
222
+ "\n",
223
+ " .colab-df-convert {\n",
224
+ " background-color: #E8F0FE;\n",
225
+ " border: none;\n",
226
+ " border-radius: 50%;\n",
227
+ " cursor: pointer;\n",
228
+ " display: none;\n",
229
+ " fill: #1967D2;\n",
230
+ " height: 32px;\n",
231
+ " padding: 0 0 0 0;\n",
232
+ " width: 32px;\n",
233
+ " }\n",
234
+ "\n",
235
+ " .colab-df-convert:hover {\n",
236
+ " background-color: #E2EBFA;\n",
237
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
238
+ " fill: #174EA6;\n",
239
+ " }\n",
240
+ "\n",
241
+ " .colab-df-buttons div {\n",
242
+ " margin-bottom: 4px;\n",
243
+ " }\n",
244
+ "\n",
245
+ " [theme=dark] .colab-df-convert {\n",
246
+ " background-color: #3B4455;\n",
247
+ " fill: #D2E3FC;\n",
248
+ " }\n",
249
+ "\n",
250
+ " [theme=dark] .colab-df-convert:hover {\n",
251
+ " background-color: #434B5C;\n",
252
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
253
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
254
+ " fill: #FFFFFF;\n",
255
+ " }\n",
256
+ " </style>\n",
257
+ "\n",
258
+ " <script>\n",
259
+ " const buttonEl =\n",
260
+ " document.querySelector('#df-3fcbebb9-218f-4ac3-b9f7-9218062579df button.colab-df-convert');\n",
261
+ " buttonEl.style.display =\n",
262
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
263
+ "\n",
264
+ " async function convertToInteractive(key) {\n",
265
+ " const element = document.querySelector('#df-3fcbebb9-218f-4ac3-b9f7-9218062579df');\n",
266
+ " const dataTable =\n",
267
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
268
+ " [key], {});\n",
269
+ " if (!dataTable) return;\n",
270
+ "\n",
271
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
272
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
273
+ " + ' to learn more about interactive tables.';\n",
274
+ " element.innerHTML = '';\n",
275
+ " dataTable['output_type'] = 'display_data';\n",
276
+ " await google.colab.output.renderOutput(dataTable, element);\n",
277
+ " const docLink = document.createElement('div');\n",
278
+ " docLink.innerHTML = docLinkHtml;\n",
279
+ " element.appendChild(docLink);\n",
280
+ " }\n",
281
+ " </script>\n",
282
+ " </div>\n",
283
+ "\n",
284
+ "\n",
285
+ "<div id=\"df-7a3554a8-e00d-4ba4-88ed-15c7da18b7db\">\n",
286
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-7a3554a8-e00d-4ba4-88ed-15c7da18b7db')\"\n",
287
+ " title=\"Suggest charts\"\n",
288
+ " style=\"display:none;\">\n",
289
+ "\n",
290
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
291
+ " width=\"24px\">\n",
292
+ " <g>\n",
293
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
294
+ " </g>\n",
295
+ "</svg>\n",
296
+ " </button>\n",
297
+ "\n",
298
+ "<style>\n",
299
+ " .colab-df-quickchart {\n",
300
+ " --bg-color: #E8F0FE;\n",
301
+ " --fill-color: #1967D2;\n",
302
+ " --hover-bg-color: #E2EBFA;\n",
303
+ " --hover-fill-color: #174EA6;\n",
304
+ " --disabled-fill-color: #AAA;\n",
305
+ " --disabled-bg-color: #DDD;\n",
306
+ " }\n",
307
+ "\n",
308
+ " [theme=dark] .colab-df-quickchart {\n",
309
+ " --bg-color: #3B4455;\n",
310
+ " --fill-color: #D2E3FC;\n",
311
+ " --hover-bg-color: #434B5C;\n",
312
+ " --hover-fill-color: #FFFFFF;\n",
313
+ " --disabled-bg-color: #3B4455;\n",
314
+ " --disabled-fill-color: #666;\n",
315
+ " }\n",
316
+ "\n",
317
+ " .colab-df-quickchart {\n",
318
+ " background-color: var(--bg-color);\n",
319
+ " border: none;\n",
320
+ " border-radius: 50%;\n",
321
+ " cursor: pointer;\n",
322
+ " display: none;\n",
323
+ " fill: var(--fill-color);\n",
324
+ " height: 32px;\n",
325
+ " padding: 0;\n",
326
+ " width: 32px;\n",
327
+ " }\n",
328
+ "\n",
329
+ " .colab-df-quickchart:hover {\n",
330
+ " background-color: var(--hover-bg-color);\n",
331
+ " box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
332
+ " fill: var(--button-hover-fill-color);\n",
333
+ " }\n",
334
+ "\n",
335
+ " .colab-df-quickchart-complete:disabled,\n",
336
+ " .colab-df-quickchart-complete:disabled:hover {\n",
337
+ " background-color: var(--disabled-bg-color);\n",
338
+ " fill: var(--disabled-fill-color);\n",
339
+ " box-shadow: none;\n",
340
+ " }\n",
341
+ "\n",
342
+ " .colab-df-spinner {\n",
343
+ " border: 2px solid var(--fill-color);\n",
344
+ " border-color: transparent;\n",
345
+ " border-bottom-color: var(--fill-color);\n",
346
+ " animation:\n",
347
+ " spin 1s steps(1) infinite;\n",
348
+ " }\n",
349
+ "\n",
350
+ " @keyframes spin {\n",
351
+ " 0% {\n",
352
+ " border-color: transparent;\n",
353
+ " border-bottom-color: var(--fill-color);\n",
354
+ " border-left-color: var(--fill-color);\n",
355
+ " }\n",
356
+ " 20% {\n",
357
+ " border-color: transparent;\n",
358
+ " border-left-color: var(--fill-color);\n",
359
+ " border-top-color: var(--fill-color);\n",
360
+ " }\n",
361
+ " 30% {\n",
362
+ " border-color: transparent;\n",
363
+ " border-left-color: var(--fill-color);\n",
364
+ " border-top-color: var(--fill-color);\n",
365
+ " border-right-color: var(--fill-color);\n",
366
+ " }\n",
367
+ " 40% {\n",
368
+ " border-color: transparent;\n",
369
+ " border-right-color: var(--fill-color);\n",
370
+ " border-top-color: var(--fill-color);\n",
371
+ " }\n",
372
+ " 60% {\n",
373
+ " border-color: transparent;\n",
374
+ " border-right-color: var(--fill-color);\n",
375
+ " }\n",
376
+ " 80% {\n",
377
+ " border-color: transparent;\n",
378
+ " border-right-color: var(--fill-color);\n",
379
+ " border-bottom-color: var(--fill-color);\n",
380
+ " }\n",
381
+ " 90% {\n",
382
+ " border-color: transparent;\n",
383
+ " border-bottom-color: var(--fill-color);\n",
384
+ " }\n",
385
+ " }\n",
386
+ "</style>\n",
387
+ "\n",
388
+ " <script>\n",
389
+ " async function quickchart(key) {\n",
390
+ " const quickchartButtonEl =\n",
391
+ " document.querySelector('#' + key + ' button');\n",
392
+ " quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
393
+ " quickchartButtonEl.classList.add('colab-df-spinner');\n",
394
+ " try {\n",
395
+ " const charts = await google.colab.kernel.invokeFunction(\n",
396
+ " 'suggestCharts', [key], {});\n",
397
+ " } catch (error) {\n",
398
+ " console.error('Error during call to suggestCharts:', error);\n",
399
+ " }\n",
400
+ " quickchartButtonEl.classList.remove('colab-df-spinner');\n",
401
+ " quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
402
+ " }\n",
403
+ " (() => {\n",
404
+ " let quickchartButtonEl =\n",
405
+ " document.querySelector('#df-7a3554a8-e00d-4ba4-88ed-15c7da18b7db button');\n",
406
+ " quickchartButtonEl.style.display =\n",
407
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
408
+ " })();\n",
409
+ " </script>\n",
410
+ "</div>\n",
411
+ "\n",
412
+ " <div id=\"id_63d1b66c-c721-4f22-ae5a-f946ac5f93d4\">\n",
413
+ " <style>\n",
414
+ " .colab-df-generate {\n",
415
+ " background-color: #E8F0FE;\n",
416
+ " border: none;\n",
417
+ " border-radius: 50%;\n",
418
+ " cursor: pointer;\n",
419
+ " display: none;\n",
420
+ " fill: #1967D2;\n",
421
+ " height: 32px;\n",
422
+ " padding: 0 0 0 0;\n",
423
+ " width: 32px;\n",
424
+ " }\n",
425
+ "\n",
426
+ " .colab-df-generate:hover {\n",
427
+ " background-color: #E2EBFA;\n",
428
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
429
+ " fill: #174EA6;\n",
430
+ " }\n",
431
+ "\n",
432
+ " [theme=dark] .colab-df-generate {\n",
433
+ " background-color: #3B4455;\n",
434
+ " fill: #D2E3FC;\n",
435
+ " }\n",
436
+ "\n",
437
+ " [theme=dark] .colab-df-generate:hover {\n",
438
+ " background-color: #434B5C;\n",
439
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
440
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
441
+ " fill: #FFFFFF;\n",
442
+ " }\n",
443
+ " </style>\n",
444
+ " <button class=\"colab-df-generate\" onclick=\"generateWithVariable('df')\"\n",
445
+ " title=\"Generate code using this dataframe.\"\n",
446
+ " style=\"display:none;\">\n",
447
+ "\n",
448
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
449
+ " width=\"24px\">\n",
450
+ " <path d=\"M7,19H8.4L18.45,9,17,7.55,7,17.6ZM5,21V16.75L18.45,3.32a2,2,0,0,1,2.83,0l1.4,1.43a1.91,1.91,0,0,1,.58,1.4,1.91,1.91,0,0,1-.58,1.4L9.25,21ZM18.45,9,17,7.55Zm-12,3A5.31,5.31,0,0,0,4.9,8.1,5.31,5.31,0,0,0,1,6.5,5.31,5.31,0,0,0,4.9,4.9,5.31,5.31,0,0,0,6.5,1,5.31,5.31,0,0,0,8.1,4.9,5.31,5.31,0,0,0,12,6.5,5.46,5.46,0,0,0,6.5,12Z\"/>\n",
451
+ " </svg>\n",
452
+ " </button>\n",
453
+ " <script>\n",
454
+ " (() => {\n",
455
+ " const buttonEl =\n",
456
+ " document.querySelector('#id_63d1b66c-c721-4f22-ae5a-f946ac5f93d4 button.colab-df-generate');\n",
457
+ " buttonEl.style.display =\n",
458
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
459
+ "\n",
460
+ " buttonEl.onclick = () => {\n",
461
+ " google.colab.notebook.generateWithVariable('df');\n",
462
+ " }\n",
463
+ " })();\n",
464
+ " </script>\n",
465
+ " </div>\n",
466
+ "\n",
467
+ " </div>\n",
468
+ " </div>\n"
469
+ ],
470
+ "application/vnd.google.colaboratory.intrinsic+json": {
471
+ "type": "dataframe",
472
+ "variable_name": "df",
473
+ "summary": "{\n \"name\": \"df\",\n \"rows\": 1349,\n \"fields\": [\n {\n \"column\": \"prompts\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1052,\n \"samples\": [\n \"i have sharp chest pain, abusing alcohol, sharp abdominal pain, vomiting, diarrhea, back pain, burning abdominal pain, side pain, lower body pain, upper abdominal pain. what disease do i have?\",\n \"i have wrist pain, hand or finger swelling, arm pain, knee pain, foot or toe pain, ankle pain, shoulder pain. what disease do i have?\",\n \"i have sharp chest pain, leg pain, sharp abdominal pain, vomiting, lower abdominal pain, low back pain. what disease do i have?\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"results\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 455,\n \"samples\": [\n \"you have temporary or benign blood in urine.\",\n \"You have Hypoglycemia.\",\n \"you have sebaceous cyst.\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
474
+ }
475
+ },
476
+ "metadata": {},
477
+ "execution_count": 3
478
+ }
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "source": [
484
+ "\n",
485
+ "BATCH_SIZE = 100\n",
486
+ "NUM_EPOCHS = 1\n"
487
+ ],
488
+ "metadata": {
489
+ "id": "0_f3sqBtpXpQ"
490
+ },
491
+ "execution_count": null,
492
+ "outputs": []
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "source": [
497
+ "def create_model_adapter(gradient):\n",
498
+ " base_model = gradient.get_base_model(base_model_slug=\"nous-hermes2\")\n",
499
+ " new_model_adapter = base_model.create_model_adapter(\n",
500
+ " name=\"meta/llama-2-7b:73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9\"\n",
501
+ " )\n",
502
+ " print(f\"Created model adapter with id {new_model_adapter.id}\")\n",
503
+ " return new_model_adapter"
504
+ ],
505
+ "metadata": {
506
+ "id": "bW9Qin9MpcRU"
507
+ },
508
+ "execution_count": null,
509
+ "outputs": []
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "source": [
514
+ "def fine_tune_in_batches(df, gradient, batch_size, num_epochs):\n",
515
+ " new_model_adapter = create_model_adapter(gradient)\n",
516
+ "\n",
517
+ " # Split the DataFrame into batches\n",
518
+ " batches = [df[i:i + batch_size] for i in range(0, len(df), batch_size)]\n",
519
+ "\n",
520
+ " # Iterate over batches and perform fine-tuning\n",
521
+ " for batch_index, batch in enumerate(batches):\n",
522
+ " fine_tuning_samples = []\n",
523
+ " for _, row in batch.iterrows():\n",
524
+ " fine_tuning_samples.append({\n",
525
+ " \"inputs\": f\"### Instruction: {row['prompts']}\",\n",
526
+ " \"targets\": f\"### Response: {row['results']}\"\n",
527
+ " })\n",
528
+ "\n",
529
+ " # Fine-tune for the given number of epochs\n",
530
+ " for epoch in range(num_epochs):\n",
531
+ " print(f\"Fine-tuning batch {batch_index + 1} (epoch {epoch + 1})\")\n",
532
+ " new_model_adapter.fine_tune(samples=fine_tuning_samples)\n",
533
+ "\n",
534
+ " return new_model_adapter"
535
+ ],
536
+ "metadata": {
537
+ "id": "XfDSZVfspe9c"
538
+ },
539
+ "execution_count": null,
540
+ "outputs": []
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "source": [
545
+ "from gradientai import Gradient\n",
546
+ "\n",
547
+ "def main():\n",
548
+ " # Initialize the Gradient API\n",
549
+ " gradient = Gradient()\n",
550
+ "\n",
551
+ "\n",
552
+ " # Fine-tune in batches and retain the final model adapter\n",
553
+ " model_adapter = fine_tune_in_batches(df, gradient, BATCH_SIZE, NUM_EPOCHS)\n",
554
+ "\n",
555
+ " # Test the model after fine-tuning\n",
556
+ " sample_query = f\"### Instruction: {df['prompts'][0]} \\n\\n### Response:\"\n",
557
+ " completion = model_adapter.complete(query=sample_query, max_generated_token_count=100).generated_output\n",
558
+ " print(f\"Generated (after fine-tuning): {completion}\")\n",
559
+ "\n",
560
+ " # Clean up the model adapter\n",
561
+ " model_adapter.delete()\n",
562
+ " gradient.close()\n",
563
+ "\n",
564
+ "if __name__ == \"__main__\":\n",
565
+ " main()"
566
+ ],
567
+ "metadata": {
568
+ "colab": {
569
+ "base_uri": "https://localhost:8080/"
570
+ },
571
+ "id": "RzHdyX0nkPtl",
572
+ "outputId": "b5ae54fd-6246-414b-cffb-5421a5e8c8a5"
573
+ },
574
+ "execution_count": null,
575
+ "outputs": [
576
+ {
577
+ "output_type": "stream",
578
+ "name": "stdout",
579
+ "text": [
580
+ "Created model adapter with id 9aebddd0-336b-4b61-8910-d8be7ef38f43_model_adapter\n",
581
+ "Fine-tuning batch 1 (epoch 1)\n",
582
+ "Fine-tuning batch 2 (epoch 1)\n",
583
+ "Fine-tuning batch 3 (epoch 1)\n",
584
+ "Fine-tuning batch 4 (epoch 1)\n",
585
+ "Fine-tuning batch 5 (epoch 1)\n",
586
+ "Fine-tuning batch 6 (epoch 1)\n",
587
+ "Fine-tuning batch 7 (epoch 1)\n",
588
+ "Fine-tuning batch 8 (epoch 1)\n",
589
+ "Fine-tuning batch 9 (epoch 1)\n",
590
+ "Fine-tuning batch 10 (epoch 1)\n",
591
+ "Fine-tuning batch 11 (epoch 1)\n",
592
+ "Fine-tuning batch 12 (epoch 1)\n",
593
+ "Fine-tuning batch 13 (epoch 1)\n",
594
+ "Fine-tuning batch 14 (epoch 1)\n",
595
+ "Generated (after fine-tuning): You may have Pneumonia.\n"
596
+ ]
597
+ }
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "source": [
603
+ "sample_query = f\"### Instruction: {df['prompts'][0]} \\n\\n### Response:\"\n",
604
+ "sample_query"
605
+ ],
606
+ "metadata": {
607
+ "colab": {
608
+ "base_uri": "https://localhost:8080/",
609
+ "height": 35
610
+ },
611
+ "id": "LLfG-8dG3gF3",
612
+ "outputId": "bd7192a7-6a07-4262-eaf6-fb62a5b278c6"
613
+ },
614
+ "execution_count": null,
615
+ "outputs": [
616
+ {
617
+ "output_type": "execute_result",
618
+ "data": {
619
+ "text/plain": [
620
+ "'### Instruction: I have Fever, Fatigue, Difficulty Breathing. What disease do i have? \\n\\n### Response:'"
621
+ ],
622
+ "application/vnd.google.colaboratory.intrinsic+json": {
623
+ "type": "string"
624
+ }
625
+ },
626
+ "metadata": {},
627
+ "execution_count": 20
628
+ }
629
+ ]
630
+ }
631
+ ]
632
+ }