alaajabari commited on
Commit
505f3f4
·
verified ·
1 Parent(s): 46171ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -128
app.py CHANGED
@@ -1,121 +1,128 @@
1
- from fastapi import FastAPI
2
- import torch
3
  import pickle
 
 
 
 
 
 
 
 
 
 
 
4
  from huggingface_hub import hf_hub_download, snapshot_download
 
 
 
 
 
 
 
 
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
- from transformers import AutoTokenizer, AutoModel
7
- import inspect
8
- from collections import namedtuple
9
  from Nested.utils.helpers import load_checkpoint
10
  from Nested.utils.data import get_dataloaders, text2segments
11
- import json
12
- from pydantic import BaseModel
13
- from fastapi.responses import JSONResponse
14
  from IBO_to_XML import IBO_to_XML
15
  from XML_to_HTML import NER_XML_to_HTML
16
  from NER_Distiller import distill_entities
17
 
 
 
 
18
  app = FastAPI()
19
 
20
- pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
 
 
 
 
 
 
 
 
 
 
 
21
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
22
  encoder = AutoModel.from_pretrained(pretrained_path).eval()
23
 
24
-
25
- checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
 
 
26
 
27
  args_path = hf_hub_download(
28
  repo_id="SinaLab/Nested",
29
  filename="args.json"
30
  )
31
 
32
- with open(args_path, 'r') as f:
33
  args_data = json.load(f)
34
-
35
- # Load model
36
  with open("Nested/utils/tag_vocab.pkl", "rb") as f:
37
  label_vocab = pickle.load(f)
38
 
39
- label_vocab = label_vocab[0] # the list loaded from pickle
40
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
41
 
 
 
 
 
 
42
  def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
43
- # Split the text into words
44
  words = sentence.split()
45
-
46
- # Initialize variables
47
  groups = []
48
  current_group = ""
49
  group_size = 0
50
-
51
- # Iterate through the words
52
  for word in words:
53
  if group_size < max_words_per_sentence - 1:
54
- if len(current_group) == 0:
55
- current_group = word
56
- else:
57
- current_group += " " + word
58
  group_size += 1
59
  else:
60
  current_group += " " + word
61
  groups.append(current_group)
62
  current_group = ""
63
  group_size = 0
64
-
65
- # Add the last group if it contains less than n words
66
  if current_group:
67
  groups.append(current_group)
68
-
69
- return groups
70
 
 
71
 
72
 
73
  def remove_empty_values(sentences):
74
- return [value for value in sentences if value != '']
75
-
76
-
77
- def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
78
- separators = []
79
- split_text = [text]
80
- if new_line==True:
81
- separators.append('\n')
82
- if dot==True:
83
- separators.append('.')
84
- if question_mark==True:
85
- separators.append('?')
86
- separators.append('؟')
87
- if exclamation_mark==True:
88
- separators.append('!')
89
-
90
- for sep in separators:
91
- new_split_text = []
92
- for part in split_text:
93
- tokens = part.split(sep)
94
- tokens_with_separator = [token + sep for token in tokens[:-1]]
95
- tokens_with_separator.append(tokens[-1].strip())
96
- new_split_text.extend(tokens_with_separator)
97
- split_text = new_split_text
98
-
99
- split_text = remove_empty_values(split_text)
100
  return split_text
101
 
 
102
  def jsons_to_list_of_lists(json_list):
103
- return [[d['token'], d['tags']] for d in json_list]
104
 
105
- tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
106
 
107
  def extract(sentence):
108
  dataset, token_vocab = text2segments(sentence)
109
 
110
- vocabs = namedtuple("Vocab", ["tags", "tokens"])
111
- vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
 
 
112
 
113
  dataloader = get_dataloaders(
114
  (dataset,),
115
  vocab,
116
  args_data,
117
  batch_size=32,
118
- shuffle=(False,),
119
  )[0]
120
 
121
  segments = tagger.infer(dataloader)
@@ -124,95 +131,246 @@ def extract(sentence):
124
 
125
  for segment in segments:
126
  for token in segment:
127
- item = {}
128
- item["token"] = token.text
129
 
130
- list_of_tags = [t["tag"] for t in token.pred_tag]
131
- list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
132
 
133
- if not list_of_tags:
134
- item["tags"] = "O"
135
- else:
136
- item["tags"] = " ".join(list_of_tags)
137
  lists.append(item)
 
138
  return lists
139
 
140
 
