TymaaHammouda commited on
Commit
39af8fe
·
1 Parent(s): f316449

Update app file

Browse files
Files changed (1) hide show
  1. app.py +37 -138
app.py CHANGED
@@ -5,6 +5,10 @@ from huggingface_hub import hf_hub_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
  from transformers import AutoTokenizer, AutoModel
7
  import inspect
 
 
 
 
8
  app = FastAPI()
9
  print("Version 2...")
10
 
@@ -24,150 +28,45 @@ checkpoint_path = hf_hub_download(
24
  filename="checkpoints/checkpoint_2.pt"
25
  )
26
 
 
 
 
 
 
27
  # Load model
28
  with open("Nested/utils/tag_vocab.pkl", "rb") as f:
29
  label_vocab = pickle.load(f)
30
 
31
- # model = torch.load(checkpoint_path, map_location="cpu")
32
- model = BertSeqTagger(
33
- bert_model="aubmindlab/bert-base-arabertv2",
34
- dropout=0.1
35
- )
36
-
37
-
38
- def load_model_from_checkpoint(model, checkpoint, strict=True):
39
- if isinstance(checkpoint, torch.nn.Module):
40
- return checkpoint
41
-
42
- if not isinstance(checkpoint, dict):
43
- raise TypeError(f"Unsupported checkpoint type: {type(checkpoint)}")
44
-
45
- candidates = [
46
- "state_dict",
47
- "model_state_dict",
48
- "model",
49
- "net",
50
- "network",
51
- "model_state",
52
- ]
53
-
54
- state_dict = None
55
- for k in candidates:
56
- if k in checkpoint and isinstance(checkpoint[k], dict):
57
- state_dict = checkpoint[k]
58
- break
59
-
60
- if state_dict is None:
61
- looks_like_state = (
62
- len(checkpoint) > 0
63
- and all(isinstance(v, torch.Tensor) for v in checkpoint.values())
64
- and all(isinstance(k, str) for k in checkpoint.keys())
65
- )
66
- if looks_like_state:
67
- state_dict = checkpoint
68
- else:
69
- raise KeyError(f"No model weights found. Keys: {list(checkpoint.keys())}")
70
-
71
- if len(state_dict) > 0:
72
- any_key = next(iter(state_dict.keys()))
73
- if any_key.startswith("module."):
74
- state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
75
-
76
- model.load_state_dict(state_dict, strict=strict)
77
- return model
78
-
79
- ckpt = torch.load(checkpoint_path, map_location="cpu")
80
- model = load_model_from_checkpoint(model, ckpt, strict=False)
81
- # model.eval()
82
-
83
- def predict_ner(sentence, tagger, encoder, tokenizer, id2label, device="cpu", max_length=128):
84
- tagger.to(device).eval()
85
- encoder.to(device).eval()
86
-
87
- words = sentence.split()
88
-
89
- enc = tokenizer(
90
- words,
91
- is_split_into_words=True,
92
- return_tensors="pt",
93
- truncation=True,
94
- max_length=max_length
95
- )
96
- enc = {k: v.to(device) for k, v in enc.items()}
97
-
98
- with torch.no_grad():
99
- x = encoder(
100
- input_ids=enc["input_ids"],
101
- attention_mask=enc.get("attention_mask", None)
102
- ).last_hidden_state # [1, seq_len, hidden]
103
-
104
- logits = _call_tagger(tagger, x, device)
105
-
106
- pred_ids = logits.argmax(dim=-1)[0].tolist()
107
- word_ids = _get_word_ids(tokenizer, words, tokenizer(words, is_split_into_words=True, return_tensors="pt",
108
- truncation=True, max_length=max_length), max_length)
109
-
110
- results = []
111
- seen = set()
112
- for tok_i, w_i in enumerate(word_ids):
113
- if w_i is None or w_i in seen:
114
- continue
115
- seen.add(w_i)
116
- results.append((words[w_i], id2label[pred_ids[tok_i]]))
117
-
118
- return results
119
-
120
- def find_label_vocab(vocabs):
121
- for i, v in enumerate(vocabs):
122
- if hasattr(v, "itos"):
123
- itos = v.itos
124
- if isinstance(itos, (list, tuple)) and any(x in itos for x in ["O", "B-PER", "I-PER"]):
125
- return i, v
126
- return None, None
127
-
128
-
129
- def _get_word_ids(tokenizer, words, enc, max_length):
130
- # Fast tokenizers: BatchEncoding has word_ids()
131
- if hasattr(enc, "word_ids"):
132
- return enc.word_ids(batch_index=0)
133
- # Fallback
134
- return tokenizer(words, is_split_into_words=True, truncation=True, max_length=max_length).word_ids()
135
-
136
- def _call_tagger(tagger, x, device):
137
- # Calls forward in a compatible way (x only vs x+labels, etc.)
138
- params = list(inspect.signature(tagger.forward).parameters.keys())
139
- # common: ['x'] or ['x','labels',...]
140
- if "labels" in params:
141
- ignore_idx = getattr(tagger, "label_ignore_idx", 0)
142
- labels = torch.full((x.size(0), x.size(1)), ignore_idx, dtype=torch.long, device=device)
143
-
144
- kwargs = {}
145
- if "segments_mask" in params:
146
- kwargs["segments_mask"] = None
147
- if "get_sent_repr" in params:
148
- kwargs["get_sent_repr"] = False
149
-
150
- out = tagger(x, labels, **kwargs)
151
- else:
152
- out = tagger(x)
153
-
154
- # normalize outputs to logits tensor
155
- if isinstance(out, (tuple, list)):
156
- return out[-1]
157
- if hasattr(out, "logits"):
158
- return out.logits
159
- return out
160
-
161
  label_vocab = label_vocab[0] # the list loaded from pickle
