edwjin commited on
Commit
548021b
·
verified ·
1 Parent(s): b25aa2d

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +44 -69
dataset.py CHANGED
@@ -1,69 +1,44 @@
1
- import os
2
- from torch.utils.data import Dataset
3
- import torch
4
-
5
-
6
- class SpeechesClassificationDataset(Dataset):
7
- """
8
- Dataset class for text classification task.
9
- This the dataset you will use to train your encoder, and classifier jointly,
10
- end-to-end for the text classification task.
11
-
12
- Args:
13
- tokenizer (Tokenizer): The tokenizer used to encode the text.
14
- file_path (str): The path to the file containing the speech classification data.
15
-
16
- """
17
-
18
- def __init__(self, tokenizer, file_path):
19
- self.tokenizer = tokenizer
20
- self.samples = []
21
-
22
- if not os.path.exists(file_path):
23
- raise FileNotFoundError(f"The file {file_path} does not exist.")
24
-
25
- with open(file_path, 'r', encoding='utf-8') as file:
26
- for line in file:
27
- label, text = line.strip().split('\t')
28
- if label not in ('0', '1', '2'):
29
- raise ValueError(f"Invalid label: {label}")
30
- if len(text.strip()) == 0:
31
- continue
32
- self.samples.append((int(label), text))
33
-
34
- def __len__(self):
35
- return len(self.samples)
36
-
37
- def __getitem__(self, index):
38
- label, text = self.samples[index]
39
- input_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.long)
40
- label_tensor = torch.tensor(label, dtype=torch.long)
41
-
42
- return input_ids, label_tensor
43
-
44
-
45
-
46
-
47
- class LanguageModelingDataset(torch.utils.data.Dataset):
48
- """
49
- Dataset class for language modeling task. This is the dataset you will use to train your encoder for the language modeling task.
50
-
51
- Args:
52
- tokenizer (Tokenizer): The tokenizer used to encode the text.
53
- text (str): The text data.
54
- block_size (int): The size of each block of text.
55
- """
56
-
57
- def __init__(self, tokenizer, text, block_size):
58
- self.tokenizer = tokenizer
59
- self.data = torch.tensor(self.tokenizer.encode(text), dtype=torch.long)
60
- self.block_size = block_size
61
-
62
- def __len__(self):
63
- return len(self.data) - self.block_size
64
-
65
- def __getitem__(self, idx):
66
- chunk = self.data[idx:idx + self.block_size + 1]
67
- x = chunk[:-1]
68
- y = chunk[1:]
69
- return x, y
 
1
+ ### DATASET.PY ###
2
+
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ import torch
6
+
7
+
8
+ class SpeechesClassificationDataset(Dataset):
9
+ """
10
+ Dataset class for text classification task.
11
+ This the dataset you will use to train your encoder, and classifier jointly,
12
+ end-to-end for the text classification task.
13
+
14
+ Args:
15
+ tokenizer (Tokenizer): The tokenizer used to encode the text.
16
+ file_path (str): The path to the file containing the speech classification data.
17
+
18
+ """
19
+
20
+ def __init__(self, tokenizer, file_path):
21
+ self.tokenizer = tokenizer
22
+ self.samples = []
23
+
24
+ if not os.path.exists(file_path):
25
+ raise FileNotFoundError(f"The file {file_path} does not exist.")
26
+
27
+ with open(file_path, 'r', encoding='utf-8') as file:
28
+ for line in file:
29
+ label, text = line.strip().split('\t')
30
+
31
+ if len(text.strip()) == 0:
32
+ continue
33
+
34
+ self.samples.append((int(label), text))
35
+
36
+ def __len__(self):
37
+ return len(self.samples)
38
+
39
+ def __getitem__(self, index):
40
+ label, text = self.samples[index]
41
+ input_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.long)
42
+ label_tensor = torch.tensor(label, dtype=torch.long)
43
+
44
+ return input_ids, label_tensor