chaima01 commited on
Commit
84e87c0
·
verified ·
1 Parent(s): 45fd832

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -32
handler.py CHANGED
@@ -1,39 +1,28 @@
1
  # handler.py
2
-
3
- import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
-
6
-
7
- _generator = None
8
-
9
- def init():
10
-
11
- global _generator
12
- model_dir = "."
13
- device = 0 if torch.cuda.is_available() else -1
14
-
15
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
16
- model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device if device>=0 else "cpu")
17
-
18
- _generator = pipeline(
19
- "text2text-generation",
20
- model=model,
21
- tokenizer=tokenizer,
22
- device=device
23
- )
24
-
25
- def run(payload: dict) -> list:
26
-
27
- text = payload.get("inputs", "")
28
- params = payload.get("parameters", {})
29
- return _generator(text, **params)
30
 
31
  class EndpointHandler:
32
-
33
  def __init__(self, model_dir: str):
34
- # simply delegate to init()
35
- init()
 
 
 
 
 
 
 
 
36
 
37
  def __call__(self, payload: dict) -> list:
38
- # delegate to run()
39
- return run(payload)
 
 
 
 
 
 
 
 
 
1
  # handler.py
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class EndpointHandler:
 
6
  def __init__(self, model_dir: str):
7
+ # load tokenizer & model from the same folder where handler.py lives
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
+ # build a HF pipeline; device_map=“auto” will pick GPU if available
11
+ self.generator = pipeline(
12
+ "text2text-generation",
13
+ model=self.model,
14
+ tokenizer=self.tokenizer,
15
+ device=0 # set to -1 if you want CPU only
16
+ )
17
 
18
  def __call__(self, payload: dict) -> list:
19
+ """
20
+ Expects a JSON payload like:
21
+ {"inputs": "<your question here>", "parameters": {"max_new_tokens": 200}}
22
+ Returns the raw list of dicts that HF pipeline emits.
23
+ """
24
+ text = payload.get("inputs", "")
25
+ params = payload.get("parameters", {})
26
+ # run generation
27
+ outputs = self.generator(text, **params)
28
+ return outputs