141
- def NER(sentence, mode):
142
- output_list = []
143
- xml = ""
144
- if mode.strip() == "1":
145
- output_list = jsons_to_list_of_lists(extract(sentence))
146
- return output_list
147
- elif mode.strip() == "2":
148
- if output_list != []:
149
- xml = IBO_to_XML(output_list)
150
- return xml
151
- else:
152
- output_list = jsons_to_list_of_lists(extract(sentence))
153
- xml = IBO_to_XML(output_list)
154
- return xml
155
-
156
- elif mode.strip() == "3":
157
- if xml != "":
158
- html = NER_XML_to_HTML(xml)
159
- return html
160
- else:
161
- output_list = jsons_to_list_of_lists(extract(sentence))
162
- xml = IBO_to_XML(output_list)
163
- html = NER_XML_to_HTML(xml)
164
- return html
165
-
166
- elif mode.strip() == "4": # json short
167
- if output_list != []:
168
- json_short = distill_entities(output_list)
169
- return json_short
170
- else:
171
- output_list = jsons_to_list_of_lists(extract(sentence))
172
- json_short = distill_entities(output_list)
173
- return json_short
174
 
 
 
175
 
 
 
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  class NERRequest(BaseModel):
178
  text: str
179
- mode: str
 
 
 
 
 
180
 
 
 
 
181
  @app.post("/predict")
182
  def predict(request: NERRequest):
183
- # Load tagger
184
- text = request.text
185
- mode = request.mode
186
 
187
- sentences = sentence_tokenizer(
188
- text, dot=False, new_line=True, question_mark=False, exclamation_mark=False
189
- )
 
 
 
190
 
191
- lists = []
192
  for sentence in sentences:
193
- se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300)
194
- for s in se:
195
- output_list = NER(s, mode)
196
- lists.append(output_list)
 
197
 
198
- content = {
199
- "resp": lists,
200
  "statusText": "OK",
201
- "statusCode": 0,
202
- }
203
 
204
- return JSONResponse(
205
- content=content,
206
- media_type="application/json",
207
- status_code=200,
208
- )
209
 
210
- from fastapi.staticfiles import StaticFiles
211
- from fastapi.responses import FileResponse
 
 
 
212
 
213
- # mount frontend
214
- app.mount("/static", StaticFiles(directory="static"), name="static")
215
 
216
- @app.get("/")
217
- def home():
218
- return FileResponse("static/index.html")
 
 
 
1
+ import os
2
+ import json
3
  import pickle
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from itertools import permutations
9
+ from collections import defaultdict
10
+ from pydantic import BaseModel
11
+ from fastapi import FastAPI
12
+ from fastapi.responses import JSONResponse, FileResponse
13
+ from fastapi.staticfiles import StaticFiles
14
+
15
  from huggingface_hub import hf_hub_download, snapshot_download
16
+
17
+ from transformers import (
18
+ AutoTokenizer,
19
+ AutoModel,
20
+ BertModel,
21
+ PreTrainedTokenizerFast
22
+ )
23
+
24
  from Nested.nn.BertSeqTagger import BertSeqTagger
 
 
 
25
  from Nested.utils.helpers import load_checkpoint
26
  from Nested.utils.data import get_dataloaders, text2segments
27
+
 
 
28
  from IBO_to_XML import IBO_to_XML
29
  from XML_to_HTML import NER_XML_to_HTML
30
  from NER_Distiller import distill_entities
31
 
32
+ # =========================
33
+ # App
34
+ # =========================
35
  app = FastAPI()
36
 
37
+ # mount frontend
38
+ app.mount("/static", StaticFiles(directory="static"), name="static")
39
+
40
+ @app.get("/")
41
+ def home():
42
+ return FileResponse("static/index.html")
43
+
44
+ # =========================
45
+ # NER MODEL (your working one)
46
+ # =========================
47
+ pretrained_path = "aubmindlab/bert-base-arabertv2"
48
+
49
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
50
  encoder = AutoModel.from_pretrained(pretrained_path).eval()
51
 
52
+ checkpoint_path = snapshot_download(
53
+ repo_id="SinaLab/Nested",
54
+ allow_patterns="checkpoints/"
55
+ )
56
 
57
  args_path = hf_hub_download(
58
  repo_id="SinaLab/Nested",
59
  filename="args.json"
60
  )
61
 
62
+ with open(args_path, "r") as f:
63
  args_data = json.load(f)
64
+
 
65
  with open("Nested/utils/tag_vocab.pkl", "rb") as f:
66
  label_vocab = pickle.load(f)
67
 
68
+ label_vocab = label_vocab[0]
69
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
70
 
71
+ tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
72
+
73
+ # =========================
74
+ # Helpers (NER)
75
+ # =========================
76
  def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
 
77
  words = sentence.split()
 
 
78
  groups = []
79
  current_group = ""
80
  group_size = 0
81
+
 
