leygit commited on
Commit
62e2412
·
verified ·
1 Parent(s): a74de4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py CHANGED
@@ -16,6 +16,42 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  # Load tokenizer
17
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Load the trained model
20
  def load_model(model_path="distilbert_spam_model.pt"):
21
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
 
16
  # Load tokenizer
17
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
18
 
19
+ # Tokenize dataset
20
+ encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt")
21
+ labels = torch.tensor(df['label_num'].values)
22
+
23
+ # Custom Dataset
24
+ class SpamDataset(Dataset):
25
+ def __init__(self, encodings, labels):
26
+ self.encodings = encodings
27
+ self.labels = labels
28
+
29
+ def __len__(self):
30
+ return len(self.labels)
31
+
32
+ def __getitem__(self, idx):
33
+ item = {key: val[idx] for key, val in self.encodings.items()} # Keep as PyTorch tensors
34
+ item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) # Ensure labels are `long`
35
+ return item
36
+
37
+ # Create dataset
38
+ dataset = SpamDataset(encodings, labels)
39
+
40
+ # Split dataset (80% train, 20% validation)
41
+ train_size = int(0.8 * len(dataset))
42
+ val_size = len(dataset) - train_size
43
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
44
+
45
+ # DataLoader Function (Fix Collate)
46
+ def collate_fn(batch):
47
+ keys = batch[0].keys()
48
+ collated = {key: torch.stack([b[key] for b in batch]) for key in keys}
49
+ return collated
50
+
51
+ # Create DataLoader
52
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
53
+ val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
54
+
55
  # Load the trained model
56
  def load_model(model_path="distilbert_spam_model.pt"):
57
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)