TymaaHammouda commited on
Commit
bea52c9
·
1 Parent(s): 7d2277b
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -77,7 +77,7 @@ ckpt = torch.load(checkpoint_path, map_location="cpu")
77
  model = load_model_from_checkpoint(model, ckpt, strict=False)
78
  # model.eval()
79
 
80
- def predict_ner(sentence: str, model, tokenizer, id2label: dict, device="cpu", max_length=128):
81
  model.to(device)
82
  model.eval()
83
 
@@ -93,15 +93,37 @@ def predict_ner(sentence: str, model, tokenizer, id2label: dict, device="cpu", m
93
  enc = {k: v.to(device) for k, v in enc.items()}
94
 
95
  with torch.no_grad():
96
- out = model(**enc)
97
- logits = out.logits if hasattr(out, "logits") else out
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  pred_ids = logits.argmax(dim=-1)[0].tolist()
100
 
101
- word_ids = enc["input_ids"].new_zeros(enc["input_ids"].shape[1]).tolist()
102
- word_ids = tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
 
 
 
103
 
104
- # first subtoken per word -> label
105
  results = []
106
  seen = set()
107
  for tok_i, w_i in enumerate(word_ids):
@@ -121,13 +143,16 @@ def find_label_vocab(vocabs):
121
  return None, None
122
 
123
 
124
- idx, label_vocab = find_label_vocab(label_vocab)
125
- print("label vocab index:", idx)
126
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
127
 
 
 
 
 
128
 
129
  sentence = "ذهب احمد الى السوق"
130
  # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
131
  # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
132
- pairs = predict_ner(sentence, model, tokenizer, id2label, device="cpu")
133
  print(pairs)
 
77
  model = load_model_from_checkpoint(model, ckpt, strict=False)
78
  # model.eval()
79
 
80
+ def predict_ner_nested(sentence, model, tokenizer, id2label, device="cpu", max_length=128):
81
  model.to(device)
82
  model.eval()
83
 
 
93
  enc = {k: v.to(device) for k, v in enc.items()}
94
 
95
  with torch.no_grad():
96
+ # 1) Get contextual token embeddings from the internal transformer
97
+ tf_out = model.transformer(
98
+ input_ids=enc["input_ids"],
99
+ attention_mask=enc.get("attention_mask", None)
100
+ )
101
+ x = tf_out.last_hidden_state # [1, seq_len, hidden]
102
+
103
+ # 2) Dummy labels (because forward requires labels)
104
+ seq_len = x.size(1)
105
+ ignore_idx = getattr(model, "label_ignore_idx", 0)
106
+ dummy_labels = torch.full((1, seq_len), ignore_idx, dtype=torch.long, device=device)
107
+
108
+ # 3) Get logits
109
+ out = model(x, dummy_labels, segments_mask=None, get_sent_repr=False)
110
+
111
+ # Your forward may return logits or (loss, logits) or dict-like
112
+ if isinstance(out, (tuple, list)):
113
+ logits = out[-1]
114
+ elif hasattr(out, "logits"):
115
+ logits = out.logits
116
+ else:
117
+ logits = out # assume tensor
118
+
119
  pred_ids = logits.argmax(dim=-1)[0].tolist()
120
 
121
+ # Map tokens back to words (first subtoken per word)
122
+ if hasattr(enc, "word_ids"):
123
+ word_ids = enc.word_ids(batch_index=0)
124
+ else:
125
+ word_ids = tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
126
 
 
127
  results = []
128
  seen = set()
129
  for tok_i, w_i in enumerate(word_ids):
 
143
  return None, None
144
 
145
 
146
+ label_vocab = label_vocab[0] # the list loaded from pickle
 
147
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
148
 
149
+ # idx, label_vocab = find_label_vocab(label_vocab)
150
+ # print("label vocab index:", idx)
151
+ # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
152
+
153
 
154
  sentence = "ذهب احمد الى السوق"
155
  # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
156
  # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
157
+ pairs = predict_ner_nested(sentence, model, tokenizer, id2label, device="cpu")
158
  print(pairs)