syberWolf commited on
Commit
eff3ac4
·
1 Parent(s): e8628b3

update handler

Browse files
Files changed (1) hide show
  1. handler.py +18 -15
handler.py CHANGED
@@ -4,34 +4,37 @@ import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # load the model
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
11
  model = AutoModelForCausalLM.from_pretrained(
12
  "Qwen/Qwen2-1.5B-Instruct",
13
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
- device_map="auto"
15
  )
16
 
17
- # create inference pipeline without specifying the device
18
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
19
 
20
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  inputs = data.pop("inputs", data)
22
  parameters = data.pop("parameters", {})
23
 
24
- # Ensure inputs are on the GPU if available
25
  if isinstance(inputs, str):
26
  inputs = [inputs]
27
 
28
- # Tensor input handling
29
- try:
30
- inputs = torch.tensor(inputs).cuda() if torch.cuda.is_available() else torch.tensor(inputs)
31
- except:
32
- pass # If inputs are not tensors (e.g., strings), continue without conversion
33
-
34
- # pass inputs with all kwargs in data
35
  prediction = self.pipeline(inputs, **parameters)
36
 
37
  return prediction
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ device = 0 if torch.cuda.is_available() else -1 # 0 for GPU, -1 for CPU
8
 
9
+ # Load the model
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
11
  model = AutoModelForCausalLM.from_pretrained(
12
  "Qwen/Qwen2-1.5B-Instruct",
13
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
14
+ device_map="cuda" # for single instance one GPU
15
  )
16
 
17
+ # Create inference pipeline with the correct device
18
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
19
 
20
+ def __call__(self, data: Any) -> List[List[Dict[str, Any]]]:
21
  inputs = data.pop("inputs", data)
22
  parameters = data.pop("parameters", {})
23
 
 
24
  if isinstance(inputs, str):
25
  inputs = [inputs]
26
 
27
+ # Get predictions from the pipeline
 
 
 
 
 
 
28
  prediction = self.pipeline(inputs, **parameters)
29
 
30
  return prediction
31
+
32
+ # Example usage
33
+ if __name__ == "__main__":
34
+ handler = EndpointHandler()
35
+ data = {
36
+ "inputs": "Hello, how can I",
37
+ "parameters": {"max_length": 50, "num_return_sequences": 1}
38
+ }
39
+ result = handler(data)
40
+ print(result)