Honzus24 commited on
Commit
5f38146
·
verified ·
1 Parent(s): 28e5ab6

Update models/T5_encoder_per_token.py

Browse files
Files changed (1) hide show
  1. models/T5_encoder_per_token.py +4 -4
models/T5_encoder_per_token.py CHANGED
@@ -124,11 +124,11 @@ def PT5_classification_model(half_precision, class_config):
124
  # Load PT5 and tokenizer
125
  # possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
126
  if not half_precision:
127
- model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=True)
128
- tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=True)
129
  elif half_precision and torch.cuda.is_available():
130
- tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=True)
131
- model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=True).to(torch.device('cuda'))
132
  else:
133
  raise ValueError('Half precision can be run on GPU only.')
134
 
 
124
  # Load PT5 and tokenizer
125
  # possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
126
  if not half_precision:
127
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False)
128
+ tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False)
129
  elif half_precision and torch.cuda.is_available():
130
+ tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=False)
131
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=False).to(torch.device('cuda'))
132
  else:
133
  raise ValueError('Half precision can be run on GPU only.')
134