Commit ·
6439837
1
Parent(s): d89e25d
added device to handler
Browse files- handler.py +10 -5
handler.py
CHANGED
|
@@ -5,8 +5,12 @@ import torch
|
|
| 5 |
|
| 6 |
class EndpointHandler:
|
| 7 |
def __init__(self, path=""):
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
self.model = AutoModelForSequenceClassification.from_pretrained(path)
|
|
|
|
|
|
|
| 10 |
def tokenize(batch):
|
| 11 |
return self.tokenizer(
|
| 12 |
batch['topic'],
|
|
@@ -16,7 +20,8 @@ class EndpointHandler:
|
|
| 16 |
return_offsets_mapping=False,
|
| 17 |
padding="max_length",
|
| 18 |
return_tensors='pt'
|
| 19 |
-
)
|
|
|
|
| 20 |
self.tokenize = tokenize
|
| 21 |
|
| 22 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
@@ -36,11 +41,11 @@ class EndpointHandler:
|
|
| 36 |
}
|
| 37 |
|
| 38 |
for topic in topics:
|
| 39 |
-
for text in texts:
|
| 40 |
batch['id'].append(text['id'])
|
| 41 |
batch['text'].append(text['text'])
|
| 42 |
batch['topic'].append(topic)
|
| 43 |
-
|
| 44 |
tokenized_inputs = self.tokenize(batch)
|
| 45 |
|
| 46 |
# run normal prediction
|
|
@@ -48,4 +53,4 @@ class EndpointHandler:
|
|
| 48 |
predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
|
| 49 |
batch['label'] = [self.model.config.id2label[p] for p in predictions]
|
| 50 |
batch.pop('text')
|
| 51 |
-
return batch
|
|
|
|
| 5 |
|
| 6 |
class EndpointHandler:
|
| 7 |
def __init__(self, path=""):
|
| 8 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 9 |
+
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path, device=device)
|
| 11 |
self.model = AutoModelForSequenceClassification.from_pretrained(path)
|
| 12 |
+
self.model.to(device)
|
| 13 |
+
|
| 14 |
def tokenize(batch):
|
| 15 |
return self.tokenizer(
|
| 16 |
batch['topic'],
|
|
|
|
| 20 |
return_offsets_mapping=False,
|
| 21 |
padding="max_length",
|
| 22 |
return_tensors='pt'
|
| 23 |
+
).to(device)
|
| 24 |
+
|
| 25 |
self.tokenize = tokenize
|
| 26 |
|
| 27 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
for topic in topics:
|
| 44 |
+
for text in texts:
|
| 45 |
batch['id'].append(text['id'])
|
| 46 |
batch['text'].append(text['text'])
|
| 47 |
batch['topic'].append(topic)
|
| 48 |
+
|
| 49 |
tokenized_inputs = self.tokenize(batch)
|
| 50 |
|
| 51 |
# run normal prediction
|
|
|
|
| 53 |
predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
|
| 54 |
batch['label'] = [self.model.config.id2label[p] for p in predictions]
|
| 55 |
batch.pop('text')
|
| 56 |
+
return batch
|