BlueDice commited on
Commit
876c3f6
·
1 Parent(s): 607c845

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +10 -6
code/inference.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoTokenizer, AutoModel
2
  import torch
3
  import re
4
 
@@ -21,11 +21,15 @@ Alice Gate: *Alice strides into the room with a smile, her eyes lighting up when
21
  Alice Gate:"""
22
 
23
  def model_fn(model_dir):
24
- # Load model from HuggingFace Hub
25
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
26
- model = torch.load(f"{model_dir}/torch_model.pt")
27
- return model, tokenizer
28
-
 
 
 
 
29
 
30
  def create_new_response(result, user_name):
31
  result = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import re
4
 
 
21
  Alice Gate:"""
22
 
23
  def model_fn(model_dir):
24
+ # Load model from HuggingFace Hub
25
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_dir,
28
+ low_cpu_mem_usage = True,
29
+ trust_remote_code = False,
30
+ torch_dtype = torch.float16,
31
+ ).to('cuda')
32
+ return model, tokenizer
33
 
34
  def create_new_response(result, user_name):
35
  result = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()