chaima01 commited on
Commit
9d37bf0
·
verified ·
1 Parent(s): 8708b02

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +22 -8
handler.py CHANGED
@@ -4,20 +4,34 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir: str):
7
- # model_dir is already the repo root
8
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to("cuda")
 
 
 
 
 
 
 
 
 
10
 
11
- # build a text2text pipeline on GPU (device=0)
 
 
 
12
  self.generator = pipeline(
13
  "text2text-generation",
14
- model=model,
15
- tokenizer=tokenizer,
16
- device=0
 
 
17
  )
18
 
19
  def __call__(self, payload: dict) -> list:
20
- # receive {"inputs": "...", "parameters": {...}}
21
  text = payload.get("inputs", "")
22
  params = payload.get("parameters", {})
 
23
  return self.generator(text, **params)
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir: str):
7
+ # some repos upload into a subfolder; detect that:
8
+ # if there's exactly one directory in model_dir that itself has config.json,
9
+ # assume that's the real checkpoint folder.
10
+ candidates = [
11
+ d for d in os.listdir(model_dir)
12
+ if os.path.isdir(os.path.join(model_dir, d))
13
+ and os.path.exists(os.path.join(model_dir, d, "config.json"))
14
+ ]
15
+ if len(candidates) == 1:
16
+ real_dir = os.path.join(model_dir, candidates[0])
17
+ else:
18
+ real_dir = model_dir
19
 
20
+ # now load from the folder that *actually* has your fine‑tuned files
21
+ self.tokenizer = AutoTokenizer.from_pretrained(real_dir)
22
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(real_dir)
23
+ # build the pipeline on GPU if available
24
  self.generator = pipeline(
25
  "text2text-generation",
26
+ model=self.model,
27
+ tokenizer=self.tokenizer,
28
+ device=0, # GPU
29
+ max_new_tokens=500, # defaults you want in every call
30
+ temperature=0.7
31
  )
32
 
33
  def __call__(self, payload: dict) -> list:
 
34
  text = payload.get("inputs", "")
35
  params = payload.get("parameters", {})
36
+ # Calls the pipeline with whatever per-call overrides you pass
37
  return self.generator(text, **params)