| import time
|
| from torch.utils.data.dataset import random_split
|
| from torchtext.data.functional import to_map_style_dataset
|
| import torch
|
| import gzip
|
| import json
|
| import numpy as np
|
| import nltk
|
| from nltk.corpus import stopwords
|
| from nltk.tokenize import word_tokenize, sent_tokenize
|
| from nltk.stem import PorterStemmer, WordNetLemmatizer
|
| from torchtext.data.utils import get_tokenizer
|
| from torchtext.vocab import build_vocab_from_iterator
|
| from torch.utils.data import DataLoader
|
| import argparse
|
| from torch import nn
|
| import json
|
|
|
| nltk.download('punkt')
|
| nltk.download('stopwords')
|
| nltk.download('averaged_perceptron_tagger')
|
| nltk.download('maxent_ne_chunker')
|
| nltk.download('words')
|
| nltk.download('wordnet')
|
|
|
|
|
| class TextClassificationModel(nn.Module):
|
| def __init__(self, vocab_size, embed_dim, num_class, vocab):
|
| self.model = super(TextClassificationModel, self).__init__()
|
| self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
|
| self.fc = nn.Linear(embed_dim, num_class)
|
| self.init_weights()
|
| self.vocab_size = vocab_size
|
| self.emsize = embed_dim
|
| self.num_class = num_class
|
| self.vocab = vocab
|
| self.text_pipeline = self.tokenizer
|
| self.tokenizer_convert = get_tokenizer("basic_english")
|
|
|
|
|
| def tokenizer(self, text):
|
| return self.vocab(self.tokenizer_convert(text))
|
|
|
| def init_weights(self):
|
| initrange = 0.5
|
| self.embedding.weight.data.uniform_(-initrange, initrange)
|
| self.fc.weight.data.uniform_(-initrange, initrange)
|
| self.fc.bias.data.zero_()
|
|
|
| def forward(self, text, offsets):
|
| embedded = self.embedding(text, offsets)
|
| return self.fc(embedded)
|
|
|
| def train_model(self, train_dataloader, valid_dataloader):
|
|
|
| total_accu = None
|
| for epoch in range(1, EPOCHS + 1):
|
| epoch_start_time = time.time()
|
|
|
| self.train()
|
| total_acc, total_count = 0, 0
|
| log_interval = 500
|
| start_time = time.time()
|
|
|
| for idx, (label, text, offsets) in enumerate(train_dataloader):
|
| optimizer.zero_grad()
|
| predicted_label = self(text, offsets)
|
| loss = criterion(predicted_label, label)
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
| optimizer.step()
|
| total_acc += (predicted_label.argmax(1) == label).sum().item()
|
| total_count += label.size(0)
|
| if idx % log_interval == 0 and idx > 0:
|
| elapsed = time.time() - start_time
|
| print(
|
| "| epoch {:3d} | {:5d}/{:5d} batches "
|
| "| accuracy {:8.3f}".format(
|
| epoch, idx, len(train_dataloader), total_acc / total_count
|
| )
|
| )
|
| total_acc, total_count = 0, 0
|
| start_time = time.time()
|
|
|
|
|
| accu_val = self.evaluate(valid_dataloader)
|
| if total_accu is not None and total_accu > accu_val:
|
| scheduler.step()
|
| else:
|
| total_accu = accu_val
|
| print("-" * 59)
|
| print(
|
| "| end of epoch {:3d} | time: {:5.2f}s | "
|
| "valid accuracy {:8.3f} ".format(
|
| epoch, time.time() - epoch_start_time, accu_val
|
| )
|
| )
|
| print("-" * 59)
|
|
|
|
|
|
|
| def save_model(self, file_path):
|
| model_state = {
|
| 'state_dict': self.state_dict(),
|
| 'vocab_size': self.vocab_size,
|
| 'embed_dim': self.emsize,
|
| 'num_class': self.num_class,
|
| 'vocab': self.vocab
|
| }
|
| torch.save(model_state, file_path)
|
| print("Model saved successfully.")
|
|
|
| @classmethod
|
| def load_model(self, file_path):
|
| model_state = torch.load(file_path, map_location=torch.device('cpu'))
|
|
|
| vocab_size = model_state['vocab_size']
|
| embed_dim = model_state['embed_dim']
|
| num_class = model_state['num_class']
|
| vocab = model_state['vocab']
|
|
|
| model = TextClassificationModel(vocab_size, embed_dim, num_class, vocab)
|
| model.load_state_dict(model_state['state_dict'])
|
| model.eval()
|
| print("Model loaded successfully.")
|
| return model
|
|
|
| def evaluate(self, dataloader):
|
| self.eval()
|
| total_acc, total_count = 0, 0
|
|
|
| with torch.no_grad():
|
| for idx, (label, text, offsets) in enumerate(dataloader):
|
| predicted_label = self(text, offsets)
|
| loss = criterion(predicted_label, label)
|
| total_acc += (predicted_label.argmax(1) == label).sum().item()
|
| total_count += label.size(0)
|
| return total_acc / total_count
|
|
|
| def predict(self, text):
|
| with torch.no_grad():
|
| text = torch.tensor(self.text_pipeline(text))
|
|
|
|
|
| output = self(text, torch.tensor([0]))
|
|
|
| return output
|
|
|
| @staticmethod
|
| def read_gz_json(file_path):
|
| with gzip.open(file_path, 'rt', encoding='utf-8') as f:
|
| data = json.load(f)
|
| for obj in data:
|
| yield obj['text'], obj['category']
|
|
|
| @staticmethod
|
| def preprocess_text(text):
|
| sentences = sent_tokenize(text)
|
| return sentences
|
|
|
| @staticmethod
|
| def data_iter(file_paths, categories):
|
|
|
| categories = np.array(categories)
|
|
|
| for path in file_paths:
|
| for text, category in TextClassificationModel.read_gz_json(path):
|
| sentences = TextClassificationModel.preprocess_text(text)
|
|
|
| for sentence in sentences:
|
| yield np.where(categories == category)[0][0], sentence
|
| @staticmethod
|
| def collate_batch(batch):
|
| label_list, text_list, offsets = [], [], [0]
|
| for _label, _text in batch:
|
| label_list.append(label_pipeline(_label))
|
| processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
|
| text_list.append(processed_text)
|
| offsets.append(processed_text.size(0))
|
| label_list = torch.tensor(label_list, dtype=torch.int64)
|
| offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
|
| text_list = torch.cat(text_list)
|
| return label_list.to(device), text_list.to(device), offsets.to(device)
|
|
|
|
|
| def parse_arguments():
|
| parser = argparse.ArgumentParser(description="Text Classification Model")
|
| parser.add_argument("--train_path", type=str, nargs='+', required=True, help="Path to the training data")
|
| parser.add_argument("--test_path", type=str, nargs='+', required=True, help="Path to the test data")
|
| parser.add_argument("--epochs", type=int, default=5, help="Number of epochs for training")
|
| parser.add_argument("--lr", type=float, default=3, help="Learning rate")
|
| parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
|
| return parser.parse_args()
|
|
|
| if __name__ == '__main__':
|
|
|
| args = parse_arguments()
|
|
|
| categories = ['Geography', 'Religion', 'Philosophy', 'Trash', 'Mythology', 'Literature', 'Science', 'Social Science', 'History', 'Current Events', 'Fine Arts']
|
|
|
| test_path = args.test_path
|
| train_path = args.train_path
|
|
|
| tokenizer = get_tokenizer("basic_english")
|
| train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
|
|
|
| def yield_tokens(data_iter):
|
| for _, text in data_iter:
|
| yield tokenizer(text)
|
|
|
|
|
| vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
|
| vocab.set_default_index(vocab["<unk>"])
|
|
|
| text_pipeline = lambda x: vocab(tokenizer(x))
|
| label_pipeline = lambda x: int(x)
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| dataloader = DataLoader(
|
| train_iter, batch_size=8, shuffle=False, collate_fn=TextClassificationModel.collate_batch
|
| )
|
|
|
| train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
|
| classes = set([label for (label, text) in train_iter])
|
| num_class = len(classes)
|
| print(num_class)
|
| vocab_size = len(vocab)
|
| emsize = 64
|
| model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
|
| print(model)
|
|
|
|
|
|
|
|
|
| EPOCHS = args.epochs
|
| LR = args.lr
|
| BATCH_SIZE = args.batch_size
|
|
|
| criterion = torch.nn.CrossEntropyLoss()
|
| optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
|
| total_accu = None
|
| train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
|
| test_iter = iter(TextClassificationModel.data_iter(test_path, categories))
|
| train_dataset = to_map_style_dataset(train_iter)
|
| test_dataset = to_map_style_dataset(test_iter)
|
| num_train = int(len(train_dataset) * 0.95)
|
| split_train_, split_valid_ = random_split(
|
| train_dataset, [num_train, len(train_dataset) - num_train]
|
| )
|
|
|
| train_dataloader = DataLoader(
|
| split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
|
| )
|
| valid_dataloader = DataLoader(
|
| split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
|
| )
|
| test_dataloader = DataLoader(
|
| test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
|
| )
|
|
|
| model.train_model(train_dataloader,valid_dataloader)
|
|
|
| print("Checking the results of test dataset.")
|
| accu_test = model.evaluate(test_dataloader)
|
| print("test accuracy {:8.3f}".format(accu_test))
|
|
|
| model.save_model("text_classification_model.pth")
|
|
|
|
|
|
|