Update app.py
Browse files
app.py
CHANGED
|
@@ -17,9 +17,18 @@ tokenizer = AutoTokenizer.from_pretrained('t5-small')
|
|
| 17 |
#model = torch.load(model_name+'/model.pt')
|
| 18 |
|
| 19 |
#if model_name == "T5 Small":
|
| 20 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, checkpoint_file="t5_epoch9.ckpt")
|
| 21 |
#model = torch.load(model_name+'/model.pt')
|
| 22 |
checkpoint_path=model_name+'/t5_epoch9.ckpt'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
#model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
|
| 24 |
|
| 25 |
#else:
|
|
|
|
| 17 |
#model = torch.load(model_name+'/model.pt')
|
| 18 |
|
| 19 |
#if model_name == "T5 Small":
|
| 20 |
+
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name, checkpoint_file="t5_epoch9.ckpt")
|
| 21 |
#model = torch.load(model_name+'/model.pt')
|
| 22 |
checkpoint_path=model_name+'/t5_epoch9.ckpt'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Choose the appropriate device
|
| 26 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
model_state_dict = torch.load(checkpoint_path, map_location=device)
|
| 28 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
| 29 |
+
model.load_state_dict(model_state_dict)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
#model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
|
| 33 |
|
| 34 |
#else:
|