zakerytclarke commited on
Commit
8334c0b
·
verified ·
1 Parent(s): 21235f2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +51 -70
handler.py CHANGED
@@ -1,71 +1,70 @@
1
  # handler.py
2
- from typing import Any, Dict, List, Union
3
- import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
5
 
6
-
7
- MAX_INPUT_TOKENS = 512
8
 
9
 
10
  class EndpointHandler:
11
- """
12
- HF Inference Endpoints custom handler that reproduces the exact style of
13
- your shared Colab code:
14
- - slow tokenizer (use_fast=False)
15
- - Seq2Seq model
16
- - deterministic generation by default (do_sample=False)
17
- - decode skip_special_tokens=True
18
- - if input > 512 tokens, keep only the MOST RECENT tokens (left-truncate)
19
- """
20
-
21
  def __init__(self, path: str = ""):
22
- # Match your working code path and avoid fast tokenizer init issues on HF endpoints.
23
- self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
24
- self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
25
 
26
  self.model.eval()
27
  self.device = torch.device("cpu")
28
  self.model.to(self.device)
29
 
 
 
 
 
 
 
 
 
 
30
  @torch.inference_mode()
31
- def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, str], List[Dict[str, str]]]:
32
  """
33
- Request schema:
34
- {
35
- "inputs": "<full prompt string>" OR ["<prompt1>", "<prompt2>", ...],
36
- "parameters": { ... optional generate kwargs ... }
37
- }
38
-
39
- Response schema (kept simple):
40
- - single input -> {"generated_text": "..."}
41
- - list inputs -> [{"generated_text": "..."}, ...]
 
 
42
  """
43
- if "inputs" not in data:
44
- raise ValueError("Missing required field 'inputs'.")
45
 
46
- inputs = data["inputs"]
47
- params = data.get("parameters") or {}
 
 
48
 
49
- # Normalize to a batch of prompts
 
 
50
  if isinstance(inputs, str):
51
- prompts = [inputs]
52
- single = True
 
 
 
53
  else:
54
- prompts = list(inputs)
55
- single = False
56
-
57
- # --- Tokenize WITHOUT truncation first so we can left-truncate manually ---
58
- enc = self.tokenizer(
59
- prompts,
60
- return_tensors="pt",
61
- padding=True,
62
- truncation=False,
63
- )
64
 
65
  input_ids = enc["input_ids"]
66
  attention_mask = enc["attention_mask"]
67
 
68
- # Left-truncate to keep the most recent tokens (right side)
69
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
70
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
71
  attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
@@ -73,34 +72,16 @@ class EndpointHandler:
73
  input_ids = input_ids.to(self.device)
74
  attention_mask = attention_mask.to(self.device)
75
 
76
- # Defaults that match your code: model.generate(**inputs, do_sample=False)
77
- # Keep them overrideable via "parameters".
78
- gen_kwargs = {
79
- "do_sample": params.pop("do_sample", False),
80
- }
81
-
82
- # Optional knobs (only applied if provided)
83
- if "max_new_tokens" in params:
84
- gen_kwargs["max_new_tokens"] = params.pop("max_new_tokens")
85
- if "num_beams" in params:
86
- gen_kwargs["num_beams"] = params.pop("num_beams")
87
- if "temperature" in params:
88
- gen_kwargs["temperature"] = params.pop("temperature")
89
- if "top_p" in params:
90
- gen_kwargs["top_p"] = params.pop("top_p")
91
- if "top_k" in params:
92
- gen_kwargs["top_k"] = params.pop("top_k")
93
-
94
- # Allow any remaining generate() kwargs through, in case you pass them
95
- gen_kwargs.update(params)
96
-
97
  outputs = self.model.generate(
98
  input_ids=input_ids,
99
  attention_mask=attention_mask,
100
- **gen_kwargs,
101
  )
102
 
103
- texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
104
 
105
- result = [{"generated_text": t} for t in texts]
106
- return result[0] if single else result
 
 
1
  # handler.py
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
+ MODEL_NAME = "." # HF mounts the repo at /repository, so "." loads local files
6
+ MAX_INPUT_TOKENS = 512
7
 
8
 
9
  class EndpointHandler:
 
 
 
 
 
 
 
 
 
 
10
  def __init__(self, path: str = ""):
11
+ # EXACTLY your loading logic (no use_fast, no overrides)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
14
 
15
  self.model.eval()
16
  self.device = torch.device("cpu")
17
  self.model.to(self.device)
18
 
19
+ # Your exact system prompt
20
+ self.system_prompt = (
21
+ "You are Teapot, an open-source AI assistant optimized for low-end devices, "
22
+ "providing short, accurate responses without hallucinating while excelling at "
23
+ "information extraction and text summarization. "
24
+ "If the context does not answer the question, reply exactly: "
25
+ "'I am sorry but I don't have any information on that'."
26
+ )
27
+
28
  @torch.inference_mode()
29
+ def __call__(self, data):
30
  """
31
+ Expected input format:
32
+ {
33
+ "inputs": {
34
+ "context": "...",
35
+ "question": "..."
36
+ }
37
+ }
38
+ OR
39
+ {
40
+ "inputs": "full prebuilt prompt string"
41
+ }
42
  """
 
 
43
 
44
+ inputs = data.get("inputs")
45
+
46
+ if inputs is None:
47
+ raise ValueError("Missing 'inputs' field")
48
 
49
+ # Support BOTH:
50
+ # 1) Full prompt string (closest to your ask() function)
51
+ # 2) Structured {context, question}
52
  if isinstance(inputs, str):
53
+ prompt = inputs
54
+ elif isinstance(inputs, dict):
55
+ context = inputs.get("context", "")
56
+ question = inputs.get("question", "")
57
+ prompt = f"{context}\n{self.system_prompt}\n{question}\n"
58
  else:
59
+ raise ValueError("inputs must be a string or dict with context/question")
60
+
61
+ # EXACT tokenizer call like your code
62
+ enc = self.tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
63
 
64
  input_ids = enc["input_ids"]
65
  attention_mask = enc["attention_mask"]
66
 
67
+ # NEW requirement: truncate to MOST RECENT 512 tokens
68
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
69
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
70
  attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
 
72
  input_ids = input_ids.to(self.device)
73
  attention_mask = attention_mask.to(self.device)
74
 
75
+ # EXACT generation call from your snippet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  outputs = self.model.generate(
77
  input_ids=input_ids,
78
  attention_mask=attention_mask,
79
+ do_sample=False
80
  )
81
 
82
+ # EXACT decode logic
83
+ answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
84
 
85
+ return {
86
+ "generated_text": answer
87
+ }