DTECT / backend /models /dynamic_trainer.py
AdhyaSuman's picture
Initial commit with Git LFS for large files
11c72a2
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import StepLR
from backend.datasets.utils import _utils
from backend.datasets.utils.logger import Logger
logger = Logger("WARNING")
class DynamicTrainer:
def __init__(self,
model,
dataset,
num_top_words=15,
epochs=200,
learning_rate=0.002,
batch_size=200,
lr_scheduler=None,
lr_step_size=125,
log_interval=5,
verbose=False
):
self.model = model
self.dataset = dataset
self.num_top_words = num_top_words
self.epochs = epochs
self.learning_rate = learning_rate
self.batch_size = batch_size
self.lr_scheduler = lr_scheduler
self.lr_step_size = lr_step_size
self.log_interval = log_interval
self.verbose = verbose
if verbose:
logger.set_level("DEBUG")
else:
logger.set_level("WARNING")
def make_optimizer(self,):
args_dict = {
'params': self.model.parameters(),
'lr': self.learning_rate,
}
optimizer = torch.optim.Adam(**args_dict)
return optimizer
def make_lr_scheduler(self, optimizer):
lr_scheduler = StepLR(optimizer, step_size=self.lr_step_size, gamma=0.5, verbose=False)
return lr_scheduler
def train(self):
optimizer = self.make_optimizer()
if self.lr_scheduler:
logger.info("using lr_scheduler")
lr_scheduler = self.make_lr_scheduler(optimizer)
data_size = len(self.dataset.train_dataloader.dataset)
for epoch in tqdm(range(1, self.epochs + 1)):
self.model.train()
loss_rst_dict = defaultdict(float)
for batch_data in self.dataset.train_dataloader:
rst_dict = self.model(batch_data['bow'], batch_data['times'])
batch_loss = rst_dict['loss']
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
for key in rst_dict:
loss_rst_dict[key] += rst_dict[key] * len(batch_data)
if self.lr_scheduler:
lr_scheduler.step()
if epoch % self.log_interval == 0:
output_log = f'Epoch: {epoch:03d}'
for key in loss_rst_dict:
output_log += f' {key}: {loss_rst_dict[key] / data_size :.3f}'
logger.info(output_log)
top_words = self.get_top_words()
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
return top_words, train_theta
def test(self, bow, times):
data_size = bow.shape[0]
theta = list()
all_idx = torch.split(torch.arange(data_size), self.batch_size)
with torch.no_grad():
self.model.eval()
for idx in all_idx:
batch_theta = self.model.get_theta(bow[idx], times[idx])
theta.extend(batch_theta.cpu().tolist())
theta = np.asarray(theta)
return theta
def get_beta(self):
self.model.eval()
beta = self.model.get_beta().detach().cpu().numpy()
return beta
def get_top_words(self, num_top_words=None):
if num_top_words is None:
num_top_words = self.num_top_words
beta = self.get_beta()
top_words_list = list()
for time in range(beta.shape[0]):
if self.verbose:
print(f"======= Time: {time} =======")
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
top_words_list.append(top_words)
return top_words_list
def export_theta(self):
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
test_theta = self.test(self.dataset.test_bow, self.dataset.test_times)
return train_theta, test_theta
def get_top_words_at_time(self, topic_id, time, top_n):
beta = self.get_beta() # shape: [T, K, V]
topic_beta = beta[time, topic_id, :]
top_indices = topic_beta.argsort()[-top_n:][::-1]
return [self.dataset.vocab[i] for i in top_indices]
def get_topic_words_over_time(self, topic_id, top_n):
"""
Returns top_n words for the given topic_id over all time steps.
Output: List[List[str]], each inner list is the top_n words at a time step.
"""
beta = self.get_beta() # shape: [T, K, V]
T = beta.shape[0]
return [
self.get_top_words_at_time(topic_id=topic_id, time=t, top_n=top_n)
for t in range(T)
]
def get_all_topics_at_time(self, time, top_n):
"""
Returns top_n words for each topic at the given time step.
Output: List[List[str]], each inner list is the top_n words for a topic.
"""
beta = self.get_beta() # shape: [T, K, V]
K = beta.shape[1]
return [
self.get_top_words_at_time(topic_id=k, time=time, top_n=top_n)
for k in range(K)
]
def get_all_topics_over_time(self, top_n=10):
"""
Returns the top_n words for all topics over all time steps.
Output shape: List[List[List[str]]] = T x K x top_n
"""
beta = self.get_beta() # shape: [T, K, V]
T, K, _ = beta.shape
return [
[
self.get_top_words_at_time(topic_id=k, time=t, top_n=top_n)
for k in range(K)
]
for t in range(T)
]