82
  for word in words:
83
  if group_size < max_words_per_sentence - 1:
84
+ current_group = word if current_group == "" else current_group + " " + word
 
 
 
85
  group_size += 1
86
  else:
87
  current_group += " " + word
88
  groups.append(current_group)
89
  current_group = ""
90
  group_size = 0
91
+
 
92
  if current_group:
93
  groups.append(current_group)
 
 
94
 
95
+ return groups
96
 
97
 
98
  def remove_empty_values(sentences):
99
+ return [v for v in sentences if v != ""]
100
+
101
+
102
+ def sentence_tokenizer(text):
103
+ split_text = text.split(".")
104
+ split_text = remove_empty_values(split_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return split_text
106
 
107
+
108
  def jsons_to_list_of_lists(json_list):
109
+ return [[d["token"], d["tags"]] for d in json_list]
110
 
 
111
 
112
  def extract(sentence):
113
  dataset, token_vocab = text2segments(sentence)
114
 
115
+ vocab = type("Vocab", (), {})(
116
+ tokens=token_vocab,
117
+ tags=tag_vocab
118
+ )
119
 
120
  dataloader = get_dataloaders(
121
  (dataset,),
122
  vocab,
123
  args_data,
124
  batch_size=32,
125
+ shuffle=(False,)
126
  )[0]
127
 
128
  segments = tagger.infer(dataloader)
 
131
 
132
  for segment in segments:
133
  for token in segment:
134
+ item = {"token": token.text}
 
135
 
136
+ tags = [t["tag"] for t in token.pred_tag]
137
+ tags = [i for i in tags if i not in ("O", " ", "")]
138
 
139
+ item["tags"] = "O" if not tags else " ".join(tags)
 
 
 
140
  lists.append(item)
141
+
142
  return lists
143
 
144
 
145
+ # =========================
146
+ # NER distillation (your logic)
147
+ # =========================
148
+ def distill_entities(entities):
149
+ list_output = []
150
+ temp_entities = sortTags(entities)
151
+
152
+ temp_list = [["", "", 0, 0]]
153
+ word_position = 0
154
+
155
+ for entity in temp_entities:
156
+ token = entity["token"]
157
+ tags = entity["tags"].split()
158
+
159
+ counter_tag = 0
160
+ for tag in tags:
161
+ if counter_tag >= len(temp_list):
162
+ temp_list.append(["", "", 0, 0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ if tag == "O":
165
+ pass
166
 
167
+ elif tag.startswith("B-"):
168
+ temp_list[counter_tag] = [token + " ", tag[2:], word_position, word_position]
169
 
170
+ elif tag.startswith("I-"):
171
+ for j in range(counter_tag, len(temp_list)):
172
+ if temp_list[j][1] == tag[2:]:
173
+ temp_list[j][0] += token + " "
174
+ temp_list[j][3] = word_position
175
+ break
176
+
177
+ counter_tag += 1
178
+
179
+ word_position += 1
180
+
181
+ for j in range(len(temp_list)):
182
+ if temp_list[j][1] != "":
183
+ list_output.append(temp_list[j])
184
+
185
+ return list_output
186
+
187
+
188
+ def sortTags(entities):
189
+ return entities
190
+
191
+
192
+ def entities_and_types(sentence):
193
+ token_tags = extract(sentence)
194
+ entities = distill_entities(token_tags)
195
+
196
+ entity_dict = {}
197
+ for name, entity_type, _, _ in entities:
198
+ entity_dict[name.strip()] = entity_type
199
+
200
+ return entity_dict
201
+
202
+
203
+ # =========================
204
+ # Relation Model
205
+ # =========================
206
+ repo_id_rel = "aaljabari/arabic-relation-extraction-v1"
207
+
208
+ relation_tokenizer = PreTrainedTokenizerFast(
209
+ tokenizer_file=hf_hub_download(repo_id_rel, "tokenizer.json")
210
+ )
211
+
212
+ weights_path = hf_hub_download(repo_id_rel, "pytorch_model.bin")
213
+
214
+ with open(hf_hub_download(repo_id_rel, "tag_vocab.pkl"), "rb") as f:
215
+ vocab = pickle.load(f)
216
+
217
+ rel2id = vocab["rel2id"]
218
+ id2rel = vocab["id2rel"]
219
+
220
+
221
+ class BertRE(nn.Module):
222
+ def __init__(self, num_labels):
223
+ super().__init__()
224
+ self.bert = BertModel.from_pretrained(repo_id_rel)
225
+
226
+ hidden = self.bert.config.hidden_size
227
+ self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
228
+ self.classifier = nn.Linear(hidden * 2, num_labels)
229
+
230
+ def forward(self, input_ids, attention_mask, sub_pos, obj_pos):
231
+ outputs = self.bert(
232
+ input_ids=input_ids,
233
+ attention_mask=attention_mask
234
+ )
235
+
236
+ hidden = outputs.last_hidden_state
237
+ batch = hidden.shape[0]
238
+
239
+ sub_vec = hidden[torch.arange(batch), sub_pos]
240
+ obj_vec = hidden[torch.arange(batch), obj_pos]
241
+
242
+ pair = torch.cat([sub_vec, obj_vec], dim=1)
243
+ pair = self.dropout(pair)
244
+
245
+ return self.classifier(pair)
246
+
247
+
248
+ model_re = BertRE(num_labels=len(rel2id))
249
+ model_re.load_state_dict(torch.load(weights_path, map_location="cpu"))
250
+ model_re.eval()
251
+
252
+ # =========================
253
+ # Relation utilities
254
+ # =========================
255
+ relation_lookup = defaultdict(lambda: defaultdict(list))
256
+
257
+
258
+ def insert_markers(sentence, ent1, ent2):
259
+ if ent1 not in sentence or ent2 not in sentence:
260
+ return None
261
+
262
+ s = sentence
263
+ s = s.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
264
+ s = s.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
265
+
266
+ return s
267
+
268
+
269
+ def encode(sentence):
270
+ enc = relation_tokenizer(
271
+ sentence,
272
+ max_length=128,
273
+ padding="max_length",
274
+ truncation=True,
275
+ return_tensors="pt"
276
+ )
277
+
278
+ input_ids = enc["input_ids"]
279
+ attention_mask = enc["attention_mask"]
280
+
281
+ sub_id = relation_tokenizer.convert_tokens_to_ids("[Sub]")
282
+ obj_id = relation_tokenizer.convert_tokens_to_ids("[Obj]")
283
+
284
+ sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1]
285
+ obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1]
286
+
287
+ return input_ids, attention_mask, sub_pos, obj_pos
288
+
289
+
290
+ def predict_relation(sentence):
291
+ input_ids, mask, sub_pos, obj_pos = encode(sentence)
292
+
293
+ with torch.no_grad():
294
+ logits = model_re(input_ids, mask, sub_pos, obj_pos)
295
+
296
+ probs = F.softmax(logits, dim=-1)
297
+
298
+ pred = torch.argmax(probs, dim=-1).item()
299
+ conf = probs[0, pred].item()
300
+
301
+ return id2rel[pred], conf
302
+
303
+
304
+ def relation_extractor(sentence):
305
+ entities = entities_and_types(sentence)
306
+ output = []
307
+
308
+ entity_items = list(entities.items())
309
+ pairs = [(e1, e2) for e1, e2 in permutations(entity_items, 2)]
310
+
311
+ for (ent1, type1), (ent2, type2) in pairs:
312
+
313
+ marked = insert_markers(sentence, ent1, ent2)
314
+ if not marked:
315
+ continue
316
+
317
+ rel, conf = predict_relation(marked)
318
+
319
+ if conf > 0.80 and rel != "no_relation":
320
+ output.append([ent1, rel, ent2, conf])
321
+
322
+ return output
323
+
324
+
325
+ # =========================
326
+ # API Models
327
+ # =========================
328
  class NERRequest(BaseModel):
329
  text: str
330
+ mode: str = "1"
331
+
332
+
333
+ class RERequest(BaseModel):
334
+ text: str
335
+
336
 
337
+ # =========================
338
+ # NER endpoint
339
+ # =========================
340
  @app.post("/predict")
341
  def predict(request: NERRequest):
 
 
 
342
 
343
+ text = request.text
344
+ mode = request.mode
345
+
346
+ sentences = sentence_tokenizer(text)
347
+
348
+ results = []
349
 
 
350
  for sentence in sentences:
351
+ chunks = split_text_into_groups_of_Ns(sentence, 300)
352
+
353
+ for c in chunks:
354
+ output_list = jsons_to_list_of_lists(extract(c))
355
+ results.append(output_list)
356
 
357
+ return JSONResponse({
358
+ "resp": results,
359
  "statusText": "OK",
360
+ "statusCode": 0
361
+ })
362
 
 
 
 
 
 
363
 
364
+ # =========================
365
+ # Relation endpoint
366
+ # =========================
367
+ @app.post("/predict_re")
368
+ def predict_re(request: RERequest):
369
 
370
+ results = relation_extractor(request.text)
 
371
 
372
+ return JSONResponse({
373
+ "resp": results,
374
+ "statusText": "OK",
375
+ "statusCode": 0
376
+ })