zakerytclarke commited on
Commit
60ae096
·
verified ·
1 Parent(s): 33112c4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +53 -45
handler.py CHANGED
@@ -1,14 +1,7 @@
1
  # handler.py
2
- #
3
- # Hugging Face Inference Endpoints custom handler for teapotai/tinyteapot (T5/Flan-T5 style seq2seq).
4
- # - Uses the mounted model directory (`path`, typically "/repository") exactly like your notebook loads from Hub.
5
- # - Forces the *slow* SentencePiece tokenizer (use_fast=False) to avoid tokenizer.json / fast-tokenizer mismatch issues.
6
- # => Requires `spiece.model` to be present in the repo root.
7
- # - Left-truncates inputs to keep only the most recent 512 tokens (matches your request).
8
- # - Deterministic generation (do_sample=False).
9
-
10
  from __future__ import annotations
11
 
 
12
  from typing import Any, Dict, Union
13
 
14
  import torch
@@ -27,75 +20,89 @@ DEFAULT_SYSTEM_PROMPT = (
27
  )
28
 
29
 
30
- class EndpointHandler:
31
- """
32
- HF Inference Endpoints will instantiate this class once, then call it per-request.
33
- """
 
 
34
 
 
35
  def __init__(self, path: str = ""):
36
- # Force slow tokenizer to guarantee consistency with SentencePiece vocab (spiece.model).
37
- # This avoids fast-tokenizer init paths that can diverge across environments.
 
 
 
 
 
 
 
38
  self.tokenizer = AutoTokenizer.from_pretrained(
39
  path,
40
  use_fast=False,
41
  model_max_length=MAX_INPUT_TOKENS,
42
  )
43
-
44
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
45
 
46
- # CPU by default on small models; endpoints sets device to CPU in your logs.
47
  self.device = torch.device("cpu")
48
  self.model.to(self.device)
49
  self.model.eval()
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  self.system_prompt = DEFAULT_SYSTEM_PROMPT
52
 
53
  @torch.inference_mode()
54
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
55
- """
56
- Accepts either:
57
- - {"inputs": "<full prompt string>"} (raw mode)
58
- - {"inputs": {"context": "...", "question": "...", "system_prompt": "..."}}
59
- Optional generation knobs:
60
- - {"parameters": {"max_new_tokens": 128}}
61
- """
62
- if not isinstance(data, dict):
63
- raise ValueError("Request payload must be a JSON object.")
64
-
65
- if "inputs" not in data:
66
- raise ValueError("Missing required field: 'inputs'.")
67
 
68
  inputs: Union[str, Dict[str, Any]] = data["inputs"]
69
-
70
- # Optional: generation parameters
71
  params = data.get("parameters") or {}
72
- try:
73
- max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))
74
- except Exception:
75
- max_new_tokens = DEFAULT_MAX_NEW_TOKENS
76
 
77
- # Build prompt exactly like your notebook logic:
78
- # prompt = f"{context}\n{system_prompt}\n{question}\n"
79
  if isinstance(inputs, str):
80
  prompt = inputs
81
  elif isinstance(inputs, dict):
82
  context = inputs.get("context", "")
83
  question = inputs.get("question", "")
84
  system_prompt = inputs.get("system_prompt", self.system_prompt)
85
-
86
- if not isinstance(context, str) or not isinstance(question, str) or not isinstance(system_prompt, str):
87
- raise ValueError("'context', 'question', and 'system_prompt' must be strings.")
88
-
89
  prompt = f"{context}\n{system_prompt}\n{question}\n"
90
  else:
91
  raise ValueError("'inputs' must be a string or an object with {context, question}.")
92
 
93
- # Tokenize
94
  enc = self.tokenizer(prompt, return_tensors="pt")
95
  input_ids = enc["input_ids"]
96
- attention_mask = enc.get("attention_mask", None)
97
 
98
- # Left-truncate to keep only most recent tokens (last 512)
99
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
100
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
101
  if attention_mask is not None:
@@ -105,14 +112,15 @@ class EndpointHandler:
105
  if attention_mask is not None:
106
  attention_mask = attention_mask.to(self.device)
107
 
108
- # Generate deterministically
109
  out = self.model.generate(
110
  input_ids=input_ids,
111
  attention_mask=attention_mask,
112
  do_sample=False,
113
  num_beams=1,
114
  max_new_tokens=max_new_tokens,
115
- use_cache=True,
 
 
116
  )
117
 
118
  text = self.tokenizer.decode(out[0], skip_special_tokens=True)
 
1
  # handler.py
 
 
 
 
 
 
 
 
2
  from __future__ import annotations
3
 
4
+ import os
5
  from typing import Any, Dict, Union
6
 
7
  import torch
 
20
  )
21
 
22
 
23
+ def _path_exists(p: str) -> bool:
24
+ try:
25
+ return os.path.exists(p)
26
+ except Exception:
27
+ return False
28
+
29
 
