Spaces:
Sleeping
Sleeping
| from torch import nn | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer | |
| import torch | |
| from torch.utils.data import Dataset | |
| from utils import read_yaml | |
| class BanglaHSDataset(Dataset): | |
| def __init__(self, tokenizer, max_length): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): return 0 | |
| def __getitem__(self, text): | |
| inputs = self.tokenizer( | |
| text, | |
| max_length=self.max_length, padding='max_length', | |
| truncation=True, | |
| return_offsets_mapping=False | |
| ) | |
| for k, v in inputs.items(): inputs[k] = torch.tensor(v, dtype=torch.long).unsqueeze(dim=0) | |
| label = torch.tensor(0, dtype=torch.float) | |
| return inputs, label | |
| def get_class(index): | |
| ind2cat = [ | |
| 'Geopolitical', | |
| 'Personal', | |
| 'Political', | |
| 'Religious', | |
| ] | |
| return ind2cat[index] | |
| if __name__ == '__main__': | |
| cfg = read_yaml('./baseline.yaml') | |
| # cfg.Model.target_size = 6 | |
| # model = BanglaHS_Model(cfg.Model) | |
| # #model.load_state_dict(torch.load('./model_fold-0_best.pt', map_location=torch.device('cpu'))) | |
| # model.eval() | |
| # ds = BanglaHSDataset(cfg.Dataset, model) | |
| # x = ds['Hello hi'][0] | |
| # with torch.no_grad(): | |
| # y = model(x) | |
| # print('y:', y) | |