162
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
163
 
164
- # idx, label_vocab = find_label_vocab(label_vocab)
165
- # print("label vocab index:", idx)
166
- # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
167
-
168
 
169
  sentence = "ذهب احمد الى السوق"
170
- # id2label = {i: s for i, s in enumerate(label_vocab.itos)}
171
- # pairs = predict_ner(sentence, model, label_vocab, device="cpu")
172
- pairs = predict_ner(sentence, model, encoder, tokenizer, id2label, device="cpu")
173
- print(pairs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
  from transformers import AutoTokenizer, AutoModel
7
  import inspect
8
+ from collections import namedtuple
9
+ from Nested.utils.helpers import load_checkpoint
10
+ from Nested.utils.data import get_dataloaders, text2segments
11
+
12
  app = FastAPI()
13
  print("Version 2...")
14
 
 
28
  filename="checkpoints/checkpoint_2.pt"
29
  )
30
 
31
+ args_path = hf_hub_download(
32
+ repo_id="SinaLab/Nested",
33
+ filename="args.json"
34
+ )
35
+
36
  # Load model
37
  with open("Nested/utils/tag_vocab.pkl", "rb") as f:
38
  label_vocab = pickle.load(f)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  label_vocab = label_vocab[0] # the list loaded from pickle
41
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
42
 
 
 
 
 
43
 
44
  sentence = "ذهب احمد الى السوق"
45
+ # Load tagger
46
+ tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
47
+
48
+ # Convert text to a tagger dataset and index the tokens in args.text
49
+ dataset, token_vocab = text2segments(sentence)
50
+
51
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
52
+ vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
53
+
54
+ # From the datasets generate the dataloaders
55
+ dataloader = get_dataloaders(
56
+ (dataset,),
57
+ vocab,
58
+ args_path,
59
+ batch_size=32,
60
+ shuffle=(False,),
61
+ )[0]
62
+
63
+ # Perform inference on the text and get back the tagged segments
64
+ segments = tagger.infer(dataloader)
65
+
66
+ # Print results
67
+ for segment in segments:
68
+ s = [
69
+ f"{token.text} ({'|'.join([t['tag'] for t in token.pred_tag])})"
70
+ for token in segment
71
+ ]
72
+ print(" ".join(s))