Spaces:
Paused
Paused
Update multipurpose_chatbot/engines/transformers_engine.py
Browse files
multipurpose_chatbot/engines/transformers_engine.py
CHANGED
|
@@ -429,7 +429,7 @@ class TransformersEngine(BaseEngine):
|
|
| 429 |
|
| 430 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
| 431 |
import sys
|
| 432 |
-
|
| 433 |
with torch.no_grad():
|
| 434 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
| 435 |
num_tokens = inputs.input_ids.size(1)
|
|
@@ -450,7 +450,7 @@ class TransformersEngine(BaseEngine):
|
|
| 450 |
out_tokens.extend(token.tolist())
|
| 451 |
response = self.tokenizer.decode(out_tokens)
|
| 452 |
if "<|im_start|>assistant\n" in response:
|
| 453 |
-
response = response.split("<|im_start|>assistant\n")
|
| 454 |
num_tokens += 1
|
| 455 |
print(f"{response}", end='\r')
|
| 456 |
sys.stdout.flush()
|
|
@@ -458,7 +458,7 @@ class TransformersEngine(BaseEngine):
|
|
| 458 |
|
| 459 |
if response is not None:
|
| 460 |
if "<|im_start|>assistant\n" in response:
|
| 461 |
-
response = response.split("<|im_start|>assistant\n")
|
| 462 |
full_text = prompt + response
|
| 463 |
num_tokens = len(self.tokenizer.encode(full_text))
|
| 464 |
yield response, num_tokens
|
|
|
|
| 429 |
|
| 430 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
| 431 |
import sys
|
| 432 |
+
self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
| 433 |
with torch.no_grad():
|
| 434 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
| 435 |
num_tokens = inputs.input_ids.size(1)
|
|
|
|
| 450 |
out_tokens.extend(token.tolist())
|
| 451 |
response = self.tokenizer.decode(out_tokens)
|
| 452 |
if "<|im_start|>assistant\n" in response:
|
| 453 |
+
response = response.split("<|im_start|>assistant\n")[-1]
|
| 454 |
num_tokens += 1
|
| 455 |
print(f"{response}", end='\r')
|
| 456 |
sys.stdout.flush()
|
|
|
|
| 458 |
|
| 459 |
if response is not None:
|
| 460 |
if "<|im_start|>assistant\n" in response:
|
| 461 |
+
response = response.split("<|im_start|>assistant\n")[-1]
|
| 462 |
full_text = prompt + response
|
| 463 |
num_tokens = len(self.tokenizer.encode(full_text))
|
| 464 |
yield response, num_tokens
|