samithcs commited on
Commit
83d35cd
·
verified ·
1 Parent(s): 0105c1e

Update src/components/model_nlp_ner.py

Browse files
Files changed (1) hide show
  1. src/components/model_nlp_ner.py +32 -211
src/components/model_nlp_ner.py CHANGED
@@ -3,262 +3,83 @@ from transformers import DistilBertTokenizerFast, TFDistilBertForTokenClassifica
3
  import requests
4
  from io import BytesIO
5
  import numpy as np
6
- from sklearn.model_selection import train_test_split
7
- import numpy as np
8
  import joblib
9
- import sys
10
- from pathlib import Path
11
- sys.path.append(str(Path(__file__).resolve().parents[1]))
12
- from utils.logger import *
13
-
14
- import logging
15
- logger = logging.getLogger(__name__)
16
-
17
- EPOCHS = 30
18
- BATCH_SIZE = 8
19
- LEARNING_RATE = 5e-5
20
- VALIDATION_SPLIT = 0.15
21
- PATIENCE = 3
22
-
23
- try:
24
- from tensorflow_addons.optimizers import AdamW
25
- optimizer = AdamW(learning_rate=LEARNING_RATE, weight_decay=1e-2)
26
- except ImportError:
27
- optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
28
-
29
-
30
-
31
- examples = [
32
- (["Delay", "in", "Shanghai", "due", "to", "storms"], ["O", "O", "B-LOC", "O", "O", "B-EVENT"]),
33
- (["Any", "delay", "in", "vessel", "from", "USA", "to", "UAE", "?"], ["O", "O", "O", "O", "O", "B-LOC", "O", "B-LOC", "O"]),
34
- (["Cargo", "stuck", "at", "UAE", "port"], ["O", "O", "O", "B-LOC", "O"]),
35
- (["Weather", "alert", "for", "USA"], ["O", "O", "O", "B-LOC"]),
36
- (["Flood", "risk", "in", "Mumbai"], ["O", "O", "O", "B-LOC"]),
37
- (["Port", "closure", "Middle", "East"], ["O", "O", "B-LOC", "I-LOC"]),
38
- (["Is", "cargo", "delayed", "from", "USA", "to", "India", "?"], ["O", "O", "O", "O", "B-LOC", "O", "B-LOC", "O"]),
39
- (["Weather", "problems", "expected", "in", "USA"], ["O", "O", "O", "O", "B-LOC"]),
40
- (["Port", "strike", "at", "Singapore"], ["O", "O", "O", "B-LOC"]),
41
- (["Typhoon", "in", "Japan"], ["B-EVENT", "O", "B-LOC"]),
42
- (["Reroute", "shipments", "from", "Los", "Angeles"], ["O", "O", "O", "B-LOC", "I-LOC"]),
43
- (["Supply", "disruption", "Middle", "East"], ["O", "O", "B-LOC", "I-LOC"]),
44
- (["Severe", "fog", "in", "United", "Arab", "Emirates"], ["O", "O", "O", "B-LOC", "I-LOC", "I-LOC"]),
45
- (["Are", "shipments", "to", "Brazil", "affected", "by", "strike", "?"], ["O", "O", "O", "B-LOC", "O", "O", "B-EVENT", "O"]),
46
- (["Is", "Paris", "airport", "open", "after", "floods", "?"], ["O", "B-LOC", "O", "O", "O", "B-EVENT", "O"]),
47
- (["Delay", "reported", "in", "Berlin"], ["O", "O", "O", "B-LOC"]),
48
- (["Export", "hold", "at", "Los", "Angeles"], ["O", "O", "O", "B-LOC", "I-LOC"]),
49
- (["Typhoon", "warning", "for", "Japan"], ["B-EVENT", "O", "O", "B-LOC"]),
50
- (["Reroute", "cargo", "to", "Singapore"], ["O", "O", "O", "B-LOC"]),
51
- (["Is", "there", "labor", "strike", "in", "Canada", "?"], ["O", "O", "O", "B-EVENT", "O", "B-LOC", "O"]),
52
- (["Storm", "impact", "on", "United", "Kingdom"], ["B-EVENT", "O", "O", "B-LOC", "I-LOC"]),
53
- (["Supply", "disruption", "Italy"], ["O", "O", "B-LOC"]),
54
- (["Any", "hold-up", "in", "Dubai", "?",], ["O", "O", "O", "B-LOC", "O"]),
55
- (["Cargo", "delay", "at", "Rotterdam", "port"], ["O", "O", "O", "B-LOC", "O"]),
56
- (["Flood", "disrupts", "service", "in", "Turkey"], ["B-EVENT", "O", "O", "O", "B-LOC"]),
57
- (["Severe", "thunderstorm", "in", "New", "York", "City"], ["O", "B-EVENT", "O", "B-LOC", "I-LOC", "I-LOC"]),
58
- (["Is", "Shanghai", "port", "closed", "for", "holiday", "?"], ["O", "B-LOC", "O", "O", "O", "O", "O"]),
59
- (["France", "logistics", "strike"], ["B-LOC", "O", "B-EVENT"]),
60
- (["Export", "shipment", "to", "Spain", "delayed"], ["O", "O", "O", "B-LOC", "O"]),
61
- (["Cargo", "rerouted", "from", "Colombo", "to", "Sydney"], ["O", "O", "O", "B-LOC", "O", "B-LOC"]),
62
- (["Vessel", "from", "India", "held", "by", "customs"], ["O", "O", "B-LOC", "O", "O", "O"]),
63
- (["Is", "Singapore", "affected", "by", "monsoon", "season", "?"], ["O", "B-LOC", "O", "O", "B-EVENT", "I-EVENT", "O"]),
64
- (["Disruption", "in", "United", "Arab", "Emirates", "due", "to", "strike"], ["O", "O", "B-LOC", "I-LOC", "I-LOC", "O", "O", "B-EVENT"]),
65
- (["How", "long", "is", "the", "delay", "in", "Mexico", "?"], ["O", "O", "O", "O", "O", "O", "B-LOC", "O"]),
66
- (["Flood", "risk", "in", "Gujarat"], ["B-EVENT", "O", "O", "B-LOC"]),
67
- (["Severe", "weather", "disrupts", "Melbourne", "port"], ["B-EVENT", "O", "O", "B-LOC", "O"]),
68
- (["Export", "stopped", "from", "Jakarta", "because", "of", "strike"], ["O", "O", "O", "B-LOC", "O", "O", "B-EVENT"]),
69
- (["Storm", "warning", "for", "Delhi"], ["B-EVENT", "O", "O", "B-LOC"]),
70
- (["Any", "delay", "from", "United", "States", "to", "United", "Kingdom", "?"], ["O", "O", "O", "B-LOC", "I-LOC", "O", "B-LOC", "I-LOC", "O"]),
71
- (["Cargo", "stuck", "at", "Sao", "Paulo"], ["O", "O", "O", "B-LOC", "I-LOC"]),
72
- (["Shipping", "interruption", "in", "Cairo"], ["O", "O", "O", "B-LOC"]),
73
- (["Typhoon", "delays", "cargo", "to", "Hong", "Kong"], ["B-EVENT", "O", "O", "O", "B-LOC", "I-LOC"]),
74
- (["No", "disruption", "in", "Berlin"], ["O", "O", "O", "B-LOC"]),
75
- (["Port", "closure", "for", "Christmas", "in", "Canada"], ["O", "O", "O", "O", "O", "B-LOC"]),
76
- (["Is", "there", "a", "strike", "in", "Melbourne", "?"], ["O", "O", "O", "B-EVENT", "O", "B-LOC", "O"]),
77
- (["Shipment", "delayed", "in", "Mexico", "City"], ["O", "O", "O", "B-LOC", "I-LOC"]),
78
- (["Are", "vessels", "from", "Copenhagen", "blocked", "?"], ["O", "O", "O", "B-LOC", "O", "O"]),
79
- (["Heavy", "rains", "in", "Manila"], ["O", "B-EVENT", "O", "B-LOC"]),
80
- (["Strike", "at", "Johannesburg", "port"], ["B-EVENT", "O", "B-LOC", "O"]),
81
- (["Is", "the", "route", "from", "Italy", "to", "Brazil", "safe", "?"], ["O", "O", "O", "O", "B-LOC", "O", "B-LOC", "O", "O"]),
82
- (["Container", "stuck", "at", "Antwerp"], ["O", "O", "O", "B-LOC"]),
83
- (["Any", "blockade", "in", "Pakistan", "?"], ["O", "B-EVENT", "O", "B-LOC", "O"]),
84
- (["Flood", "alerts", "for", "Vietnam"], ["B-EVENT", "O", "O", "B-LOC"]),
85
- (["Are", "planes", "to", "Madrid", "canceled", "?"], ["O", "O", "O", "B-LOC", "O", "O"]),
86
- (["Shipments", "from", "Morocco", "are", "late"], ["O", "O", "B-LOC", "O", "O"]),
87
- (["Earthquake", "in", "Indonesia", "affecting", "deliveries"], ["B-EVENT", "O", "B-LOC", "O", "O"]),
88
- (["Rail", "disruption", "in", "Melbourne"], ["O", "B-EVENT", "O", "B-LOC"]),
89
- (["Any", "closure", "at", "Rotterdam", "port", "?"], ["O", "B-EVENT", "O", "B-LOC", "O", "O"]),
90
- (["Landslide", "blocks", "road", "to", "Lima"], ["B-EVENT", "O", "O", "O", "B-LOC"]),
91
- (["Flights", "to", "Bangkok", "suspended"], ["O", "O", "B-LOC", "O"]),
92
- (["Typhoon", "threat", "for", "Taipei"], ["B-EVENT", "O", "O", "B-LOC"]),
93
- (["Is", "Melbourne", "port", "operational", "today", "?"], ["O", "B-LOC", "O", "O", "O", "O"]),
94
- (["Japan", "export", "ban"], ["B-LOC", "O", "B-EVENT"]),
95
- (["Closure", "in", "Buenos", "Aires"], ["B-EVENT", "O", "B-LOC", "I-LOC"]),
96
- (["Truck", "strike", "delaying", "goods", "from", "Poland"], ["O", "B-EVENT", "O", "O", "O", "B-LOC"]),
97
- (["Shanghai", "flood", "disrupts", "cargo"], ["B-LOC", "B-EVENT", "O", "O"]),
98
- (["Supply", "held", "in", "Turkey", "because", "of", "strike"], ["O", "O", "O", "B-LOC", "O", "O", "B-EVENT"]),
99
- (["Port", "congestion", "in", "Los", "Angeles"], ["O", "B-EVENT", "O", "B-LOC", "I-LOC"]),
100
- (["Storm", "approaching", "Cape", "Town"], ["B-EVENT", "O", "B-LOC", "I-LOC"]),
101
- (["Bad", "weather", "New", "York"], ["O", "B-EVENT", "B-LOC", "I-LOC"]),
102
- (["Zambia", "roads", "closed", "due", "to", "flood"], ["B-LOC", "O", "O", "O", "O", "B-EVENT"]),
103
- (["Strike", "in", "Athens", "delays", "supply"], ["B-EVENT", "O", "B-LOC", "O", "O"]),
104
- (["Transport", "problem", "in", "Perth"], ["O", "B-EVENT", "O", "B-LOC"]),
105
- (["Typhoon", "interrupts", "shipments", "to", "Hong", "Kong"], ["B-EVENT", "O", "O", "O", "B-LOC", "I-LOC"]),
106
- (["Avalanche", "blocks", "Italian", "border"], ["B-EVENT", "O", "B-LOC", "O"]),
107
-
108
- ]
109
-
110
-
111
- sentences = [s for s, t in examples]
112
- tags = [t for s, t in examples]
113
- unique_tags = sorted(set(l for ts in tags for l in ts))
114
- label2id = {t: i for i, t in enumerate(unique_tags)}
115
- id2label = {i: t for t, i in label2id.items()}
116
- max_len = max(len(s) for s in sentences)
117
- tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
118
 
