TymaaHammouda commited on
Commit
42ba242
·
1 Parent(s): 0d0ecdd
Files changed (2) hide show
  1. app.py +39 -19
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import pickle
4
  from huggingface_hub import hf_hub_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
  from transformers import AutoTokenizer, AutoModel
 
7
  app = FastAPI()
8
  print("Version 2...")
9
 
@@ -14,10 +15,9 @@ print("Version 2...")
14
  # )
15
 
16
 
17
- pretrained_path = "aubmindlab/bert-base-arabertv2" # set to your training backbone
18
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
19
- encoder = AutoModel.from_pretrained(pretrained_path)
20
- encoder.eval()
21
 
22
  checkpoint_path = hf_hub_download(
23
  repo_id="SinaLab/Nested",
@@ -80,7 +80,7 @@ ckpt = torch.load(checkpoint_path, map_location="cpu")
80
  model = load_model_from_checkpoint(model, ckpt, strict=False)
81
  # model.eval()
82
 
83
- def predict_ner_with_external_encoder(sentence, tagger, encoder, tokenizer, id2label, device="cpu", max_length=128):
84
  tagger.to(device).eval()
85
  encoder.to(device).eval()
86
 
@@ -101,21 +101,11 @@ def predict_ner_with_external_encoder(sentence, tagger, encoder, tokenizer, id2l
101
  attention_mask=enc.get("attention_mask", None)
102
  ).last_hidden_state # [1, seq_len, hidden]
103
 
104
- ignore_idx = getattr(tagger, "label_ignore_idx", 0)
105
- dummy_labels = torch.full((1, x.size(1)), ignore_idx, dtype=torch.long, device=device)
106
-
107
- out = tagger(x, dummy_labels)
108
-
109
- if isinstance(out, (tuple, list)):
110
- logits = out[-1]
111
- elif hasattr(out, "logits"):
112
- logits = out.logits
113
- else:
114
- logits = out
115
 
116
  pred_ids = logits.argmax(dim=-1)[0].tolist()
117
-
118
- word_ids = tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
119
 
120
  results = []
121
  seen = set()
@@ -127,7 +117,6 @@ def predict_ner_with_external_encoder(sentence, tagger, encoder, tokenizer, id2l
127
 
128
  return results
129
 
130
-
131
  def find_label_vocab(vocabs):
132
  for i, v in enumerate(vocabs):
133
  if hasattr(v, "itos"):
@@ -137,6 +126,37 @@ def find_label_vocab(vocabs):
137
  return None, None
138
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  label_vocab = label_vocab[0] # the list loaded from pickle
@@ -150,5 +170,5 @@ id2label = {i: s for i, s in enumerate(label_vocab.itos)}
150
  sentence = "ذهب احمد الى السوق"
151
  # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
152
  # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
153
- pairs = predict_ner_with_external_encoder(sentence, model, encoder, tokenizer, id2label, device="cpu")
154
  print(pairs)
 
4
  from huggingface_hub import hf_hub_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
  from transformers import AutoTokenizer, AutoModel
7
+ import inspect
8
  app = FastAPI()
9
  print("Version 2...")
10
 
 
15
  # )
16
 
17
 
18
+ pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
19
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
20
+ encoder = AutoModel.from_pretrained(pretrained_path).eval()
 
21
 
22
  checkpoint_path = hf_hub_download(
23
  repo_id="SinaLab/Nested",
 
80
  model = load_model_from_checkpoint(model, ckpt, strict=False)
81
  # model.eval()
82
 
83
+ def predict_ner(sentence, tagger, encoder, tokenizer, id2label, device="cpu", max_length=128):
84
  tagger.to(device).eval()
85
  encoder.to(device).eval()
86
 
 
101
  attention_mask=enc.get("attention_mask", None)
102
  ).last_hidden_state # [1, seq_len, hidden]
103
 
104
+ logits = _call_tagger(tagger, x, device)
 
 
 
 
 
 
 
 
 
 
105
 
106
  pred_ids = logits.argmax(dim=-1)[0].tolist()
107
+ word_ids = _get_word_ids(tokenizer, words, tokenizer(words, is_split_into_words=True, return_tensors="pt",
108
+ truncation=True, max_length=max_length), max_length)
109
 
110
  results = []
111
  seen = set()
 
117
 
118
  return results
119
 
 
120
  def find_label_vocab(vocabs):
121
  for i, v in enumerate(vocabs):
122
  if hasattr(v, "itos"):
 
126
  return None, None
127
 
128
 
129
+ def _get_word_ids(tokenizer, words, enc, max_length):
130
+ # Fast tokenizers: BatchEncoding has word_ids()
131
+ if hasattr(enc, "word_ids"):
132
+ return enc.word_ids(batch_index=0)
133
+ # Fallback
134
+ return tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
135
+
136
+ def _call_tagger(tagger, x, device):
137
+ # Calls forward in a compatible way (x only vs x+labels, etc.)
138
+ params = list(inspect.signature(tagger.forward).parameters.keys())
139
+ # common: ['x'] or ['x','labels',...]
140
+ if "labels" in params:
141
+ ignore_idx = getattr(tagger, "label_ignore_idx", 0)
142
+ labels = torch.full((x.size(0), x.size(1)), ignore_idx, dtype=torch.long, device=device)
143
+
144
+ kwargs = {}
145
+ if "segments_mask" in params:
146
+ kwargs["segments_mask"] = None
147
+ if "get_sent_repr" in params:
148
+ kwargs["get_sent_repr"] = False
149
+
150
+ out = tagger(x, labels, **kwargs)
151
+ else:
152
+ out = tagger(x)
153
+
154
+ # normalize outputs to logits tensor
155
+ if isinstance(out, (tuple, list)):
156
+ return out[-1]
157
+ if hasattr(out, "logits"):
158
+ return out.logits
159
+ return out
160
 
161
 
162
  label_vocab = label_vocab[0] # the list loaded from pickle
 
170
  sentence = "ذهب احمد الى السوق"
171
  # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
172
  # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
173
+ pairs = predict_ner(sentence, model, encoder, tokenizer, id2label, device="cpu")
174
  print(pairs)
requirements.txt CHANGED
@@ -3,4 +3,5 @@ fastapi
3
  uvicorn
4
  numpy
5
  huggingface_hub
6
- transformers
 
 
3
  uvicorn
4
  numpy
5
  huggingface_hub
6
+ transformers
7
+ inspect