Michelangiolo commited on
Commit
cf172ac
·
1 Parent(s): 77a74de
Files changed (4) hide show
  1. 1_data_processing.ipynb +215 -0
  2. 2_gradio.ipynb +145 -0
  3. app.py +36 -0
  4. df_encoded.parquet +3 -0
1_data_processing.ipynb ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%pip install sentence-transformers==2.0.0"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "1. Load dataset with pandas"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 7,
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "data": {
26
+ "text/html": [
27
+ "<div>\n",
28
+ "<style scoped>\n",
29
+ " .dataframe tbody tr th:only-of-type {\n",
30
+ " vertical-align: middle;\n",
31
+ " }\n",
32
+ "\n",
33
+ " .dataframe tbody tr th {\n",
34
+ " vertical-align: top;\n",
35
+ " }\n",
36
+ "\n",
37
+ " .dataframe thead th {\n",
38
+ " text-align: right;\n",
39
+ " }\n",
40
+ "</style>\n",
41
+ "<table border=\"1\" class=\"dataframe\">\n",
42
+ " <thead>\n",
43
+ " <tr style=\"text-align: right;\">\n",
44
+ " <th></th>\n",
45
+ " <th>Description</th>\n",
46
+ " <th>UnitPrice</th>\n",
47
+ " <th>Country</th>\n",
48
+ " </tr>\n",
49
+ " </thead>\n",
50
+ " <tbody>\n",
51
+ " <tr>\n",
52
+ " <th>0</th>\n",
53
+ " <td>WHITE HANGING HEART T-LIGHT HOLDER</td>\n",
54
+ " <td>2.55</td>\n",
55
+ " <td>United Kingdom</td>\n",
56
+ " </tr>\n",
57
+ " <tr>\n",
58
+ " <th>1</th>\n",
59
+ " <td>WHITE METAL LANTERN</td>\n",
60
+ " <td>3.39</td>\n",
61
+ " <td>United Kingdom</td>\n",
62
+ " </tr>\n",
63
+ " <tr>\n",
64
+ " <th>2</th>\n",
65
+ " <td>CREAM CUPID HEARTS COAT HANGER</td>\n",
66
+ " <td>2.75</td>\n",
67
+ " <td>United Kingdom</td>\n",
68
+ " </tr>\n",
69
+ " <tr>\n",
70
+ " <th>3</th>\n",
71
+ " <td>KNITTED UNION FLAG HOT WATER BOTTLE</td>\n",
72
+ " <td>3.39</td>\n",
73
+ " <td>United Kingdom</td>\n",
74
+ " </tr>\n",
75
+ " <tr>\n",
76
+ " <th>4</th>\n",
77
+ " <td>RED WOOLLY HOTTIE WHITE HEART.</td>\n",
78
+ " <td>3.39</td>\n",
79
+ " <td>United Kingdom</td>\n",
80
+ " </tr>\n",
81
+ " <tr>\n",
82
+ " <th>...</th>\n",
83
+ " <td>...</td>\n",
84
+ " <td>...</td>\n",
85
+ " <td>...</td>\n",
86
+ " </tr>\n",
87
+ " <tr>\n",
88
+ " <th>535327</th>\n",
89
+ " <td>????damages????</td>\n",
90
+ " <td>0.00</td>\n",
91
+ " <td>United Kingdom</td>\n",
92
+ " </tr>\n",
93
+ " <tr>\n",
94
+ " <th>535329</th>\n",
95
+ " <td>mixed up</td>\n",
96
+ " <td>0.00</td>\n",
97
+ " <td>United Kingdom</td>\n",
98
+ " </tr>\n",
99
+ " <tr>\n",
100
+ " <th>535335</th>\n",
101
+ " <td>lost</td>\n",
102
+ " <td>0.00</td>\n",
103
+ " <td>United Kingdom</td>\n",
104
+ " </tr>\n",
105
+ " <tr>\n",
106
+ " <th>537621</th>\n",
107
+ " <td>CREAM HANGING HEART T-LIGHT HOLDER</td>\n",
108
+ " <td>2.95</td>\n",
109
+ " <td>United Kingdom</td>\n",
110
+ " </tr>\n",
111
+ " <tr>\n",
112
+ " <th>540421</th>\n",
113
+ " <td>PAPER CRAFT , LITTLE BIRDIE</td>\n",
114
+ " <td>2.08</td>\n",
115
+ " <td>United Kingdom</td>\n",
116
+ " </tr>\n",
117
+ " </tbody>\n",
118
+ "</table>\n",
119
+ "<p>4223 rows × 3 columns</p>\n",
120
+ "</div>"
121
+ ],
122
+ "text/plain": [
123
+ " Description UnitPrice Country\n",
124
+ "0 WHITE HANGING HEART T-LIGHT HOLDER 2.55 United Kingdom\n",
125
+ "1 WHITE METAL LANTERN 3.39 United Kingdom\n",
126
+ "2 CREAM CUPID HEARTS COAT HANGER 2.75 United Kingdom\n",
127
+ "3 KNITTED UNION FLAG HOT WATER BOTTLE 3.39 United Kingdom\n",
128
+ "4 RED WOOLLY HOTTIE WHITE HEART. 3.39 United Kingdom\n",
129
+ "... ... ... ...\n",
130
+ "535327 ????damages???? 0.00 United Kingdom\n",
131
+ "535329 mixed up 0.00 United Kingdom\n",
132
+ "535335 lost 0.00 United Kingdom\n",
133
+ "537621 CREAM HANGING HEART T-LIGHT HOLDER 2.95 United Kingdom\n",
134
+ "540421 PAPER CRAFT , LITTLE BIRDIE 2.08 United Kingdom\n",
135
+ "\n",
136
+ "[4223 rows x 3 columns]"
137
+ ]
138
+ },
139
+ "execution_count": 7,
140
+ "metadata": {},
141
+ "output_type": "execute_result"
142
+ }
143
+ ],
144
+ "source": [
145
+ "import pandas as pd\n",
146
+ "\n",
147
+ "df = pd.read_csv('products.csv')\n",
148
+ "df = df[['Description', 'UnitPrice', 'Country']]\n",
149
+ "df = df.dropna().drop_duplicates(subset=['Description'])\n",
150
+ "df"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "metadata": {},
156
+ "source": [
157
+ "2. Encode 100 samples into vectors (1 column with product text, 1 column with vectors)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": null,
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "import pandas as pd\n",
167
+ "from tqdm import tqdm\n",
168
+ "from sentence_transformers import SentenceTransformer\n",
169
+ "tqdm.pandas()\n",
170
+ "\n",
171
+ "model = SentenceTransformer('all-mpnet-base-v2') #all-MiniLM-L6-v2 #all-mpnet-base-v2\n",
172
+ "\n",
173
+ "#encode df version: for small dataset only\n",
174
+ "df['text_vector_'] = df['Description'].progress_apply(lambda x : model.encode(x).tolist())\n",
175
+ "df"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 9,
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "df.to_parquet('df_encoded.parquet', index=None)"
185
+ ]
186
+ }
187
+ ],
188
+ "metadata": {
189
+ "kernelspec": {
190
+ "display_name": "Python 3.9.0 64-bit",
191
+ "language": "python",
192
+ "name": "python3"
193
+ },
194
+ "language_info": {
195
+ "codemirror_mode": {
196
+ "name": "ipython",
197
+ "version": 3
198
+ },
199
+ "file_extension": ".py",
200
+ "mimetype": "text/x-python",
201
+ "name": "python",
202
+ "nbconvert_exporter": "python",
203
+ "pygments_lexer": "ipython3",
204
+ "version": "3.9.13"
205
+ },
206
+ "orig_nbformat": 4,
207
+ "vscode": {
208
+ "interpreter": {
209
+ "hash": "fdf377d643bc1cb065454f0ad2ceac75d834452ecf289e7ba92c6b3f59a7cee1"
210
+ }
211
+ }
212
+ },
213
+ "nbformat": 4,
214
+ "nbformat_minor": 2
215
+ }
2_gradio.ipynb ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# !pip install sentence-transformers==2.0.0"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 3,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pandas as pd\n",
19
+ "from tqdm import tqdm\n",
20
+ "from sentence_transformers import SentenceTransformer\n",
21
+ "\n",
22
+ "model = SentenceTransformer('all-mpnet-base-v2') #all-MiniLM-L6-v2 #all-mpnet-base-v2"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 5,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import pandas as pd\n",
32
+ "\n",
33
+ "df = pd.read_parquet('df_encoded.parquet')"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 6,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "from sklearn.neighbors import NearestNeighbors\n",
43
+ "import numpy as np\n",
44
+ "import pandas as pd\n",
45
+ "\n",
46
+ "from sentence_transformers import SentenceTransformer\n",
47
+ "\n",
48
+ "# model = SentenceTransformer('all-mpnet-base-v2') #all-MiniLM-L6-v2 #all-mpnet-base-v2\n",
49
+ "\n",
50
+ "#prepare model\n",
51
+ "nbrs = NearestNeighbors(n_neighbors=8, algorithm='ball_tree').fit(df['text_vector_'].values.tolist())"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 7,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "def search(query):\n",
61
+ " product = model.encode(query).tolist()\n",
62
+ " # product = df.iloc[0]['text_vector_'] #use one of the products as sample\n",
63
+ "\n",
64
+ " distances, indices = nbrs.kneighbors([product]) #input the vector of the reference object\n",
65
+ "\n",
66
+ " #print out the description of every recommended product\n",
67
+ " return df.iloc[list(indices)[0]][['Description', 'UnitPrice', 'Country']]"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 10,
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "Running on local URL: http://127.0.0.1:7860\n",
80
+ "\n",
81
+ "To create a public link, set `share=True` in `launch()`.\n"
82
+ ]
83
+ },
84
+ {
85
+ "data": {
86
+ "text/html": [
87
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
88
+ ],
89
+ "text/plain": [
90
+ "<IPython.core.display.HTML object>"
91
+ ]
92
+ },
93
+ "metadata": {},
94
+ "output_type": "display_data"
95
+ },
96
+ {
97
+ "data": {
98
+ "text/plain": []
99
+ },
100
+ "execution_count": 10,
101
+ "metadata": {},
102
+ "output_type": "execute_result"
103
+ }
104
+ ],
105
+ "source": [
106
+ "import gradio as gr\n",
107
+ "import os\n",
108
+ "\n",
109
+ "#the first module becomes text1, the second module file1\n",
110
+ "def greet(text1): \n",
111
+ " return search(text1)\n",
112
+ "\n",
113
+ "iface = gr.Interface(fn=greet, inputs=['text'], outputs=[\"dataframe\"])\n",
114
+ "iface.launch(share=False)"
115
+ ]
116
+ }
117
+ ],
118
+ "metadata": {
119
+ "kernelspec": {
120
+ "display_name": "Python 3.9.0 64-bit",
121
+ "language": "python",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "codemirror_mode": {
126
+ "name": "ipython",
127
+ "version": 3
128
+ },
129
+ "file_extension": ".py",
130
+ "mimetype": "text/x-python",
131
+ "name": "python",
132
+ "nbconvert_exporter": "python",
133
+ "pygments_lexer": "ipython3",
134
+ "version": "3.9.13"
135
+ },
136
+ "orig_nbformat": 4,
137
+ "vscode": {
138
+ "interpreter": {
139
+ "hash": "fdf377d643bc1cb065454f0ad2ceac75d834452ecf289e7ba92c6b3f59a7cee1"
140
+ }
141
+ }
142
+ },
143
+ "nbformat": 4,
144
+ "nbformat_minor": 2
145
+ }
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install openpyxl')
3
+ os.system('pip install sentence-transformers')
4
+ import pandas as pd
5
+ import gradio as gr
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.neighbors import NearestNeighbors
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ model = SentenceTransformer('all-mpnet-base-v2') #all-MiniLM-L6-v2 #all-mpnet-base-v2
14
+ df = pd.read_parquet('df_encoded.parquet')
15
+
16
+ #prepare model
17
+ nbrs = NearestNeighbors(n_neighbors=8, algorithm='ball_tree').fit(df['text_vector_'].values.tolist())
18
+
19
+ def search(df, query):
20
+ product = model.encode(query).tolist()
21
+ # product = df.iloc[0]['text_vector_'] #use one of the products as sample
22
+
23
+ distances, indices = nbrs.kneighbors([product]) #input the vector of the reference object
24
+
25
+ #print out the description of every recommended product
26
+ return df.iloc[list(indices)[0]][['Description', 'UnitPrice', 'Country']]
27
+
28
+ import gradio as gr
29
+ import os
30
+
31
+ #the first module becomes text1, the second module file1
32
+ def greet(text1):
33
+ return search(df, text1)
34
+
35
+ iface = gr.Interface(fn=greet, inputs=['text'], outputs=["dataframe"])
36
+ iface.launch(share=False)
df_encoded.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46c74a19104ae10b2c173f39825f3e08174e0f5f213c2e2392d95ca364e49c60
3
+ size 20362183