30
+ class EndpointHandler:
31
  def __init__(self, path: str = ""):
32
+ # Sanity: ensure key files exist in the mounted repo
33
+ spiece_path = os.path.join(path, "spiece.model")
34
+ tokjson_path = os.path.join(path, "tokenizer.json")
35
+ cfg_path = os.path.join(path, "config.json")
36
+
37
+ print(f"[teapot] model_dir={path}")
38
+ print(f"[teapot] exists config.json={_path_exists(cfg_path)} tokenizer.json={_path_exists(tokjson_path)} spiece.model={_path_exists(spiece_path)}")
39
+
40
+ # Force SentencePiece tokenizer (slow)
41
  self.tokenizer = AutoTokenizer.from_pretrained(
42
  path,
43
  use_fast=False,
44
  model_max_length=MAX_INPUT_TOKENS,
45
  )
 
46
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
47
 
 
48
  self.device = torch.device("cpu")
49
  self.model.to(self.device)
50
  self.model.eval()
51
 
52
+ # ----------------------------
53
+ # CRITICAL CONSISTENCY CHECKS
54
+ # ----------------------------
55
+ tok_len = len(self.tokenizer) # includes added tokens
56
+ tok_vocab_size = getattr(self.tokenizer, "vocab_size", None) # base vocab (T5 SP)
57
+ cfg_vocab = getattr(self.model.config, "vocab_size", None)
58
+ emb_rows = int(self.model.get_input_embeddings().weight.shape[0])
59
+
60
+ print(f"[teapot] tokenizer_class={type(self.tokenizer).__name__} use_fast={getattr(self.tokenizer, 'is_fast', None)}")
61
+ print(f"[teapot] len(tokenizer)={tok_len} tokenizer.vocab_size={tok_vocab_size} model.config.vocab_size={cfg_vocab} embedding_rows={emb_rows}")
62
+ print(f"[teapot] special_tokens: pad={self.tokenizer.pad_token} eos={self.tokenizer.eos_token} unk={self.tokenizer.unk_token}")
63
+
64
+ # If you ever resized embeddings, these MUST match:
65
+ # - embedding rows must equal len(tokenizer)
66
+ # - config vocab_size should match embedding rows
67
+ if emb_rows != tok_len:
68
+ raise RuntimeError(
69
+ f"[teapot] FATAL: embedding_rows ({emb_rows}) != len(tokenizer) ({tok_len}). "
70
+ "This means your model weights and tokenizer files are out of sync in the repo. "
71
+ "Fix by re-saving model+tokenizer together after resize_token_embeddings."
72
+ )
73
+ if cfg_vocab is not None and cfg_vocab != emb_rows:
74
+ raise RuntimeError(
75
+ f"[teapot] FATAL: model.config.vocab_size ({cfg_vocab}) != embedding_rows ({emb_rows}). "
76
+ "Your config.json is inconsistent with the weights. Re-save model to update config."
77
+ )
78
+
79
  self.system_prompt = DEFAULT_SYSTEM_PROMPT
80
 
81
  @torch.inference_mode()
82
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
83
+ if not isinstance(data, dict) or "inputs" not in data:
84
+ raise ValueError("Request must be JSON with an 'inputs' field.")
 
 
 
 
 
 
 
 
 
 
85
 
86
  inputs: Union[str, Dict[str, Any]] = data["inputs"]
 
 
87
  params = data.get("parameters") or {}
 
 
 
 
88
 
89
+ max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))
90
+
91
  if isinstance(inputs, str):
92
  prompt = inputs
93
  elif isinstance(inputs, dict):
94
  context = inputs.get("context", "")
95
  question = inputs.get("question", "")
96
  system_prompt = inputs.get("system_prompt", self.system_prompt)
 
 
 
 
97
  prompt = f"{context}\n{system_prompt}\n{question}\n"
98
  else:
99
  raise ValueError("'inputs' must be a string or an object with {context, question}.")
100
 
 
101
  enc = self.tokenizer(prompt, return_tensors="pt")
102
  input_ids = enc["input_ids"]
103
+ attention_mask = enc.get("attention_mask")
104
 
105
+ # Keep most recent tokens (left truncate)
106
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
107
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
108
  if attention_mask is not None:
 
112
  if attention_mask is not None:
113
  attention_mask = attention_mask.to(self.device)
114
 
 
115
  out = self.model.generate(
116
  input_ids=input_ids,
117
  attention_mask=attention_mask,
118
  do_sample=False,
119
  num_beams=1,
120
  max_new_tokens=max_new_tokens,
121
+ # Band-aid to prevent pathological repeats, but not a real fix:
122
+ repetition_penalty=1.05,
123
+ no_repeat_ngram_size=3,
124
  )
125
 
126
  text = self.tokenizer.decode(out[0], skip_special_tokens=True)