AhmedGaver commited on
Commit
e7b7078
·
verified ·
1 Parent(s): 6d46009

Upload v2 of URL classifier model (hybrid BERT + tabular)

Browse files

Training Metrics:
- Eval Loss: 0.07472482323646545
- Eval F1 Macro: 0.9319480242737234
- Eval Accuracy: 0.9817232375979112

Files changed (3) hide show
  1. handler.py +138 -0
  2. known_platforms.json +254 -0
  3. requirements.txt +1 -0
handler.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import json
4
+ import torch
5
+ import torch.nn as nn
6
+ from urllib.parse import urlparse
7
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+
10
+ PROFILE_SLUGS = re.compile(
11
+ r'/(profile|store|shop|freelancers?|biz|therapists?|counsellors?|'
12
+ r'restaurants?|menu|cottage|actors?|celebrants?|broker-finder|'
13
+ r'users?|usr|sellers?|vendors?|merchants?|dealers?|agents?|'
14
+ r'members?|str|book|booking|appointments?)(/|$)', re.IGNORECASE
15
+ )
16
+
17
+ NUM_TABULAR_FEATURES = 6
18
+ NUMERIC_ID_IN_PATH = re.compile(r'/\d{3,}(/|$)')
19
+ TABULAR_HIDDEN_SIZE = 128
20
+
21
+ KNOWN_PLATFORMS_PATH = os.path.join(os.path.dirname(__file__), "known_platforms.json")
22
+ with open(KNOWN_PLATFORMS_PATH) as _f:
23
+ KNOWN_PLATFORMS = set(json.load(_f))
24
+
25
+ try:
26
+ import tldextract
27
+ _get_registered_domain = lambda url: tldextract.extract(url).registered_domain.lower()
28
+ _tld = lambda url: tldextract.extract(url).suffix.lower()
29
+ except ImportError:
30
+ _get_registered_domain = lambda url: '.'.join(urlparse(url).netloc.lower().split('.')[-2:])
31
+ _tld = lambda url: urlparse(url).netloc.lower().split('.')[-1]
32
+
33
+ _subdomain_dot_count = lambda url: max(0, urlparse(url).netloc.count('.') - 1)
34
+ _path_depth = lambda url: len([s for s in urlparse(url).path.split('/') if s])
35
+
36
+ extract_tabular_features = lambda url: [
37
+ 1.0 if PROFILE_SLUGS.search(urlparse(url).path.lower()) else 0.0,
38
+ 1.0 if _get_registered_domain(url) in KNOWN_PLATFORMS else 0.0,
39
+ min(_path_depth(url) / 10.0, 1.0),
40
+ min(_subdomain_dot_count(url) / 3.0, 1.0),
41
+ 1.0 if NUMERIC_ID_IN_PATH.search(urlparse(url).path) else 0.0,
42
+ 1.0 if _tld(url) == 'jp' else 0.0,
43
+ ]
44
+
45
+
46
+ class UrlBertWithTabular(nn.Module):
47
+ def __init__(self, bert_model_name, num_labels, num_tabular_features=NUM_TABULAR_FEATURES):
48
+ super().__init__()
49
+ self.bert = AutoModel.from_pretrained(bert_model_name)
50
+ self.hidden_size = self.bert.config.hidden_size
51
+ self.num_labels = num_labels
52
+ self.num_tabular_features = num_tabular_features
53
+ self.tabular_proj = nn.Sequential(
54
+ nn.Linear(num_tabular_features, TABULAR_HIDDEN_SIZE),
55
+ nn.ReLU(),
56
+ nn.Dropout(0.1),
57
+ )
58
+ self.classifier = nn.Linear(self.hidden_size + TABULAR_HIDDEN_SIZE, num_labels)
59
+
60
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, tabular_features=None, **kwargs):
61
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
62
+ cls_output = bert_output.last_hidden_state[:, 0, :]
63
+ tabular_proj = self.tabular_proj(tabular_features.float())
64
+ combined = torch.cat([cls_output, tabular_proj], dim=1)
65
+ logits = self.classifier(combined)
66
+ return SequenceClassifierOutput(logits=logits)
67
+
68
+ @classmethod
69
+ def from_pretrained(cls, save_directory):
70
+ with open(os.path.join(save_directory, "tabular_config.json")) as f:
71
+ tabular_config = json.load(f)
72
+ bert_config = AutoConfig.from_pretrained(save_directory)
73
+ model = cls.__new__(cls)
74
+ nn.Module.__init__(model)
75
+ model.bert = AutoModel.from_config(bert_config)
76
+ model.hidden_size = bert_config.hidden_size
77
+ model.num_labels = tabular_config["num_labels"]
78
+ model.num_tabular_features = tabular_config["num_tabular_features"]
79
+ model.tabular_proj = nn.Sequential(
80
+ nn.Linear(model.num_tabular_features, TABULAR_HIDDEN_SIZE),
81
+ nn.ReLU(),
82
+ nn.Dropout(0.1),
83
+ )
84
+ model.classifier = nn.Linear(model.hidden_size + TABULAR_HIDDEN_SIZE, model.num_labels)
85
+ safetensors_path = os.path.join(save_directory, "model.safetensors")
86
+ bin_path = os.path.join(save_directory, "pytorch_model.bin")
87
+ if os.path.exists(safetensors_path):
88
+ from safetensors.torch import load_file
89
+ state_dict = load_file(safetensors_path)
90
+ else:
91
+ state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
92
+ model.load_state_dict(state_dict)
93
+ return model
94
+
95
+
96
+ LABEL_MAP = {0: "official_website", 1: "platform"}
97
+
98
+
99
+ class EndpointHandler:
100
+ def __init__(self, path=""):
101
+ self.model = UrlBertWithTabular.from_pretrained(path)
102
+ self.model.eval()
103
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
104
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
+ self.model.to(self.device)
106
+
107
+ def __call__(self, data):
108
+ inputs = data.get("inputs", data)
109
+ if isinstance(inputs, str):
110
+ inputs = [inputs]
111
+
112
+ encodings = self.tokenizer(
113
+ inputs, padding=True, truncation=True, max_length=128, return_tensors="pt"
114
+ ).to(self.device)
115
+
116
+ tabular = torch.tensor(
117
+ [extract_tabular_features(url) for url in inputs], dtype=torch.float32
118
+ ).to(self.device)
119
+
120
+ with torch.no_grad():
121
+ outputs = self.model(
122
+ input_ids=encodings["input_ids"],
123
+ attention_mask=encodings["attention_mask"],
124
+ tabular_features=tabular,
125
+ )
126
+
127
+ probs = torch.softmax(outputs.logits, dim=-1)
128
+ results = []
129
+ for i in range(len(inputs)):
130
+ scores = probs[i].tolist()
131
+ predictions = [
132
+ {"label": LABEL_MAP.get(j, f"LABEL_{j}"), "score": scores[j]}
133
+ for j in range(len(scores))
134
+ ]
135
+ predictions.sort(key=lambda x: x["score"], reverse=True)
136
+ results.append(predictions)
137
+
138
+ return results
known_platforms.json ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "alibaba.com",
3
+ "vinted.com",
4
+ "vinted.co.uk",
5
+ "jalan.net",
6
+ "instagram.com",
7
+ "linkedin.com",
8
+ "facebook.com",
9
+ "salla.sa",
10
+ "trustpilot.com",
11
+ "youtube.com",
12
+ "twitter.com",
13
+ "tiktok.com",
14
+ "twitch.tv",
15
+ "crunchbase.com",
16
+ "wa.me",
17
+ "zid.store",
18
+ "pinterest.com",
19
+ "whatsapp.com",
20
+ "ebay.co.uk",
21
+ "myshopify.com",
22
+ "etsy.com",
23
+ "discord.gg",
24
+ "wixsite.com",
25
+ "square.site",
26
+ "reddit.com",
27
+ "onlyfans.com",
28
+ "vinted.co.uk",
29
+ "yelp.com",
30
+ "stan.store",
31
+ "threads.net",
32
+ "bigcartel.com",
33
+ "t.me",
34
+ "snapchat.com",
35
+ "fresha.com",
36
+ "discord.com",
37
+ "planity.com",
38
+ "bbb.org",
39
+ "tip4serv.com",
40
+ "tumblr.com",
41
+ "sumupstore.com",
42
+ "booksy.com",
43
+ "depop.com",
44
+ "nextdoor.com",
45
+ "doctolib.fr",
46
+ "bookinbeautiful.com",
47
+ "bandcamp.com",
48
+ "spotlight.com",
49
+ "shopify.com",
50
+ "whatnot.com",
51
+ "squarespace.com",
52
+ "beacons.ai",
53
+ "fanvue.com",
54
+ "linktr.ee",
55
+ "company.site",
56
+ "webflow.io",
57
+ "wordpress.com",
58
+ "weebly.com",
59
+ "blogspot.com",
60
+ "ueniweb.com",
61
+ "canva.site",
62
+ "booking.com",
63
+ "godaddysites.com",
64
+ "nearcut.com",
65
+ "yell.com",
66
+ "substack.com",
67
+ "cargo.site",
68
+ "business.site",
69
+ "sell.app",
70
+ "goo.gl",
71
+ "g.co",
72
+ "g.page",
73
+ "google.com",
74
+ "hiboutik.com",
75
+ "inteletravel.uk",
76
+ "jimdofree.com",
77
+ "jimdo.com",
78
+ "just-eat.co.uk",
79
+ "ubereats.com",
80
+ "deliveroo.co.uk",
81
+ "gofundme.com",
82
+ "toasttab.com",
83
+ "holidaycottages.co.uk",
84
+ "imdb.com",
85
+ "upwork.com",
86
+ "fiverr.com",
87
+ "amazon.com",
88
+ "amazon.co.uk",
89
+ "amazon.co.jp",
90
+ "rakuten.co.jp",
91
+ "yahoo.co.jp",
92
+ "line.me",
93
+ "hirameki7.site",
94
+ "umin.ac.jp",
95
+ "guinot.com",
96
+ "motrio.fr",
97
+ "chiens-de-france.com",
98
+ "counselling-directory.org.uk",
99
+ "helloself.com",
100
+ "realpeople.co.uk",
101
+ "justmortgages.co.uk",
102
+ "linktree.com",
103
+ "shopifypreview.com",
104
+ "glossgenius.com",
105
+ "rekaz.io",
106
+ "systeme.io",
107
+ "direct.me",
108
+ "fato.me",
109
+ "notjusttravel.com",
110
+ "vivastreet.co.uk",
111
+ "mysellauth.com",
112
+ "garage-auto.info",
113
+
114
+ "alibaba.com",
115
+ "airbnb.com",
116
+ "airbnb.co.uk",
117
+ "airbnb.com.sg",
118
+ "airbnb.com.au",
119
+ "vinted.com",
120
+ "vinted.fr",
121
+ "vinted.de",
122
+ "herbalife.com",
123
+ "gonnaorder.com",
124
+ "keeq.io",
125
+ "soundbetter.com",
126
+ "10web-site.ai",
127
+ "github.io",
128
+ "sellsn.io",
129
+ "egift-store.com",
130
+ "tdc.ne.jp",
131
+
132
+ "ebay.com",
133
+ "ebay.de",
134
+ "ebay.fr",
135
+ "ebay.com.au",
136
+ "amazon.de",
137
+ "amazon.fr",
138
+ "amazon.es",
139
+ "amazon.it",
140
+ "amazon.in",
141
+ "amazon.com.au",
142
+ "rakuten.com",
143
+
144
+ "tripadvisor.com",
145
+ "tripadvisor.co.uk",
146
+ "tripadvisor.co.jp",
147
+ "zillow.com",
148
+ "realtor.com",
149
+ "rightmove.co.uk",
150
+ "zoopla.co.uk",
151
+ "indeed.com",
152
+ "glassdoor.com",
153
+ "glassdoor.co.uk",
154
+ "monster.com",
155
+ "freelancer.com",
156
+ "peopleperhour.com",
157
+ "toptal.com",
158
+ "99designs.com",
159
+
160
+ "x.com",
161
+ "bsky.app",
162
+ "mastodon.social",
163
+ "spotify.com",
164
+ "soundcloud.com",
165
+ "deezer.com",
166
+
167
+ "medium.com",
168
+ "patreon.com",
169
+ "ko-fi.com",
170
+ "buymeacoffee.com",
171
+ "gumroad.com",
172
+ "teachable.com",
173
+ "udemy.com",
174
+ "skillshare.com",
175
+ "coursera.org",
176
+ "eventbrite.com",
177
+ "eventbrite.co.uk",
178
+ "meetup.com",
179
+
180
+ "uber.com",
181
+ "lyft.com",
182
+ "doordash.com",
183
+ "grubhub.com",
184
+ "deliveroo.com",
185
+ "deliveroo.fr",
186
+ "deliveroo.com.au",
187
+ "just-eat.com",
188
+ "justeat.it",
189
+
190
+ "poshmark.com",
191
+ "mercari.com",
192
+ "rover.com",
193
+ "taskrabbit.com",
194
+ "thumbtack.com",
195
+ "temu.com",
196
+ "shein.com",
197
+ "wish.com",
198
+ "asos.com",
199
+ "zalando.com",
200
+ "zalando.co.uk",
201
+ "zalando.de",
202
+ "shopee.com",
203
+ "lazada.com",
204
+ "tokopedia.com",
205
+ "flipkart.com",
206
+ "olx.com",
207
+ "leboncoin.fr",
208
+ "allegro.pl",
209
+ "bol.com",
210
+ "cdiscount.com",
211
+ "fnac.com",
212
+ "otto.de",
213
+
214
+ "tabelog.com",
215
+ "hotpepper.jp",
216
+ "gnavi.co.jp",
217
+ "kakaku.com",
218
+ "zozo.jp",
219
+ "mercari.jp",
220
+ "minne.com",
221
+ "creema.jp",
222
+ "stores.jp",
223
+ "base.shop",
224
+ "booth.pm",
225
+ "note.com",
226
+ "ameblo.jp",
227
+ "fc2.com",
228
+
229
+ "wix.com",
230
+ "strikingly.com",
231
+ "site123.com",
232
+ "webnode.com",
233
+ "duda.co",
234
+ "format.com",
235
+ "cargocollective.com",
236
+ "contently.com",
237
+ "about.me",
238
+ "bio.link",
239
+ "lnk.bio",
240
+ "campsite.bio",
241
+ "carrd.co",
242
+ "typedream.com",
243
+ "super.so",
244
+
245
+ "behance.net",
246
+ "dribbble.com",
247
+ "deviantart.com",
248
+ "flickr.com",
249
+ "500px.com",
250
+ "vimeo.com",
251
+ "dailymotion.com",
252
+ "bookipi.com",
253
+ "odoo.com"
254
+ ]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tldextract