119
- def encode(sentences, labels, label2id, max_len):
120
- encodings = tokenizer(sentences, is_split_into_words=True, padding='max_length', truncation=True, max_length=max_len, return_tensors='tf')
121
- label_ids = []
122
- sample_weights = []
123
- for i, labs in enumerate(labels):
124
- ids = [label2id[l] for l in labs]
125
- padding_length = max_len - len(ids)
126
- ids += [0]*padding_length
127
- weights = [1]*len(labs) + [0]*padding_length
128
- label_ids.append(ids)
129
- sample_weights.append(weights)
130
- encodings['labels'] = tf.convert_to_tensor(label_ids)
131
- encodings['sample_weights'] = tf.convert_to_tensor(sample_weights, dtype=tf.float32)
132
- return encodings
133
 
134
- def train_ner_model():
135
- X_train, X_val, y_train, y_val = train_test_split(sentences, tags, test_size=VALIDATION_SPLIT, random_state=42)
136
- train_inputs = encode(X_train, y_train, label2id, max_len)
137
- val_inputs = encode(X_val, y_val, label2id, max_len)
138
 
139
 
140
- model = TFDistilBertForTokenClassification.from_pretrained(
141
- 'distilbert-base-uncased',
142
- num_labels=len(label2id),
143
- id2label=id2label,
144
- label2id=label2id
145
- )
146
- loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
147
- model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'],weighted_metrics=['accuracy'])
148
-
149
 
