TymaaHammouda commited on
Commit
79eb00d
·
1 Parent(s): d6fe8b7

Update model eval

Browse files
Files changed (2) hide show
  1. Nested/nn/BertSeqTagger.py +14 -0
  2. app.py +51 -4
Nested/nn/BertSeqTagger.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import BertModel
3
+
4
+ class BertSeqTagger(nn.Module):
5
+ def __init__(self, bert_model, num_labels=2, dropout=0.1):
6
+ super().__init__()
7
+ self.bert = BertModel.from_pretrained(bert_model)
8
+ self.dropout = nn.Dropout(dropout)
9
+ self.linear = nn.Linear(768, num_labels)
10
+ def forward(self, x):
11
+ y = self.bert(x)
12
+ y = self.dropout(y["last_hidden_state"])
13
+ logits = self.linear(y)
14
+ return logits
app.py CHANGED
@@ -2,10 +2,8 @@ from fastapi import FastAPI
2
  import torch
3
  import pickle
4
  from huggingface_hub import hf_hub_download
 
5
 
6
- import os
7
-
8
- print(os.getcwd())
9
  app = FastAPI()
10
  print("Version 2...")
11
 
@@ -24,5 +22,54 @@ checkpoint_path = hf_hub_download(
24
  with open("Nested/utils/tag_vocab.pkl", "rb") as f:
25
  id2label = pickle.load(f)
26
 
27
- model = torch.load(checkpoint_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  model.eval()
 
2
  import torch
3
  import pickle
4
  from huggingface_hub import hf_hub_download
5
+ from Nested.nn.BertSeqTagger import BertSeqTagger
6
 
 
 
 
7
  app = FastAPI()
8
  print("Version 2...")
9
 
 
22
  with open("Nested/utils/tag_vocab.pkl", "rb") as f:
23
  id2label = pickle.load(f)
24
 
25
+ # model = torch.load(checkpoint_path, map_location="cpu")
26
+ model = BertSeqTagger(
27
+ pretrained_path="aubmindlab/bert-base-arabertv2",
28
+ dropout_p=0.1
29
+ )
30
+
31
+
32
+ def load_model_from_checkpoint(model, checkpoint, strict=True):
33
+ if isinstance(checkpoint, torch.nn.Module):
34
+ return checkpoint
35
+
36
+ if not isinstance(checkpoint, dict):
37
+ raise TypeError(f"Unsupported checkpoint type: {type(checkpoint)}")
38
+
39
+ candidates = [
40
+ "state_dict",
41
+ "model_state_dict",
42
+ "model",
43
+ "net",
44
+ "network",
45
+ "model_state",
46
+ ]
47
+
48
+ state_dict = None
49
+ for k in candidates:
50
+ if k in checkpoint and isinstance(checkpoint[k], dict):
51
+ state_dict = checkpoint[k]
52
+ break
53
+
54
+ if state_dict is None:
55
+ looks_like_state = (
56
+ len(checkpoint) > 0
57
+ and all(isinstance(v, torch.Tensor) for v in checkpoint.values())
58
+ and all(isinstance(k, str) for k in checkpoint.keys())
59
+ )
60
+ if looks_like_state:
61
+ state_dict = checkpoint
62
+ else:
63
+ raise KeyError(f"No model weights found. Keys: {list(checkpoint.keys())}")
64
+
65
+ if len(state_dict) > 0:
66
+ any_key = next(iter(state_dict.keys()))
67
+ if any_key.startswith("module."):
68
+ state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
69
+
70
+ model.load_state_dict(state_dict, strict=strict)
71
+ return model
72
+
73
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
74
+ model = load_model_from_checkpoint(model, ckpt, strict=False)
75
  model.eval()