Spaces:
Runtime error
Runtime error
| # Author: Ricardo Lisboa Santos | |
| # Creation date: 2024-01-10 | |
| import torch | |
| # import torch_directml | |
| from transformers import pipeline | |
| def getDevice(DEVICE): | |
| device = None | |
| if DEVICE == "cpu": | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| elif DEVICE == "cuda": | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # elif DEVICE == "directml": | |
| # device = torch_directml.device() | |
| # dtype = torch.float16 | |
| return device | |
| def loadClassifier(device): | |
| classifier = pipeline("sentiment-analysis") # .to(device) | |
| return classifier | |
| def classify(classifier, text): | |
| output = classifier(text) | |
| return output | |
| def clearCache(DEVICE, classifier): | |
| classifier.tokenizer.save_pretrained("cache") | |
| classifier.model.save_pretrained("cache") | |
| del classifier | |
| # if DEVICE == "directml": | |
| # torch_directml.empty_cache() | |