150
- callback = tf.keras.callbacks.EarlyStopping(
151
- monitor='val_loss',
152
- patience=PATIENCE,
153
- restore_best_weights=True
154
- )
155
 
156
- logger.info("Starting NER model training (tuned).")
157
-
158
-
159
- history = model.fit(
160
- {k: v for k, v in train_inputs.items() if k not in ['labels', 'sample_weights']},
161
- train_inputs['labels'],
162
- sample_weight=train_inputs['sample_weights'],
163
- epochs=EPOCHS,
164
- batch_size=BATCH_SIZE,
165
- validation_data=(
166
- {k: v for k, v in val_inputs.items() if k not in ['labels', 'sample_weights']},
167
- val_inputs['labels'],
168
- val_inputs['sample_weights']
169
- ),
170
- callbacks=[callback]
171
- )
172
-
173
- logger.info("Training complete.")
174
- logger.info(f"Best validation accuracy: {max(history.history['val_accuracy'])}")
175
-
176
- out_dir = Path(__file__).resolve().parents[2] / "artifacts" / "models" / "nlp_ner"
177
- out_dir.mkdir(parents=True, exist_ok=True)
178
- model.save_pretrained(out_dir / "ner_model")
179
- tokenizer.save_pretrained(out_dir / "ner_tokenizer")
180
- joblib.dump(label2id, out_dir / "label2id.joblib")
181
- logger.info(f"NER (TF) model, tokenizer, and label map saved to {out_dir}")
182
 
