jrawa commited on
Commit
2bbbe37
·
verified ·
1 Parent(s): 2c495cb

Upload load.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. load.py +38 -0
load.py CHANGED
@@ -3,11 +3,49 @@ from model import FakeBERT
3
 
4
  MODEL_NAME = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
5
  MODEL_PATH = "distilbert_best.pth"
 
6
  NUM_CLASSES = 3
7
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE)
11
  state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
12
  model.load_state_dict(state_dict)
13
 
 
 
 
 
 
 
 
 
 
3
 
4
  MODEL_NAME = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
5
  MODEL_PATH = "distilbert_best.pth"
6
+ MAX_LENGTH = 512
7
  NUM_CLASSES = 3
8
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
 
11
+ def predict_veracity(texts, model, tokenizer, device, max_length=MAX_LENGTH):
12
+ model.eval()
13
+ id2label = {0: "F", 1: "U", 2: "T"}
14
+
15
+ encodings = tokenizer(
16
+ texts,
17
+ padding=True,
18
+ truncation=True,
19
+ max_length=max_length,
20
+ return_tensors="pt"
21
+ )
22
+
23
+ input_ids = encodings["input_ids"].to(device)
24
+ attention_mask = encodings["attention_mask"].to(device)
25
+ token_type_ids = encodings.get("token_type_ids")
26
+ if token_type_ids is not None:
27
+ token_type_ids = token_type_ids.to(device)
28
+
29
+ with torch.inference_mode():
30
+ logits = model(input_ids, attention_mask, token_type_ids)
31
+ preds = torch.argmax(logits, dim=1).tolist()
32
+
33
+ return [id2label.get(p, "U") for p in preds]
34
+
35
+
36
+
37
+
38
+ # Load resources
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
40
  model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE)
41
  state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
42
  model.load_state_dict(state_dict)
43
 
44
+ # Label a list of texts
45
+ labels = predict_sentiment(texts, model, tokenizer, DEVICE)
46
+
47
+
48
+
49
+
50
+
51
+