Update README.md
Browse files
README.md
CHANGED
|
@@ -135,7 +135,6 @@ from transformers import AutoTokenizer
|
|
| 135 |
from proto_model.configuration_proto import ProtoConfig
|
| 136 |
from proto_model.modeling_proto import ProtoForMultiLabelClassification
|
| 137 |
|
| 138 |
-
# Load & configure
|
| 139 |
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
|
| 140 |
cfg.pretrained_model_name_or_path = "bert-base-uncased"
|
| 141 |
cfg.use_cuda = False
|
|
@@ -149,7 +148,6 @@ model = ProtoForMultiLabelClassification.from_pretrained(
|
|
| 149 |
model.eval()
|
| 150 |
model.cpu()
|
| 151 |
|
| 152 |
-
# Helper
|
| 153 |
def get_proto_logits(texts):
|
| 154 |
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 155 |
batch = {
|
|
@@ -162,7 +160,6 @@ def get_proto_logits(texts):
|
|
| 162 |
logits, _ = model.proto_module(batch)
|
| 163 |
return logits
|
| 164 |
|
| 165 |
-
# Run
|
| 166 |
texts = [
|
| 167 |
"Patient shows elevated heart rate and low oxygen saturation.",
|
| 168 |
"No significant findings; patient is healthy."
|
|
|
|
| 135 |
from proto_model.configuration_proto import ProtoConfig
|
| 136 |
from proto_model.modeling_proto import ProtoForMultiLabelClassification
|
| 137 |
|
|
|
|
| 138 |
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
|
| 139 |
cfg.pretrained_model_name_or_path = "bert-base-uncased"
|
| 140 |
cfg.use_cuda = False
|
|
|
|
| 148 |
model.eval()
|
| 149 |
model.cpu()
|
| 150 |
|
|
|
|
| 151 |
def get_proto_logits(texts):
|
| 152 |
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 153 |
batch = {
|
|
|
|
| 160 |
logits, _ = model.proto_module(batch)
|
| 161 |
return logits
|
| 162 |
|
|
|
|
| 163 |
texts = [
|
| 164 |
"Patient shows elevated heart rate and low oxygen saturation.",
|
| 165 |
"No significant findings; patient is healthy."
|