row56 commited on
Commit
caf29aa
·
verified ·
1 Parent(s): cafe337

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +18 -11
README.md CHANGED
@@ -130,23 +130,33 @@ export TOKENIZERS_PARALLELISM=false
130
  ### 2. Load the Model via Hugging Face
131
 
132
  ```python
 
 
 
 
 
 
 
 
 
133
  import torch
134
  from transformers import AutoTokenizer
135
  from proto_model.configuration_proto import ProtoConfig
136
  from proto_model.modeling_proto import ProtoForMultiLabelClassification
137
 
138
- cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
139
  cfg.pretrained_model_name_or_path = "bert-base-uncased"
140
- cfg.use_cuda = False
 
 
141
 
142
  tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
143
- model = ProtoForMultiLabelClassification.from_pretrained(
144
- "row56/ProtoPatient",
145
- config=cfg,
146
- ignore_mismatched_sizes=True
147
- )
148
  model.eval()
149
- model.cpu()
150
 
151
  def get_proto_logits(texts):
152
  enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
@@ -167,9 +177,6 @@ texts = [
167
  logits = get_proto_logits(texts)
168
  print("Logits shape:", logits.shape)
169
  print("Logits:\n", logits)
170
-
171
- probs = torch.sigmoid(logits)
172
- print("Probabilities:\n", probs)
173
  ```
174
 
175
  ## 3. Training Data & Licenses
 
130
  ### 2. Load the Model via Hugging Face
131
 
132
  ```python
133
+ import os
134
+ import warnings
135
+
136
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
137
+
138
+ from transformers import logging as hf_logging
139
+ hf_logging.set_verbosity_error()
140
+
141
+ warnings.filterwarnings("ignore", category=UserWarning)
142
  import torch
143
  from transformers import AutoTokenizer
144
  from proto_model.configuration_proto import ProtoConfig
145
  from proto_model.modeling_proto import ProtoForMultiLabelClassification
146
 
147
+ cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
148
  cfg.pretrained_model_name_or_path = "bert-base-uncased"
149
+ cfg.use_cuda = torch.cuda.is_available()
150
+
151
+ device = torch.device("cuda" if cfg.use_cuda else "cpu")
152
 
153
  tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
154
+ model = ProtoForMultiLabelClassification.from_pretrained(
155
+ "row56/ProtoPatient",
156
+ config=cfg,
157
+ )
158
+ model.to(device)
159
  model.eval()
 
160
 
161
  def get_proto_logits(texts):
162
  enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
 
177
  logits = get_proto_logits(texts)
178
  print("Logits shape:", logits.shape)
179
  print("Logits:\n", logits)
 
 
 
180
  ```
181
 
182
  ## 3. Training Data & Licenses