drrobot9 commited on
Commit
0702aa1
·
verified ·
1 Parent(s): 5423133

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +3 -2
app/agents/crew_pipeline.py CHANGED
@@ -79,7 +79,8 @@ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
79
  device_map = 'auto' if DEVICE == 'cuda' else None
80
 
81
  )
82
-
 
83
 
84
 
85
 
@@ -126,7 +127,7 @@ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int =
126
  padding = True,
127
  truncation = True,
128
  max_length = 400
129
- ).to(translation_model,device)
130
 
131
  #setting the target language token
132
  forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang)
 
79
  device_map = 'auto' if DEVICE == 'cuda' else None
80
 
81
  )
82
+ if DEVICE == 'cpu':
83
+ translation_model = translation_model.to('cpu')
84
 
85
 
86
 
 
127
  padding = True,
128
  truncation = True,
129
  max_length = 400
130
+ ).to(translation_model.device)
131
 
132
  #setting the target language token
133
  forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang)