Spaces:
Runtime error
Runtime error
| model_session <- R6::R6Class( | |
| lock_objects = FALSE, | |
| public = list( | |
| initialize = function() { | |
| self$task_q <- NULL | |
| self$temperature <- 1 | |
| self$top_k <- 50 | |
| self$is_loaded <- NULL | |
| }, | |
| load_model = function(repo) { | |
| if (!is.null(self$sess)) { | |
| cat("Model is already loaded.", "\n") | |
| return(self$task_q$push(function() "done")) | |
| } | |
| # the tokenizer doesn't need to live in the remote session. | |
| self$tok <- tok::tokenizer$from_pretrained(repo) | |
| self$task_q <- callq::task_q$new(num_workers = 1) | |
| self$task_q$push(args = list(repo = repo), function(repo) { | |
| library(torch) | |
| library(zeallot) | |
| library(minhub) | |
| device <- if (cuda_is_available()) "cuda" else "cpu" | |
| model <<- minhub::gptneox_from_pretrained(repo) | |
| model$eval() | |
| if (device == "cuda") { | |
| model$to(dtype=torch_half()) | |
| model$to(device=device) | |
| } else { | |
| model$to(dtype = torch_float()) | |
| } | |
| "done" | |
| }) | |
| }, | |
| generate = function(idx) { | |
| if (is.null(self$task_q)) { | |
| cat("Model is not loaded, error.", "\n") | |
| return(self$task_q$push(function() stop("Model is not loaded"))) | |
| } | |
| args <- list( | |
| idx = idx, | |
| temperature = self$temperature, | |
| top_k = self$top_k | |
| ) | |
| self$task_q$push(args = args, function(idx, temperature, top_k) { | |
| device <- if (cuda_is_available()) "cuda" else "cpu" | |
| idx <- torch_tensor(idx, device=device)$view(c(1, -1)) | |
| with_no_grad({ | |
| logits <- model(idx + 1L)$to(dtype="float", device="cpu") | |
| }) | |
| logits <- logits[,-1,]/temperature | |
| c(prob, ind) %<-% logits$topk(top_k) | |
| logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob) | |
| logits <- nnf_softmax(logits, dim = -1) | |
| id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L | |
| as.integer(id_next) | |
| }) | |
| } | |
| ) | |
| ) | |