183
 
 
 
 
 
 
184
 
185
 
186
  def extract_entities_pipeline(text: str) -> dict:
187
-
188
- custom_model = TFDistilBertForTokenClassification.from_pretrained(
189
- "samithcs/nlp_ner/nlp_ner/ner_model", from_tf=True
190
- )
191
- custom_tokenizer = DistilBertTokenizerFast.from_pretrained("samithcs/nlp_ner/nlp_ner/ner_tokenizer")
192
-
193
-
194
- label_url = "https://huggingface.co/samithcs/nlp_ner/tree/main/nlp_ner/label2id.joblib"
195
- response = requests.get(label_url)
196
- label2id = joblib.load(BytesIO(response.content))
197
- id2label = {i: t for t, i in label2id.items()}
198
-
199
-
200
- max_len = 32
201
  tokens = text.split()
202
- encoding = custom_tokenizer(
203
  [tokens],
204
  is_split_into_words=True,
205
  return_tensors='tf',
206
  padding='max_length',
207
  truncation=True,
208
- max_length=max_len
209
  )
210
 
211
-
212
- outputs = custom_model({k: v for k, v in encoding.items() if k != "labels"})
213
- logits = outputs.logits.numpy()[0]
214
- pred_ids = np.argmax(logits, axis=-1)
215
-
216
 
217
- custom_entities = {"location": [], "event": []}
 
218
  current_loc, current_evt = [], []
 
219
  for w, id in zip(tokens, pred_ids[:len(tokens)]):
220
  label = id2label[id]
221
-
 
222
  if label == "B-LOC":
223
  if current_loc:
224
- custom_entities["location"].append(" ".join(current_loc))
225
- current_loc = []
226
  current_loc = [w]
227
  elif label == "I-LOC" and current_loc:
228
  current_loc.append(w)
229
  else:
230
  if current_loc:
231
- custom_entities["location"].append(" ".join(current_loc))
232
  current_loc = []
 
