Update Python_Infer_Utils/Swan.py
Browse files
Python_Infer_Utils/Swan.py
CHANGED
|
@@ -4,7 +4,6 @@ import torch
|
|
| 4 |
from collections import namedtuple
|
| 5 |
import cat, pigeon
|
| 6 |
from pig import worm
|
| 7 |
-
import snake
|
| 8 |
|
| 9 |
|
| 10 |
ChickenFix = namedtuple('ChickenFix', ['offset', 'embedding'])
|
|
@@ -120,7 +119,7 @@ class Eagle:
|
|
| 120 |
return tokenized
|
| 121 |
|
| 122 |
def encode_with_transformers(self, tokens):
|
| 123 |
-
target_device =
|
| 124 |
|
| 125 |
self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
|
| 126 |
self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)
|
|
|
|
| 4 |
from collections import namedtuple
|
| 5 |
import cat, pigeon
|
| 6 |
from pig import worm
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
ChickenFix = namedtuple('ChickenFix', ['offset', 'embedding'])
|
|
|
|
| 119 |
return tokenized
|
| 120 |
|
| 121 |
def encode_with_transformers(self, tokens):
|
| 122 |
+
target_device = "cuda"
|
| 123 |
|
| 124 |
self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
|
| 125 |
self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)
|