Update README.md
Browse files
README.md
CHANGED
|
@@ -130,23 +130,33 @@ export TOKENIZERS_PARALLELISM=false
|
|
| 130 |
### 2. Load the Model via Hugging Face
|
| 131 |
|
| 132 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
import torch
|
| 134 |
from transformers import AutoTokenizer
|
| 135 |
from proto_model.configuration_proto import ProtoConfig
|
| 136 |
from proto_model.modeling_proto import ProtoForMultiLabelClassification
|
| 137 |
|
| 138 |
-
cfg
|
| 139 |
cfg.pretrained_model_name_or_path = "bert-base-uncased"
|
| 140 |
-
cfg.use_cuda
|
|
|
|
|
|
|
| 141 |
|
| 142 |
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
|
| 143 |
-
model
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
model.eval()
|
| 149 |
-
model.cpu()
|
| 150 |
|
| 151 |
def get_proto_logits(texts):
|
| 152 |
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
|
@@ -167,9 +177,6 @@ texts = [
|
|
| 167 |
logits = get_proto_logits(texts)
|
| 168 |
print("Logits shape:", logits.shape)
|
| 169 |
print("Logits:\n", logits)
|
| 170 |
-
|
| 171 |
-
probs = torch.sigmoid(logits)
|
| 172 |
-
print("Probabilities:\n", probs)
|
| 173 |
```
|
| 174 |
|
| 175 |
## 3. Training Data & Licenses
|
|
|
|
| 130 |
### 2. Load the Model via Hugging Face
|
| 131 |
|
| 132 |
```python
|
| 133 |
+
import os
|
| 134 |
+
import warnings
|
| 135 |
+
|
| 136 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 137 |
+
|
| 138 |
+
from transformers import logging as hf_logging
|
| 139 |
+
hf_logging.set_verbosity_error()
|
| 140 |
+
|
| 141 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 142 |
import torch
|
| 143 |
from transformers import AutoTokenizer
|
| 144 |
from proto_model.configuration_proto import ProtoConfig
|
| 145 |
from proto_model.modeling_proto import ProtoForMultiLabelClassification
|
| 146 |
|
| 147 |
+
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
|
| 148 |
cfg.pretrained_model_name_or_path = "bert-base-uncased"
|
| 149 |
+
cfg.use_cuda = torch.cuda.is_available()
|
| 150 |
+
|
| 151 |
+
device = torch.device("cuda" if cfg.use_cuda else "cpu")
|
| 152 |
|
| 153 |
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
|
| 154 |
+
model = ProtoForMultiLabelClassification.from_pretrained(
|
| 155 |
+
"row56/ProtoPatient",
|
| 156 |
+
config=cfg,
|
| 157 |
+
)
|
| 158 |
+
model.to(device)
|
| 159 |
model.eval()
|
|
|
|
| 160 |
|
| 161 |
def get_proto_logits(texts):
|
| 162 |
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
|
|
|
| 177 |
logits = get_proto_logits(texts)
|
| 178 |
print("Logits shape:", logits.shape)
|
| 179 |
print("Logits:\n", logits)
|
|
|
|
|
|
|
|
|
|
| 180 |
```
|
| 181 |
|
| 182 |
## 3. Training Data & Licenses
|