Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from model import ( | |
| SwitchTransformer, | |
| SwitchTransformerLayer, | |
| MultiHeadAttention, | |
| SwitchFeedForward, | |
| FeedForward, | |
| ) | |
| from transformers import AutoTokenizer | |
| device = 'cpu' | |
| ff = FeedForward(768, 768*4) | |
| attn = MultiHeadAttention(8, 768, 0.2) | |
| st_ff = SwitchFeedForward( | |
| capacity_factor=1.25, | |
| drop_tokens=False, | |
| n_experts=4, | |
| expert=ff, | |
| d_model=768, | |
| is_scale_prob=True, | |
| ) | |
| st_layer = SwitchTransformerLayer( | |
| d_model=768, | |
| attn=attn, | |
| feed_forward=st_ff, | |
| dropout_prob=0.2 | |
| ) | |
| model = SwitchTransformer( | |
| layer=st_layer, | |
| n_layers=4, | |
| n_experts=4, | |
| device=device, | |
| load_balancing_loss_ceof=0.05, | |
| ).to(device) | |
| model.load_state_dict(torch.load("switch_transformer.pt", map_location=torch.device('cpu'))) | |
| tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz") | |