Spaces:
Runtime error
Runtime error
| # Author: Ricardo Lisboa Santos | |
| # Creation date: 2024-01-10 | |
| import torch | |
| # import torch_directml | |
| from transformers import pipeline | |
| def getDevice(DEVICE): | |
| device = None | |
| if DEVICE == "cpu": | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| elif DEVICE == "cuda": | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # elif DEVICE == "directml": | |
| # device = torch_directml.device() | |
| # dtype = torch.float16 | |
| return device | |
| def loadSummarizer(device): | |
| summarizer = pipeline("summarization") # .to(device) | |
| return summarizer | |
| def summarize(summarizer, text): | |
| output = summarizer(text) | |
| return output | |
| def clearCache(DEVICE, summarizer): | |
| summarizer.tokenizer.save_pretrained("cache") | |
| summarizer.model.save_pretrained("cache") | |
| del summarizer | |
| # if DEVICE == "directml": | |
| # torch_directml.empty_cache() | |