Spaces:
Running
Running
| import numpy as np | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| import torch | |
| from torch.optim.lr_scheduler import StepLR | |
| from backend.datasets.utils import _utils | |
| from backend.datasets.utils.logger import Logger | |
| logger = Logger("WARNING") | |
| class DynamicTrainer: | |
| def __init__(self, | |
| model, | |
| dataset, | |
| num_top_words=15, | |
| epochs=200, | |
| learning_rate=0.002, | |
| batch_size=200, | |
| lr_scheduler=None, | |
| lr_step_size=125, | |
| log_interval=5, | |
| verbose=False | |
| ): | |
| self.model = model | |
| self.dataset = dataset | |
| self.num_top_words = num_top_words | |
| self.epochs = epochs | |
| self.learning_rate = learning_rate | |
| self.batch_size = batch_size | |
| self.lr_scheduler = lr_scheduler | |
| self.lr_step_size = lr_step_size | |
| self.log_interval = log_interval | |
| self.verbose = verbose | |
| if verbose: | |
| logger.set_level("DEBUG") | |
| else: | |
| logger.set_level("WARNING") | |
| def make_optimizer(self,): | |
| args_dict = { | |
| 'params': self.model.parameters(), | |
| 'lr': self.learning_rate, | |
| } | |
| optimizer = torch.optim.Adam(**args_dict) | |
| return optimizer | |
| def make_lr_scheduler(self, optimizer): | |
| lr_scheduler = StepLR(optimizer, step_size=self.lr_step_size, gamma=0.5, verbose=False) | |
| return lr_scheduler | |
| def train(self): | |
| optimizer = self.make_optimizer() | |
| if self.lr_scheduler: | |
| logger.info("using lr_scheduler") | |
| lr_scheduler = self.make_lr_scheduler(optimizer) | |
| data_size = len(self.dataset.train_dataloader.dataset) | |
| for epoch in tqdm(range(1, self.epochs + 1)): | |
| self.model.train() | |
| loss_rst_dict = defaultdict(float) | |
| for batch_data in self.dataset.train_dataloader: | |
| rst_dict = self.model(batch_data['bow'], batch_data['times']) | |
| batch_loss = rst_dict['loss'] | |
| optimizer.zero_grad() | |
| batch_loss.backward() | |
| optimizer.step() | |
| for key in rst_dict: | |
| loss_rst_dict[key] += rst_dict[key] * len(batch_data) | |
| if self.lr_scheduler: | |
| lr_scheduler.step() | |
| if epoch % self.log_interval == 0: | |
| output_log = f'Epoch: {epoch:03d}' | |
| for key in loss_rst_dict: | |
| output_log += f' {key}: {loss_rst_dict[key] / data_size :.3f}' | |
| logger.info(output_log) | |
| top_words = self.get_top_words() | |
| train_theta = self.test(self.dataset.train_bow, self.dataset.train_times) | |
| return top_words, train_theta | |
| def test(self, bow, times): | |
| data_size = bow.shape[0] | |
| theta = list() | |
| all_idx = torch.split(torch.arange(data_size), self.batch_size) | |
| with torch.no_grad(): | |
| self.model.eval() | |
| for idx in all_idx: | |
| batch_theta = self.model.get_theta(bow[idx], times[idx]) | |
| theta.extend(batch_theta.cpu().tolist()) | |
| theta = np.asarray(theta) | |
| return theta | |
| def get_beta(self): | |
| self.model.eval() | |
| beta = self.model.get_beta().detach().cpu().numpy() | |
| return beta | |
| def get_top_words(self, num_top_words=None): | |
| if num_top_words is None: | |
| num_top_words = self.num_top_words | |
| beta = self.get_beta() | |
| top_words_list = list() | |
| for time in range(beta.shape[0]): | |
| if self.verbose: | |
| print(f"======= Time: {time} =======") | |
| top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose) | |
| top_words_list.append(top_words) | |
| return top_words_list | |
| def export_theta(self): | |
| train_theta = self.test(self.dataset.train_bow, self.dataset.train_times) | |
| test_theta = self.test(self.dataset.test_bow, self.dataset.test_times) | |
| return train_theta, test_theta | |
| def get_top_words_at_time(self, topic_id, time, top_n): | |
| beta = self.get_beta() # shape: [T, K, V] | |
| topic_beta = beta[time, topic_id, :] | |
| top_indices = topic_beta.argsort()[-top_n:][::-1] | |
| return [self.dataset.vocab[i] for i in top_indices] | |
| def get_topic_words_over_time(self, topic_id, top_n): | |
| """ | |
| Returns top_n words for the given topic_id over all time steps. | |
| Output: List[List[str]], each inner list is the top_n words at a time step. | |
| """ | |
| beta = self.get_beta() # shape: [T, K, V] | |
| T = beta.shape[0] | |
| return [ | |
| self.get_top_words_at_time(topic_id=topic_id, time=t, top_n=top_n) | |
| for t in range(T) | |
| ] | |
| def get_all_topics_at_time(self, time, top_n): | |
| """ | |
| Returns top_n words for each topic at the given time step. | |
| Output: List[List[str]], each inner list is the top_n words for a topic. | |
| """ | |
| beta = self.get_beta() # shape: [T, K, V] | |
| K = beta.shape[1] | |
| return [ | |
| self.get_top_words_at_time(topic_id=k, time=time, top_n=top_n) | |
| for k in range(K) | |
| ] | |
| def get_all_topics_over_time(self, top_n=10): | |
| """ | |
| Returns the top_n words for all topics over all time steps. | |
| Output shape: List[List[List[str]]] = T x K x top_n | |
| """ | |
| beta = self.get_beta() # shape: [T, K, V] | |
| T, K, _ = beta.shape | |
| return [ | |
| [ | |
| self.get_top_words_at_time(topic_id=k, time=t, top_n=top_n) | |
| for k in range(K) | |
| ] | |
| for t in range(T) | |
| ] | |