chaima01 commited on
Commit
eb628b0
·
verified ·
1 Parent(s): d1b7b71

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -10
handler.py CHANGED
@@ -1,28 +1,47 @@
 
1
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
 
 
 
 
4
  def init():
5
  """
6
- Called once at container startup. We load the fine-tuned T5 model here.
 
7
  """
8
  global generator
9
- model_dir = "." # root of repo
 
 
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
12
- # build a text2text-generation pipeline on GPU
 
13
  generator = pipeline(
14
  "text2text-generation",
15
  model=model,
16
  tokenizer=tokenizer,
17
- device=0
18
  )
19
 
20
  def run(payload: dict) -> list:
21
  """
22
- Called on every HTTP request.
23
- Expects JSON: { "inputs": "<string>", "parameters": { ... } }
24
- Returns the pipeline output, typically a list of {generated_text: ...}.
 
 
 
 
 
25
  """
26
- text = payload.get("inputs", "")
27
  params = payload.get("parameters", {})
28
- return generator(text, **params)
 
 
 
1
+ # handler.py
2
 
3
+ import os
4
+ import torch
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
+ # these globals will be filled in by init()
8
+ generator = None
9
+
10
  def init():
11
  """
12
+ Called once when the container starts up.
13
+ Load your fine-tuned Flan-T5 model & tokenizer from the repo root.
14
  """
15
  global generator
16
+
17
+ model_dir = "." # root of the repository
18
+ device = 0 if torch.cuda.is_available() else -1
19
+
20
+ # load tokenizer + model
21
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device if device >= 0 else "cpu")
23
+
24
+ # build a HF pipeline for text2text-generation
25
  generator = pipeline(
26
  "text2text-generation",
27
  model=model,
28
  tokenizer=tokenizer,
29
+ device=device # GPU 0 or CPU (-1)
30
  )
31
 
32
  def run(payload: dict) -> list:
33
  """
34
+ Called on every HTTP request.
35
+ Expects JSON with:
36
+ {
37
+ "inputs": "<your prompt string>",
38
+ "parameters": { ...generation kwargs… }
39
+ }
40
+ Returns a list-of-dicts, e.g.:
41
+ [ { "generated_text": "…" } ]
42
  """
43
+ prompt = payload.get("inputs", "")
44
  params = payload.get("parameters", {})
45
+
46
+ # run the pipeline and return its output directly
47
+ return generator(prompt, **params)