zakerytclarke commited on
Commit
62aec62
·
verified ·
1 Parent(s): 43f4eb6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -22
handler.py CHANGED
@@ -1,19 +1,14 @@
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
 
4
- MODEL_NAME = "teapotai/tinyteapot"
5
  MAX_INPUT_TOKENS = 512
6
 
7
-
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
- # EXACT same as your snippet BUT force slow tokenizer
11
- # This prevents the fast tokenizer crash from extra_special_tokens list
12
- self.tokenizer = AutoTokenizer.from_pretrained(
13
- MODEL_NAME,
14
- use_fast=False
15
- )
16
- self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
17
 
18
  self.model.eval()
19
  self.device = torch.device("cpu")
@@ -30,11 +25,9 @@ class EndpointHandler:
30
  @torch.inference_mode()
31
  def __call__(self, data):
32
  inputs = data.get("inputs")
33
-
34
  if inputs is None:
35
- raise ValueError("Missing 'inputs' field")
36
 
37
- # Match your ask() behavior
38
  if isinstance(inputs, str):
39
  prompt = inputs
40
  elif isinstance(inputs, dict):
@@ -42,15 +35,13 @@ class EndpointHandler:
42
  question = inputs.get("question", "")
43
  prompt = f"{context}\n{self.system_prompt}\n{question}\n"
44
  else:
45
- raise ValueError("inputs must be a string or dict")
46
 
47
- # EXACT tokenizer call like your code
48
  enc = self.tokenizer(prompt, return_tensors="pt")
49
-
50
  input_ids = enc["input_ids"]
51
  attention_mask = enc["attention_mask"]
52
 
53
- # NEW requirement: keep most recent 512 tokens
54
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
55
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
56
  attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
@@ -58,15 +49,11 @@ class EndpointHandler:
58
  input_ids = input_ids.to(self.device)
59
  attention_mask = attention_mask.to(self.device)
60
 
61
- # EXACT generation settings from your snippet
62
  outputs = self.model.generate(
63
  input_ids=input_ids,
64
  attention_mask=attention_mask,
65
- do_sample=False
66
  )
67
 
68
  answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
69
-
70
- return {
71
- "generated_text": answer
72
- }
 
1
+ # handler.py (repo root)
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
 
5
  MAX_INPUT_TOKENS = 512
6
 
 
7
  class EndpointHandler:
8
  def __init__(self, path: str = ""):
9
+ # Load exactly from the mounted model dir ("/repository")
10
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
11
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
 
 
 
 
12
 
13
  self.model.eval()
14
  self.device = torch.device("cpu")
 
25
  @torch.inference_mode()
26
  def __call__(self, data):
27
  inputs = data.get("inputs")
 
28
  if inputs is None:
29
+ raise ValueError("Missing required field 'inputs'.")
30
 
 
31
  if isinstance(inputs, str):
32
  prompt = inputs
33
  elif isinstance(inputs, dict):
 
35
  question = inputs.get("question", "")
36
  prompt = f"{context}\n{self.system_prompt}\n{question}\n"
37
  else:
38
+ raise ValueError("inputs must be a string or dict.")
39
 
 
40
  enc = self.tokenizer(prompt, return_tensors="pt")
 
41
  input_ids = enc["input_ids"]
42
  attention_mask = enc["attention_mask"]
43
 
44
+ # keep most recent 512 tokens
45
  if input_ids.shape[1] > MAX_INPUT_TOKENS:
46
  input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
47
  attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
 
49
  input_ids = input_ids.to(self.device)
50
  attention_mask = attention_mask.to(self.device)
51
 
 
52
  outputs = self.model.generate(
53
  input_ids=input_ids,
54
  attention_mask=attention_mask,
55
+ do_sample=False,
56
  )
57
 
58
  answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
59
+ return {"generated_text": answer}