zakerytclarke commited on
Commit
33112c4
·
verified ·
1 Parent(s): 21be093

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +85 -25
handler.py CHANGED
@@ -1,59 +1,119 @@
1
- # handler.py (repo root)
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
 
 
4
 
5
  MAX_INPUT_TOKENS = 512
 
 
 
 
 
 
 
 
 
 
6
 
7
  class EndpointHandler:
 
 
 
 
8
  def __init__(self, path: str = ""):
9
- # Load exactly from the mounted model dir ("/repository")
10
- self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
 
 
 
 
11
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
12
 
13
- self.model.eval()
14
  self.device = torch.device("cpu")
15
  self.model.to(self.device)
 
16
 
17
- self.system_prompt = (
18
- "You are Teapot, an open-source AI assistant optimized for low-end devices, "
19
- "providing short, accurate responses without hallucinating while excelling at "
20
- "information extraction and text summarization. "
21
- "If the context does not answer the question, reply exactly: "
22
- "'I am sorry but I don't have any information on that'."
23
- )
24
 
25
  @torch.inference_mode()
26
- def __call__(self, data):
27
- inputs = data.get("inputs")
28
- if inputs is None:
29
- raise ValueError("Missing required field 'inputs'.")
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  if isinstance(inputs, str):
32
  prompt = inputs
33
  elif isinstance(inputs, dict):
34
  context = inputs.get("context", "")
35
  question = inputs.get("question", "")
36
- prompt = f"{context}\n{self.system_prompt}\n{question}\n"
 
 
 
 
 
37
  else:
38
- raise ValueError("inputs must be a string or dict.")
39
 
 
40
  enc = self.tokenizer(prompt, return_tensors="pt")
41
  input_ids = enc["input_ids"]
42
- attention_mask = enc["attention_mask"]
43
 
44
- # keep most recent 512 tokens
45
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
46
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
47
- attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
 
48
 
49
  input_ids = input_ids.to(self.device)
50
- attention_mask = attention_mask.to(self.device)
 
51
 
52
- outputs = self.model.generate(
 
53
  input_ids=input_ids,
54
  attention_mask=attention_mask,
55
  do_sample=False,
 
 
 
56
  )
57
 
58
- answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
59
- return {"generated_text": answer}
 
1
+ # handler.py
2
+ #
3
+ # Hugging Face Inference Endpoints custom handler for teapotai/tinyteapot (T5/Flan-T5 style seq2seq).
4
+ # - Uses the mounted model directory (`path`, typically "/repository") exactly like your notebook loads from Hub.
5
+ # - Forces the *slow* SentencePiece tokenizer (use_fast=False) to avoid tokenizer.json / fast-tokenizer mismatch issues.
6
+ # => Requires `spiece.model` to be present in the repo root.
7
+ # - Left-truncates inputs to keep only the most recent 512 tokens (matches your request).
8
+ # - Deterministic generation (do_sample=False).
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Dict, Union
13
+
14
  import torch
15
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
16
+
17
 
18
  MAX_INPUT_TOKENS = 512
19
+ DEFAULT_MAX_NEW_TOKENS = 128
20
+
21
+ DEFAULT_SYSTEM_PROMPT = (
22
+ "You are Teapot, an open-source AI assistant optimized for low-end devices, "
23
+ "providing short, accurate responses without hallucinating while excelling at "
24
+ "information extraction and text summarization. "
25
+ "If the context does not answer the question, reply exactly: "
26
+ "'I am sorry but I don't have any information on that'."
27
+ )
28
+
29
 
30
  class EndpointHandler:
31
+ """
32
+ HF Inference Endpoints will instantiate this class once, then call it per-request.
33
+ """
34
+
35
  def __init__(self, path: str = ""):
36
+ # Force slow tokenizer to guarantee consistency with SentencePiece vocab (spiece.model).
37
+ # This avoids fast-tokenizer init paths that can diverge across environments.
38
+ self.tokenizer = AutoTokenizer.from_pretrained(
39
+ path,
40
+ use_fast=False,
41
+ model_max_length=MAX_INPUT_TOKENS,
42
+ )
43
+
44
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
45
 
46
+ # CPU by default on small models; endpoints sets device to CPU in your logs.
47
  self.device = torch.device("cpu")
48
  self.model.to(self.device)
49
+ self.model.eval()
50
 
51
+ self.system_prompt = DEFAULT_SYSTEM_PROMPT
 
 
 
 
 
 
52
 
53
  @torch.inference_mode()
54
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
55
+ """
56
+ Accepts either:
57
+ - {"inputs": "<full prompt string>"} (raw mode)
58
+ - {"inputs": {"context": "...", "question": "...", "system_prompt": "..."}}
59
+ Optional generation knobs:
60
+ - {"parameters": {"max_new_tokens": 128}}
61
+ """
62
+ if not isinstance(data, dict):
63
+ raise ValueError("Request payload must be a JSON object.")
64
+
65
+ if "inputs" not in data:
66
+ raise ValueError("Missing required field: 'inputs'.")
67
 
68
+ inputs: Union[str, Dict[str, Any]] = data["inputs"]
69
+
70
+ # Optional: generation parameters
71
+ params = data.get("parameters") or {}
72
+ try:
73
+ max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))
74
+ except Exception:
75
+ max_new_tokens = DEFAULT_MAX_NEW_TOKENS
76
+
77
+ # Build prompt exactly like your notebook logic:
78
+ # prompt = f"{context}\n{system_prompt}\n{question}\n"
79
  if isinstance(inputs, str):
80
  prompt = inputs
81
  elif isinstance(inputs, dict):
82
  context = inputs.get("context", "")
83
  question = inputs.get("question", "")
84
+ system_prompt = inputs.get("system_prompt", self.system_prompt)
85
+
86
+ if not isinstance(context, str) or not isinstance(question, str) or not isinstance(system_prompt, str):
87
+ raise ValueError("'context', 'question', and 'system_prompt' must be strings.")
88
+
89
+ prompt = f"{context}\n{system_prompt}\n{question}\n"
90
  else:
91
+ raise ValueError("'inputs' must be a string or an object with {context, question}.")
92
 
93
+ # Tokenize
94
  enc = self.tokenizer(prompt, return_tensors="pt")
95
  input_ids = enc["input_ids"]
96
+ attention_mask = enc.get("attention_mask", None)
97
 
98
+ # Left-truncate to keep only most recent tokens (last 512)
99
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
100
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
101
+ if attention_mask is not None:
102
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
103
 
104
  input_ids = input_ids.to(self.device)
105
+ if attention_mask is not None:
106
+ attention_mask = attention_mask.to(self.device)
107
 
108
+ # Generate deterministically
109
+ out = self.model.generate(
110
  input_ids=input_ids,
111
  attention_mask=attention_mask,
112
  do_sample=False,
113
+ num_beams=1,
114
+ max_new_tokens=max_new_tokens,
115
+ use_cache=True,
116
  )
117
 
118
+ text = self.tokenizer.decode(out[0], skip_special_tokens=True)
119
+ return {"generated_text": text}