| | import time
|
| | from torch.utils.data.dataset import random_split
|
| | from torchtext.data.functional import to_map_style_dataset
|
| | import torch
|
| | import gzip
|
| | import json
|
| | import numpy as np
|
| | import nltk
|
| | from nltk.corpus import stopwords
|
| | from nltk.tokenize import word_tokenize, sent_tokenize
|
| | from nltk.stem import PorterStemmer, WordNetLemmatizer
|
| | from torchtext.data.utils import get_tokenizer
|
| | from torchtext.vocab import build_vocab_from_iterator
|
| | from torch.utils.data import DataLoader
|
| | import argparse
|
| | from torch import nn
|
| | import json
|
| |
|
| | nltk.download('punkt')
|
| | nltk.download('stopwords')
|
| | nltk.download('averaged_perceptron_tagger')
|
| | nltk.download('maxent_ne_chunker')
|
| | nltk.download('words')
|
| | nltk.download('wordnet')
|
| |
|
| |
|
| | class TextClassificationModel(nn.Module):
|
| | def __init__(self, vocab_size, embed_dim, num_class, vocab):
|
| | self.model = super(TextClassificationModel, self).__init__()
|
| | self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
|
| | self.fc = nn.Linear(embed_dim, num_class)
|
| | self.init_weights()
|
| | self.vocab_size = vocab_size
|
| | self.emsize = embed_dim
|
| | self.num_class = num_class
|
| | self.vocab = vocab
|
| | self.text_pipeline = self.tokenizer
|
| | self.tokenizer_convert = get_tokenizer("basic_english")
|
| |
|
| |
|
| | def tokenizer(self, text):
|
| | return self.vocab(self.tokenizer_convert(text))
|
| |
|
| | def init_weights(self):
|
| | initrange = 0.5
|
| | self.embedding.weight.data.uniform_(-initrange, initrange)
|
| | self.fc.weight.data.uniform_(-initrange, initrange)
|
| | self.fc.bias.data.zero_()
|
| |
|
| | def forward(self, text, offsets):
|
| | embedded = self.embedding(text, offsets)
|
| | return self.fc(embedded)
|
| |
|
| | def train_model(self, train_dataloader, valid_dataloader):
|
| |
|
| | total_accu = None
|
| | for epoch in range(1, EPOCHS + 1):
|
| | epoch_start_time = time.time()
|
| |
|
| | self.train()
|
| | total_acc, total_count = 0, 0
|
| | log_interval = 500
|
| | start_time = time.time()
|
| |
|
| | for idx, (label, text, offsets) in enumerate(train_dataloader):
|
| | optimizer.zero_grad()
|
| | predicted_label = self(text, offsets)
|
| | loss = criterion(predicted_label, label)
|
| | loss.backward()
|
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
| | optimizer.step()
|
| | total_acc += (predicted_label.argmax(1) == label).sum().item()
|
| | total_count += label.size(0)
|
| | if idx % log_interval == 0 and idx > 0:
|
| | elapsed = time.time() - start_time
|
| | print(
|
| | "| epoch {:3d} | {:5d}/{:5d} batches "
|
| | "| accuracy {:8.3f}".format(
|
| | epoch, idx, len(train_dataloader), total_acc / total_count
|
| | )
|
| | )
|
| | total_acc, total_count = 0, 0
|
| | start_time = time.time()
|
| |
|
| |
|
| | accu_val = self.evaluate(valid_dataloader)
|
| | if total_accu is not None and total_accu > accu_val:
|
| | scheduler.step()
|
| | else:
|
| | total_accu = accu_val
|
| | print("-" * 59)
|
| | print(
|
| | "| end of epoch {:3d} | time: {:5.2f}s | "
|
| | "valid accuracy {:8.3f} ".format(
|
| | epoch, time.time() - epoch_start_time, accu_val
|
| | )
|
| | )
|
| | print("-" * 59)
|
| |
|
| |
|
| |
|
| | def save_model(self, file_path):
|
| | model_state = {
|
| | 'state_dict': self.state_dict(),
|
| | 'vocab_size': self.vocab_size,
|
| | 'embed_dim': self.emsize,
|
| | 'num_class': self.num_class,
|
| | 'vocab': self.vocab
|
| | }
|
| | torch.save(model_state, file_path)
|
| | print("Model saved successfully.")
|
| |
|
| | @classmethod
|
| | def load_model(self, file_path):
|
| | model_state = torch.load(file_path, map_location=torch.device('cpu'))
|
| |
|
| | vocab_size = model_state['vocab_size']
|
| | embed_dim = model_state['embed_dim']
|
| | num_class = model_state['num_class']
|
| | vocab = model_state['vocab']
|
| |
|
| | model = TextClassificationModel(vocab_size, embed_dim, num_class, vocab)
|
| | model.load_state_dict(model_state['state_dict'])
|
| | model.eval()
|
| | print("Model loaded successfully.")
|
| | return model
|
| |
|
| | def evaluate(self, dataloader):
|
| | self.eval()
|
| | total_acc, total_count = 0, 0
|
| |
|
| | with torch.no_grad():
|
| | for idx, (label, text, offsets) in enumerate(dataloader):
|
| | predicted_label = self(text, offsets)
|
| | loss = criterion(predicted_label, label)
|
| | total_acc += (predicted_label.argmax(1) == label).sum().item()
|
| | total_count += label.size(0)
|
| | return total_acc / total_count
|
| |
|
| | def predict(self, text):
|
| | with torch.no_grad():
|
| | text = torch.tensor(self.text_pipeline(text))
|
| |
|
| |
|
| | output = self(text, torch.tensor([0]))
|
| |
|
| | return output
|
| |
|
| | @staticmethod
|
| | def read_gz_json(file_path):
|
| | with gzip.open(file_path, 'rt', encoding='utf-8') as f:
|
| | data = json.load(f)
|
| | for obj in data:
|
| | yield obj['text'], obj['category']
|
| |
|
| | @staticmethod
|
| | def preprocess_text(text):
|
| | sentences = sent_tokenize(text)
|
| | return sentences
|
| |
|
| | @staticmethod
|
| | def data_iter(file_paths, categories):
|
| |
|
| | categories = np.array(categories)
|
| |
|
| | for path in file_paths:
|
| | for text, category in TextClassificationModel.read_gz_json(path):
|
| | sentences = TextClassificationModel.preprocess_text(text)
|
| |
|
| | for sentence in sentences:
|
| | yield np.where(categories == category)[0][0], sentence
|
| | @staticmethod
|
| | def collate_batch(batch):
|
| | label_list, text_list, offsets = [], [], [0]
|
| | for _label, _text in batch:
|
| | label_list.append(label_pipeline(_label))
|
| | processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
|
| | text_list.append(processed_text)
|
| | offsets.append(processed_text.size(0))
|
| | label_list = torch.tensor(label_list, dtype=torch.int64)
|
| | offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
|
| | text_list = torch.cat(text_list)
|
| | return label_list.to(device), text_list.to(device), offsets.to(device)
|
| |
|
| |
|
| | def parse_arguments():
|
| | parser = argparse.ArgumentParser(description="Text Classification Model")
|
| | parser.add_argument("--train_path", type=str, nargs='+', required=True, help="Path to the training data")
|
| | parser.add_argument("--test_path", type=str, nargs='+', required=True, help="Path to the test data")
|
| | parser.add_argument("--epochs", type=int, default=5, help="Number of epochs for training")
|
| | parser.add_argument("--lr", type=float, default=3, help="Learning rate")
|
| | parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
|
| | return parser.parse_args()
|
| |
|
| | if __name__ == '__main__':
|
| |
|
| | args = parse_arguments()
|
| |
|
| | categories = ['Geography', 'Religion', 'Philosophy', 'Trash', 'Mythology', 'Literature', 'Science', 'Social Science', 'History', 'Current Events', 'Fine Arts']
|
| |
|
| | test_path = args.test_path
|
| | train_path = args.train_path
|
| |
|
| | tokenizer = get_tokenizer("basic_english")
|
| | train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
|
| |
|
| | def yield_tokens(data_iter):
|
| | for _, text in data_iter:
|
| | yield tokenizer(text)
|
| |
|
| |
|
| | vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
|
| | vocab.set_default_index(vocab["<unk>"])
|
| |
|
| | text_pipeline = lambda x: vocab(tokenizer(x))
|
| | label_pipeline = lambda x: int(x)
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| |
|
| | dataloader = DataLoader(
|
| | train_iter, batch_size=8, shuffle=False, collate_fn=TextClassificationModel.collate_batch
|
| | )
|
| |
|
| | train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
|
| | classes = set([label for (label, text) in train_iter])
|
| | num_class = len(classes)
|
| | print(num_class)
|
| | vocab_size = len(vocab)
|
| | emsize = 64
|
| | model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
|
| | print(model)
|
| |
|
| |
|
| |
|
| |
|
| | EPOCHS = args.epochs
|
| | LR = args.lr
|
| | BATCH_SIZE = args.batch_size
|
| |
|
| | criterion = torch.nn.CrossEntropyLoss()
|
| | optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
| | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
|
| | total_accu = None
|
| | train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
|
| | test_iter = iter(TextClassificationModel.data_iter(test_path, categories))
|
| | train_dataset = to_map_style_dataset(train_iter)
|
| | test_dataset = to_map_style_dataset(test_iter)
|
| | num_train = int(len(train_dataset) * 0.95)
|
| | split_train_, split_valid_ = random_split(
|
| | train_dataset, [num_train, len(train_dataset) - num_train]
|
| | )
|
| |
|
| | train_dataloader = DataLoader(
|
| | split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
|
| | )
|
| | valid_dataloader = DataLoader(
|
| | split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
|
| | )
|
| | test_dataloader = DataLoader(
|
| | test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
|
| | )
|
| |
|
| | model.train_model(train_dataloader,valid_dataloader)
|
| |
|
| | print("Checking the results of test dataset.")
|
| | accu_test = model.evaluate(test_dataloader)
|
| | print("test accuracy {:8.3f}".format(accu_test))
|
| |
|
| | model.save_model("text_classification_model.pth")
|
| |
|
| |
|
| |
|