Update app.py
Browse files
app.py
CHANGED
|
@@ -17,10 +17,10 @@ tokenizer = AutoTokenizer.from_pretrained('t5-small')
|
|
| 17 |
#model = torch.load(model_name+'/model.pt')
|
| 18 |
|
| 19 |
#if model_name == "T5 Small":
|
| 20 |
-
|
| 21 |
#model = torch.load(model_name+'/model.pt')
|
| 22 |
checkpoint_path=model_name+'/t5_epoch9.ckpt'
|
| 23 |
-
model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
|
| 24 |
|
| 25 |
#else:
|
| 26 |
#model = GPT2().from_pretrained(model_name)
|
|
|
|
| 17 |
#model = torch.load(model_name+'/model.pt')
|
| 18 |
|
| 19 |
#if model_name == "T5 Small":
|
| 20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, checkpoint_file="t5_epoch9.ckpt")
|
| 21 |
#model = torch.load(model_name+'/model.pt')
|
| 22 |
checkpoint_path=model_name+'/t5_epoch9.ckpt'
|
| 23 |
+
#model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
|
| 24 |
|
| 25 |
#else:
|
| 26 |
#model = GPT2().from_pretrained(model_name)
|