Spaces:
Runtime error
Runtime error
| import torch | |
| NEGATIVE_INF = -100000.0 | |
| HALF_NEGATIVE_INF = -60000.0 # half precision | |
| def get_first_sentence(txt, min_len=5): | |
| eos = '<|endoftext|>' | |
| eos_idx = txt.find(eos) | |
| if eos_idx > 0: | |
| txt = txt[eos_idx:] | |
| txt = txt.replace('\n', ' ') | |
| sents = txt.split('. ') | |
| if len(sents[0]) >= min_len: | |
| sent = f'{sents[0].strip()}.' | |
| else: | |
| sent = txt | |
| return sent | |
| def logits_to_entropy(logits): | |
| distribution = torch.distributions.Categorical(logits=logits) | |
| return distribution.entropy() | |
| def mask_pad(value, mask): | |
| return value * mask + NEGATIVE_INF * (1 - mask) | |