Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +19 -12
tasks/text.py
CHANGED
|
@@ -3,6 +3,14 @@ from datetime import datetime
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sklearn.metrics import accuracy_score
|
| 5 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from .utils.evaluation import TextEvaluationRequest
|
| 8 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
@@ -55,13 +63,16 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 55 |
# YOUR MODEL INFERENCE CODE HERE
|
| 56 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
| 57 |
#--------------------------------------------------------------------------------------------
|
| 58 |
-
class CovidTwitterBertClassifier(
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
super().__init__()
|
| 62 |
-
self.n_classes =
|
| 63 |
self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
|
| 64 |
-
self.bert.cls.seq_relationship = nn.Linear(1024,
|
| 65 |
|
| 66 |
self.sigmoid = nn.Sigmoid()
|
| 67 |
|
|
@@ -71,11 +82,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 71 |
logits = outputs[1]
|
| 72 |
|
| 73 |
return logits
|
| 74 |
-
|
| 75 |
-
model = CovidTwitterBertClassifier(8)
|
| 76 |
-
|
| 77 |
-
model.to(device)
|
| 78 |
-
model.load_state_dict(torch.load('ypesk/ct_baseline/CTBert_128_e15_0.692.pth'))
|
| 79 |
model.eval()
|
| 80 |
|
| 81 |
|
|
@@ -83,7 +90,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 83 |
|
| 84 |
test_texts = [t['quote'] for t in data_test]
|
| 85 |
|
| 86 |
-
MAX_LEN =
|
| 87 |
|
| 88 |
tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
|
| 89 |
test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
|
|
@@ -92,7 +99,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 92 |
test_input_ids = torch.tensor(test_input_ids)
|
| 93 |
test_attention_mask = torch.tensor(test_attention_mask)
|
| 94 |
|
| 95 |
-
batch_size =
|
| 96 |
test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
|
| 97 |
|
| 98 |
test_sampler = SequentialSampler(test_data)
|
|
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sklearn.metrics import accuracy_score
|
| 5 |
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
| 12 |
+
from transformers import BertForPreTraining, BertModel, AutoTokenizer, BertForSequenceClassification, RobertaForSequenceClassification
|
| 13 |
+
|
| 14 |
|
| 15 |
from .utils.evaluation import TextEvaluationRequest
|
| 16 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
|
|
| 63 |
# YOUR MODEL INFERENCE CODE HERE
|
| 64 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
| 65 |
#--------------------------------------------------------------------------------------------
|
| 66 |
+
class CovidTwitterBertClassifier(
|
| 67 |
+
nn.Module,
|
| 68 |
+
PyTorchModelHubMixin,
|
| 69 |
+
# optionally, you can add metadata which gets pushed to the model card
|
| 70 |
+
):
|
| 71 |
+
def __init__(self, num_classes):
|
| 72 |
super().__init__()
|
| 73 |
+
self.n_classes = num_classes
|
| 74 |
self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
|
| 75 |
+
self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
|
| 76 |
|
| 77 |
self.sigmoid = nn.Sigmoid()
|
| 78 |
|
|
|
|
| 82 |
logits = outputs[1]
|
| 83 |
|
| 84 |
return logits
|
| 85 |
+
model = CovidTwitterBertClassifier.from_pretrained("ypesk/ct-baseline")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
model.eval()
|
| 87 |
|
| 88 |
|
|
|
|
| 90 |
|
| 91 |
test_texts = [t['quote'] for t in data_test]
|
| 92 |
|
| 93 |
+
MAX_LEN = 256 #1024 # < m some tweets will be truncated
|
| 94 |
|
| 95 |
tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
|
| 96 |
test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
|
|
|
|
| 99 |
test_input_ids = torch.tensor(test_input_ids)
|
| 100 |
test_attention_mask = torch.tensor(test_attention_mask)
|
| 101 |
|
| 102 |
+
batch_size = 12 #
|
| 103 |
test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
|
| 104 |
|
| 105 |
test_sampler = SequentialSampler(test_data)
|