Spaces:
Configuration error
Configuration error
t5
Browse files
stable/stable_audio_tools/models/conditioners.py
CHANGED
|
@@ -283,9 +283,9 @@ class T5Conditioner(Conditioner):
|
|
| 283 |
# self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
|
| 284 |
# model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
|
| 285 |
self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
|
| 286 |
-
cwd = os.listdir('./')
|
| 287 |
-
print("==========", cwd)
|
| 288 |
-
ckpt = torch.load('
|
| 289 |
model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
|
| 290 |
model.load_state_dict(ckpt,strict=True)
|
| 291 |
|
|
|
|
| 283 |
# self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
|
| 284 |
# model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
|
| 285 |
self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
|
| 286 |
+
# cwd = os.listdir('./')
|
| 287 |
+
# print("==========", cwd)
|
| 288 |
+
ckpt = torch.load('./stable/stable_audio_tools/try_t5.pt')
|
| 289 |
model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
|
| 290 |
model.load_state_dict(ckpt,strict=True)
|
| 291 |
|