233
 
234
  if label == "B-EVENT":
235
  if current_evt:
236
- custom_entities["event"].append(" ".join(current_evt))
237
- current_evt = []
238
  current_evt = [w]
239
  elif label == "I-EVENT" and current_evt:
240
  current_evt.append(w)
241
  else:
242
  if current_evt:
243
- custom_entities["event"].append(" ".join(current_evt))
244
  current_evt = []
245
 
246
  if current_loc:
247
- custom_entities["location"].append(" ".join(current_loc))
248
  if current_evt:
249
- custom_entities["event"].append(" ".join(current_evt))
 
250
 
251
-
252
- hf_ner = pipeline("ner", grouped_entities=True, model="dbmdz/bert-large-cased-finetuned-conll03-english")
253
  hf_results = hf_ner(text)
254
  hf_locations = [ent['word'] for ent in hf_results if ent['entity_group'] == "LOC"]
 
255
 
256
-
257
- all_locations = set(custom_entities["location"]) | set(hf_locations)
258
- all_events = custom_entities["event"]
259
-
260
- return {"location": list(all_locations), "event": all_events}
261
-
262
-
263
- if __name__ == "__main__":
264
- train_ner_model()
 
3
  import requests
4
  from io import BytesIO
5
  import numpy as np
 
 
6
  import joblib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
9
 
10
 
11
+ ner_model = TFDistilBertForTokenClassification.from_pretrained(
12
+ "samithcs/nlp_ner/nlp_ner/ner_model", from_tf=True
13
+ )
14
+ ner_tokenizer = DistilBertTokenizerFast.from_pretrained("samithcs/nlp_ner/nlp_ner/ner_tokenizer")
 
 
 
 
 
15
 
 
 
 
 
 
16
 
17
+ label_url = "https://huggingface.co/samithcs/nlp_ner/resolve/main/nlp_ner/label2id.joblib"
18
+ response = requests.get(label_url)
19
+ label2id = joblib.load(BytesIO(response.content))
20
+ id2label = {i: t for t, i in label2id.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
+ hf_ner = pipeline(
24
+ "ner",
25
+ grouped_entities=True,
26
+ model="dbmdz/bert-large-cased-finetuned-conll03-english"
27
+ )
28
 
29
 
30
  def extract_entities_pipeline(text: str) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  tokens = text.split()
32
+ encoding = ner_tokenizer(
33
  [tokens],
34
  is_split_into_words=True,
35
  return_tensors='tf',
36
  padding='max_length',
37
  truncation=True,
38
+ max_length=32
39
  )
40
 
41
+ outputs = ner_model({k: v for k, v in encoding.items() if k != "labels"})
42
+ pred_ids = np.argmax(outputs.logits.numpy()[0], axis=-1)
 
 
 
43
 
44
+
45
+ entities = {"location": [], "event": []}
46
  current_loc, current_evt = [], []
47
+
48
  for w, id in zip(tokens, pred_ids[:len(tokens)]):
49
  label = id2label[id]
50
+
51
+
52
  if label == "B-LOC":
53
  if current_loc:
54
+ entities["location"].append(" ".join(current_loc))
 
55
  current_loc = [w]
56
  elif label == "I-LOC" and current_loc:
57
  current_loc.append(w)
58
  else:
59
  if current_loc:
60
+ entities["location"].append(" ".join(current_loc))
61
  current_loc = []
62
+
63
 
64
  if label == "B-EVENT":
65
  if current_evt:
66
+ entities["event"].append(" ".join(current_evt))
 
67
  current_evt = [w]
68
  elif label == "I-EVENT" and current_evt:
69
  current_evt.append(w)
70
  else:
71
  if current_evt:
72
+ entities["event"].append(" ".join(current_evt))
73
  current_evt = []
74
 
75
  if current_loc:
76
+ entities["location"].append(" ".join(current_loc))
77
  if current_evt:
78
+ entities["event"].append(" ".join(current_evt))
79
+
80
 
 
 
81
  hf_results = hf_ner(text)
82
  hf_locations = [ent['word'] for ent in hf_results if ent['entity_group'] == "LOC"]
83
+ entities["location"] = list(set(entities["location"]) | set(hf_locations))
84
 
85
+ return entities