Srijith Rajamohan commited on
Commit
161cf85
·
1 Parent(s): ef2841e

Updated custom handler

Browse files
Files changed (1) hide show
  1. handler.py +18 -2
handler.py CHANGED
@@ -1,10 +1,26 @@
1
  from typing import Dict, List, Any
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  class EndpointHandler():
4
  def __init__(self, path=""):
5
  # Preload all the elements you are going to need at inference.
6
- # pseudo:
7
- self.model= load_model(path)
 
 
 
 
8
 
9
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
10
  """
 
1
  from typing import Dict, List, Any
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer)
5
+ import torch
6
+
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ "sjster/test_medium",
9
+ trust_remote_code=True,
10
+ quantization_config=None,
11
+ torch_dtype=torch.float, # data type is float
12
+ device_map="auto",
13
+ )
14
 
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
  # Preload all the elements you are going to need at inference.
18
+ self.model = AutoModelForCausalLM.from_pretrained(
19
+ path,
20
+ trust_remote_code=True,
21
+ quantization_config=None,
22
+ torch_dtype=torch.float, # data type is float
23
+ device_map="auto",
24
 
25
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
26
  """