canovich commited on
Commit
e9b6e3f
·
1 Parent(s): 00fb7d5

Upload code/ with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/inference.py +7 -5
code/inference.py CHANGED
@@ -1,7 +1,7 @@
1
  import logging, requests, os, io, glob, time
2
  import json
3
 
4
-
5
  from transformers import BertTokenizer
6
  from transformers import PreTrainedModel
7
  import torch
@@ -276,10 +276,12 @@ JSON_CONTENT_TYPE = 'application/json'
276
 
277
 
278
  # loads the model into memory from disk and returns it
279
- def model_fn():
280
-
281
- model = AutoModelForSeq2SeqLM.from_pretrained("canovich/myprivateee")
282
- return model
 
 
283
 
284
 
285
  # Perform prediction on the deserialized object, with the loaded model
 
1
  import logging, requests, os, io, glob, time
2
  import json
3
 
4
+ from transformers import T5TokenizerFast
5
  from transformers import BertTokenizer
6
  from transformers import PreTrainedModel
7
  import torch
 
276
 
277
 
278
  # loads the model into memory from disk and returns it
279
+ def model_fn(model_dir):
280
+ # Load model from HuggingFace Hub
281
+ tokenizer = T5TokenizerFast(model_dir, extra_ids=0)
282
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
283
+ return model, tokenizer
284
+
285
 
286
 
287
  # Perform prediction on the deserialized object, with the loaded model