rawcell commited on
Commit
3fcc17e
·
verified ·
1 Parent(s): cca395f

Add handler.py for Inference Endpoints

Browse files
Files changed (1) hide show
  1. handler.py +9 -5
handler.py CHANGED
@@ -15,6 +15,7 @@ class EndpointHandler:
15
  inputs = data.pop("inputs", data)
16
  parameters = data.pop("parameters", {})
17
 
 
18
  if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
19
  text = self.tokenizer.apply_chat_template(
20
  inputs,
@@ -26,11 +27,13 @@ class EndpointHandler:
26
 
27
  encoded = self.tokenizer(text, return_tensors="pt").to(self.model.device)
28
 
 
29
  gen_kwargs = {
30
  "max_new_tokens": parameters.get("max_new_tokens", 512),
31
  "temperature": parameters.get("temperature", 0.7),
32
  "top_p": parameters.get("top_p", 0.9),
33
  "do_sample": parameters.get("do_sample", True),
 
34
  }
35
 
36
  with torch.no_grad():
@@ -38,9 +41,10 @@ class EndpointHandler:
38
 
39
  decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
41
- if "<|im_start|>assistant" in decoded:
42
- decoded = decoded.split("<|im_start|>assistant")[-1].strip()
43
- if decoded.endswith("<|im_end|>"):
44
- decoded = decoded[:-10].strip()
 
45
 
46
- return [{"generated_text": decoded}]
 
15
  inputs = data.pop("inputs", data)
16
  parameters = data.pop("parameters", {})
17
 
18
+ # Handle chat format
19
  if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
20
  text = self.tokenizer.apply_chat_template(
21
  inputs,
 
27
 
28
  encoded = self.tokenizer(text, return_tensors="pt").to(self.model.device)
29
 
30
+ # Default generation parameters
31
  gen_kwargs = {
32
  "max_new_tokens": parameters.get("max_new_tokens", 512),
33
  "temperature": parameters.get("temperature", 0.7),
34
  "top_p": parameters.get("top_p", 0.9),
35
  "do_sample": parameters.get("do_sample", True),
36
+ "pad_token_id": self.tokenizer.eos_token_id,
37
  }
38
 
39
  with torch.no_grad():
 
41
 
42
  decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
+ # Return only the generated part (remove input)
45
+ if isinstance(inputs, str):
46
+ generated = decoded[len(inputs):].strip()
47
+ else:
48
+ generated = decoded
49
 
50
+ return [{"generated_text": generated}]