File size: 1,313 Bytes
548021b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6457a8e
548021b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
### DATASET.PY ###

import os
from torch.utils.data import Dataset
import torch


class SpeechesClassificationDataset(Dataset):
    """
    Dataset class for text classification task.
    This the dataset you will use to train your encoder, and classifier jointly,
    end-to-end for the text classification task.

    Args:
        tokenizer (Tokenizer): The tokenizer used to encode the text.
        file_path (str): The path to the file containing the speech classification data.

    """

    def __init__(self, tokenizer, file_path):
        self.tokenizer = tokenizer
        self.samples = []

        if not os.path.exists(file_path):
            raise FileNotFoundError(f"The file {file_path} does not exist.")

        with open(file_path, 'r', encoding='utf-16') as file:
            for line in file:
                label, text = line.strip().split('\t')

                if len(text.strip()) == 0:
                    continue

                self.samples.append((int(label), text))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        label, text = self.samples[index]
        input_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.long)
        label_tensor = torch.tensor(label, dtype=torch.long)

        return input_ids, label_tensor