TymaaHammouda commited on
Commit
0d0ecdd
·
1 Parent(s): bea52c9
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import pickle
4
  from huggingface_hub import hf_hub_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
- from transformers import AutoTokenizer
7
  app = FastAPI()
8
  print("Version 2...")
9
 
@@ -13,8 +13,11 @@ print("Version 2...")
13
  # filename="tag_vocab.pkl"
14
  # )
15
 
16
- pretrained_path = "aubmindlab/bert-base-arabertv2" # change if different in your training
 
17
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
 
 
18
 
19
  checkpoint_path = hf_hub_download(
20
  repo_id="SinaLab/Nested",
@@ -77,9 +80,9 @@ ckpt = torch.load(checkpoint_path, map_location="cpu")
77
  model = load_model_from_checkpoint(model, ckpt, strict=False)
78
  # model.eval()
79
 
80
- def predict_ner_nested(sentence, model, tokenizer, id2label, device="cpu", max_length=128):
81
- model.to(device)
82
- model.eval()
83
 
84
  words = sentence.split()
85
 
@@ -93,36 +96,26 @@ def predict_ner_nested(sentence, model, tokenizer, id2label, device="cpu", max_l
93
  enc = {k: v.to(device) for k, v in enc.items()}
94
 
95
  with torch.no_grad():
96
- # 1) Get contextual token embeddings from the internal transformer
97
- tf_out = model.transformer(
98
  input_ids=enc["input_ids"],
99
  attention_mask=enc.get("attention_mask", None)
100
- )
101
- x = tf_out.last_hidden_state # [1, seq_len, hidden]
102
 
103
- # 2) Dummy labels (because forward requires labels)
104
- seq_len = x.size(1)
105
- ignore_idx = getattr(model, "label_ignore_idx", 0)
106
- dummy_labels = torch.full((1, seq_len), ignore_idx, dtype=torch.long, device=device)
107
 
108
- # 3) Get logits
109
- out = model(x, dummy_labels, segments_mask=None, get_sent_repr=False)
110
 
111
- # Your forward may return logits or (loss, logits) or dict-like
112
  if isinstance(out, (tuple, list)):
113
  logits = out[-1]
114
  elif hasattr(out, "logits"):
115
  logits = out.logits
116
  else:
117
- logits = out # assume tensor
118
 
119
  pred_ids = logits.argmax(dim=-1)[0].tolist()
120
 
121
- # Map tokens back to words (first subtoken per word)
122
- if hasattr(enc, "word_ids"):
123
- word_ids = enc.word_ids(batch_index=0)
124
- else:
125
- word_ids = tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
126
 
127
  results = []
128
  seen = set()
@@ -134,6 +127,7 @@ def predict_ner_nested(sentence, model, tokenizer, id2label, device="cpu", max_l
134
 
135
  return results
136
 
 
137
  def find_label_vocab(vocabs):
138
  for i, v in enumerate(vocabs):
139
  if hasattr(v, "itos"):
@@ -143,6 +137,8 @@ def find_label_vocab(vocabs):
143
  return None, None
144
 
145
 
 
 
146
  label_vocab = label_vocab[0] # the list loaded from pickle
147
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
148
 
@@ -154,5 +150,5 @@ id2label = {i: s for i, s in enumerate(label_vocab.itos)}
154
  sentence = "ذهب احمد الى السوق"
155
  # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
156
  # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
157
- pairs = predict_ner_nested(sentence, model, tokenizer, id2label, device="cpu")
158
  print(pairs)
 
3
  import pickle
4
  from huggingface_hub import hf_hub_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
+ from transformers import AutoTokenizer, AutoModel
7
  app = FastAPI()
8
  print("Version 2...")
9
 
 
13
  # filename="tag_vocab.pkl"
14
  # )
15
 
16
+
17
+ pretrained_path = "aubmindlab/bert-base-arabertv2" # set to your training backbone
18
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
19
+ encoder = AutoModel.from_pretrained(pretrained_path)
20
+ encoder.eval()
21
 
22
  checkpoint_path = hf_hub_download(
23
  repo_id="SinaLab/Nested",
 
80
  model = load_model_from_checkpoint(model, ckpt, strict=False)
81
  # model.eval()
82
 
83
+ def predict_ner_with_external_encoder(sentence, tagger, encoder, tokenizer, id2label, device="cpu", max_length=128):
84
+ tagger.to(device).eval()
85
+ encoder.to(device).eval()
86
 
87
  words = sentence.split()
88
 
 
96
  enc = {k: v.to(device) for k, v in enc.items()}
97
 
98
  with torch.no_grad():
99
+ x = encoder(
 
100
  input_ids=enc["input_ids"],
101
  attention_mask=enc.get("attention_mask", None)
102
+ ).last_hidden_state # [1, seq_len, hidden]
 
103
 
104
+ ignore_idx = getattr(tagger, "label_ignore_idx", 0)
105
+ dummy_labels = torch.full((1, x.size(1)), ignore_idx, dtype=torch.long, device=device)
 
 
106
 
107
+ out = tagger(x, dummy_labels)
 
108
 
 
109
  if isinstance(out, (tuple, list)):
110
  logits = out[-1]
111
  elif hasattr(out, "logits"):
112
  logits = out.logits
113
  else:
114
+ logits = out
115
 
116
  pred_ids = logits.argmax(dim=-1)[0].tolist()
117
 
118
+ word_ids = tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
 
 
 
 
119
 
120
  results = []
121
  seen = set()
 
127
 
128
  return results
129
 
130
+
131
  def find_label_vocab(vocabs):
132
  for i, v in enumerate(vocabs):
133
  if hasattr(v, "itos"):
 
137
  return None, None
138
 
139
 
140
+
141
+
142
  label_vocab = label_vocab[0] # the list loaded from pickle
143
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
144
 
 
150
  sentence = "ذهب احمد الى السوق"
151
  # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
152
  # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
153
+ pairs = predict_ner_with_external_encoder(sentence, model, encoder, tokenizer, id2label, device="cpu")
154
  print(pairs)