ZinebSN commited on
Commit
a535e2c
·
1 Parent(s): 9f63629

Update t5_model.py

Browse files
Files changed (1) hide show
  1. t5_model.py +5 -0
t5_model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
 
2
  class T5(pl.LightningModule):
3
  def __init__(self, lr=5e-5, num_train_epochs=15, warmup_steps=1000):
@@ -85,6 +87,9 @@ class T5(pl.LightningModule):
85
  # Save the model
86
  self.model.push_to_hub(model_name, organization)
87
 
 
 
 
88
 
89
  def train_dataloader(self):
90
  return train_dataloader
 
1
+ import pytorch_lightning as pl
2
+ from transformers import AutoModelForSeq2SeqLM
3
 
4
  class T5(pl.LightningModule):
5
  def __init__(self, lr=5e-5, num_train_epochs=15, warmup_steps=1000):
 
87
  # Save the model
88
  self.model.push_to_hub(model_name, organization)
89
 
90
+ def from_pretrained(self, model_path):
91
+ AutoModelForSeq2SeqLM.from_pretrained(model_path)
92
+
93
 
94
  def train_dataloader(self):
95
  return train_dataloader