Upload load.py with huggingface_hub
Browse files
load.py
CHANGED
|
@@ -3,11 +3,49 @@ from model import FakeBERT
|
|
| 3 |
|
| 4 |
MODEL_NAME = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
|
| 5 |
MODEL_PATH = "distilbert_best.pth"
|
|
|
|
| 6 |
NUM_CLASSES = 3
|
| 7 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE)
|
| 11 |
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 12 |
model.load_state_dict(state_dict)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
MODEL_NAME = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
|
| 5 |
MODEL_PATH = "distilbert_best.pth"
|
| 6 |
+
MAX_LENGTH = 512
|
| 7 |
NUM_CLASSES = 3
|
| 8 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
| 10 |
|
| 11 |
+
def predict_veracity(texts, model, tokenizer, device, max_length=MAX_LENGTH):
|
| 12 |
+
model.eval()
|
| 13 |
+
id2label = {0: "F", 1: "U", 2: "T"}
|
| 14 |
+
|
| 15 |
+
encodings = tokenizer(
|
| 16 |
+
texts,
|
| 17 |
+
padding=True,
|
| 18 |
+
truncation=True,
|
| 19 |
+
max_length=max_length,
|
| 20 |
+
return_tensors="pt"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
input_ids = encodings["input_ids"].to(device)
|
| 24 |
+
attention_mask = encodings["attention_mask"].to(device)
|
| 25 |
+
token_type_ids = encodings.get("token_type_ids")
|
| 26 |
+
if token_type_ids is not None:
|
| 27 |
+
token_type_ids = token_type_ids.to(device)
|
| 28 |
+
|
| 29 |
+
with torch.inference_mode():
|
| 30 |
+
logits = model(input_ids, attention_mask, token_type_ids)
|
| 31 |
+
preds = torch.argmax(logits, dim=1).tolist()
|
| 32 |
+
|
| 33 |
+
return [id2label.get(p, "U") for p in preds]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Load resources
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 40 |
model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE)
|
| 41 |
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 42 |
model.load_state_dict(state_dict)
|
| 43 |
|
| 44 |
+
# Label a list of texts
|
| 45 |
+
labels = predict_sentiment(texts, model, tokenizer, DEVICE)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|