Spaces:
Runtime error
Runtime error
Update multipurpose_chatbot/engines/transformers_engine.py
Browse files
multipurpose_chatbot/engines/transformers_engine.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
|
|
|
|
| 2 |
import os
|
| 3 |
import numpy as np
|
| 4 |
import argparse
|
|
@@ -420,7 +421,8 @@ class TransformersEngine(BaseEngine):
|
|
| 420 |
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
| 421 |
print(self._model)
|
| 422 |
print(f"{self.max_position_embeddings=}")
|
| 423 |
-
|
|
|
|
| 424 |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
| 425 |
|
| 426 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
|
@@ -428,7 +430,7 @@ class TransformersEngine(BaseEngine):
|
|
| 428 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
| 429 |
num_tokens = inputs.input_ids.size(1)
|
| 430 |
|
| 431 |
-
inputs = inputs.to(self.
|
| 432 |
|
| 433 |
generator = self._model.generate(
|
| 434 |
**inputs,
|
|
|
|
| 1 |
|
| 2 |
+
import spaces
|
| 3 |
import os
|
| 4 |
import numpy as np
|
| 5 |
import argparse
|
|
|
|
| 421 |
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
| 422 |
print(self._model)
|
| 423 |
print(f"{self.max_position_embeddings=}")
|
| 424 |
+
|
| 425 |
+
@spaces.GPU
|
| 426 |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
| 427 |
|
| 428 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
|
|
|
| 430 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
| 431 |
num_tokens = inputs.input_ids.size(1)
|
| 432 |
|
| 433 |
+
inputs = inputs.to(self._model.device)
|
| 434 |
|
| 435 |
generator = self._model.generate(
|
| 436 |
**inputs,
|