ZinebSN commited on
Commit
fb4f1d2
·
1 Parent(s): d0f7f8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -17,9 +17,18 @@ tokenizer = AutoTokenizer.from_pretrained('t5-small')
17
  #model = torch.load(model_name+'/model.pt')
18
 
19
  #if model_name == "T5 Small":
20
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, checkpoint_file="t5_epoch9.ckpt")
21
  #model = torch.load(model_name+'/model.pt')
22
  checkpoint_path=model_name+'/t5_epoch9.ckpt'
 
 
 
 
 
 
 
 
 
23
  #model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
24
 
25
  #else:
 
17
  #model = torch.load(model_name+'/model.pt')
18
 
19
  #if model_name == "T5 Small":
20
+ #model = AutoModelForSeq2SeqLM.from_pretrained(model_name, checkpoint_file="t5_epoch9.ckpt")
21
  #model = torch.load(model_name+'/model.pt')
22
  checkpoint_path=model_name+'/t5_epoch9.ckpt'
23
+
24
+
25
+ # Choose the appropriate device
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model_state_dict = torch.load(checkpoint_path, map_location=device)
28
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
29
+ model.load_state_dict(model_state_dict)
30
+
31
+
32
  #model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
33
 
34
  #else: