Spaces:
Running
Running
File size: 5,772 Bytes
11c72a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import numpy as np
from tqdm import tqdm
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import StepLR
from backend.datasets.utils import _utils
from backend.datasets.utils.logger import Logger
logger = Logger("WARNING")
class DynamicTrainer:
def __init__(self,
model,
dataset,
num_top_words=15,
epochs=200,
learning_rate=0.002,
batch_size=200,
lr_scheduler=None,
lr_step_size=125,
log_interval=5,
verbose=False
):
self.model = model
self.dataset = dataset
self.num_top_words = num_top_words
self.epochs = epochs
self.learning_rate = learning_rate
self.batch_size = batch_size
self.lr_scheduler = lr_scheduler
self.lr_step_size = lr_step_size
self.log_interval = log_interval
self.verbose = verbose
if verbose:
logger.set_level("DEBUG")
else:
logger.set_level("WARNING")
def make_optimizer(self,):
args_dict = {
'params': self.model.parameters(),
'lr': self.learning_rate,
}
optimizer = torch.optim.Adam(**args_dict)
return optimizer
def make_lr_scheduler(self, optimizer):
lr_scheduler = StepLR(optimizer, step_size=self.lr_step_size, gamma=0.5, verbose=False)
return lr_scheduler
def train(self):
optimizer = self.make_optimizer()
if self.lr_scheduler:
logger.info("using lr_scheduler")
lr_scheduler = self.make_lr_scheduler(optimizer)
data_size = len(self.dataset.train_dataloader.dataset)
for epoch in tqdm(range(1, self.epochs + 1)):
self.model.train()
loss_rst_dict = defaultdict(float)
for batch_data in self.dataset.train_dataloader:
rst_dict = self.model(batch_data['bow'], batch_data['times'])
batch_loss = rst_dict['loss']
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
for key in rst_dict:
loss_rst_dict[key] += rst_dict[key] * len(batch_data)
if self.lr_scheduler:
lr_scheduler.step()
if epoch % self.log_interval == 0:
output_log = f'Epoch: {epoch:03d}'
for key in loss_rst_dict:
output_log += f' {key}: {loss_rst_dict[key] / data_size :.3f}'
logger.info(output_log)
top_words = self.get_top_words()
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
return top_words, train_theta
def test(self, bow, times):
data_size = bow.shape[0]
theta = list()
all_idx = torch.split(torch.arange(data_size), self.batch_size)
with torch.no_grad():
self.model.eval()
for idx in all_idx:
batch_theta = self.model.get_theta(bow[idx], times[idx])
theta.extend(batch_theta.cpu().tolist())
theta = np.asarray(theta)
return theta
def get_beta(self):
self.model.eval()
beta = self.model.get_beta().detach().cpu().numpy()
return beta
def get_top_words(self, num_top_words=None):
if num_top_words is None:
num_top_words = self.num_top_words
beta = self.get_beta()
top_words_list = list()
for time in range(beta.shape[0]):
if self.verbose:
print(f"======= Time: {time} =======")
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
top_words_list.append(top_words)
return top_words_list
def export_theta(self):
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
test_theta = self.test(self.dataset.test_bow, self.dataset.test_times)
return train_theta, test_theta
def get_top_words_at_time(self, topic_id, time, top_n):
beta = self.get_beta() # shape: [T, K, V]
topic_beta = beta[time, topic_id, :]
top_indices = topic_beta.argsort()[-top_n:][::-1]
return [self.dataset.vocab[i] for i in top_indices]
def get_topic_words_over_time(self, topic_id, top_n):
"""
Returns top_n words for the given topic_id over all time steps.
Output: List[List[str]], each inner list is the top_n words at a time step.
"""
beta = self.get_beta() # shape: [T, K, V]
T = beta.shape[0]
return [
self.get_top_words_at_time(topic_id=topic_id, time=t, top_n=top_n)
for t in range(T)
]
def get_all_topics_at_time(self, time, top_n):
"""
Returns top_n words for each topic at the given time step.
Output: List[List[str]], each inner list is the top_n words for a topic.
"""
beta = self.get_beta() # shape: [T, K, V]
K = beta.shape[1]
return [
self.get_top_words_at_time(topic_id=k, time=time, top_n=top_n)
for k in range(K)
]
def get_all_topics_over_time(self, top_n=10):
"""
Returns the top_n words for all topics over all time steps.
Output shape: List[List[List[str]]] = T x K x top_n
"""
beta = self.get_beta() # shape: [T, K, V]
T, K, _ = beta.shape
return [
[
self.get_top_words_at_time(topic_id=k, time=t, top_n=top_n)
for k in range(K)
]
for t in range(T)
]
|