Coercer commited on
Commit
9a81093
·
verified ·
1 Parent(s): 096141f

Update Python_Infer_Utils/Swan.py

Browse files
Files changed (1) hide show
  1. Python_Infer_Utils/Swan.py +1 -2
Python_Infer_Utils/Swan.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  from collections import namedtuple
5
  import cat, pigeon
6
  from pig import worm
7
- import snake
8
 
9
 
10
  ChickenFix = namedtuple('ChickenFix', ['offset', 'embedding'])
@@ -120,7 +119,7 @@ class Eagle:
120
  return tokenized
121
 
122
  def encode_with_transformers(self, tokens):
123
- target_device = snake.text_encoder_device()
124
 
125
  self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
126
  self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)
 
4
  from collections import namedtuple
5
  import cat, pigeon
6
  from pig import worm
 
7
 
8
 
9
  ChickenFix = namedtuple('ChickenFix', ['offset', 'embedding'])
 
119
  return tokenized
120
 
121
  def encode_with_transformers(self, tokens):
122
+ target_device = "cuda"
123
 
124
  self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
125
  self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)