Spaces:
Runtime error
Runtime error
Update sentence_embeddings.py
Browse files- sentence_embeddings.py +3 -2
sentence_embeddings.py
CHANGED
|
@@ -7,6 +7,7 @@ from datetime import datetime
|
|
| 7 |
from logger import log
|
| 8 |
from config import TEST_MODE
|
| 9 |
|
|
|
|
| 10 |
router = APIRouter()
|
| 11 |
|
| 12 |
class SentenceEmbeddingsInput(BaseModel):
|
|
@@ -59,11 +60,11 @@ def generic_sentence_embeddings(model_name: str):
|
|
| 59 |
tokenizer, model = loaded_models[model_name]
|
| 60 |
else:
|
| 61 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 62 |
-
model = AutoModel.from_pretrained(model_name)
|
| 63 |
loaded_models[model] = (tokenizer, model)
|
| 64 |
|
| 65 |
# Tokenize sentences
|
| 66 |
-
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
|
| 67 |
with torch.no_grad():
|
| 68 |
model_output = model(**encoded_input)
|
| 69 |
sentence_embeddings = model_output[0][:, 0]
|
|
|
|
| 7 |
from logger import log
|
| 8 |
from config import TEST_MODE
|
| 9 |
|
| 10 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 11 |
router = APIRouter()
|
| 12 |
|
| 13 |
class SentenceEmbeddingsInput(BaseModel):
|
|
|
|
| 60 |
tokenizer, model = loaded_models[model_name]
|
| 61 |
else:
|
| 62 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 63 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
| 64 |
loaded_models[model] = (tokenizer, model)
|
| 65 |
|
| 66 |
# Tokenize sentences
|
| 67 |
+
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
|
| 68 |
with torch.no_grad():
|
| 69 |
model_output = model(**encoded_input)
|
| 70 |
sentence_embeddings = model_output[0][:, 0]
|