Update app/agents/crew_pipeline.py
Browse files
app/agents/crew_pipeline.py
CHANGED
|
@@ -79,7 +79,8 @@ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
| 79 |
device_map = 'auto' if DEVICE == 'cuda' else None
|
| 80 |
|
| 81 |
)
|
| 82 |
-
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
|
|
@@ -126,7 +127,7 @@ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int =
|
|
| 126 |
padding = True,
|
| 127 |
truncation = True,
|
| 128 |
max_length = 400
|
| 129 |
-
).to(translation_model
|
| 130 |
|
| 131 |
#setting the target language token
|
| 132 |
forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang)
|
|
|
|
| 79 |
device_map = 'auto' if DEVICE == 'cuda' else None
|
| 80 |
|
| 81 |
)
|
| 82 |
+
if DEVICE == 'cpu':
|
| 83 |
+
translation_model = translation_model.to('cpu')
|
| 84 |
|
| 85 |
|
| 86 |
|
|
|
|
| 127 |
padding = True,
|
| 128 |
truncation = True,
|
| 129 |
max_length = 400
|
| 130 |
+
).to(translation_model.device)
|
| 131 |
|
| 132 |
#setting the target language token
|
| 133 |
forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang)
|