Update handler.py
Browse files- handler.py +70 -31
handler.py
CHANGED
|
@@ -135,48 +135,87 @@ class EndpointHandler:
|
|
| 135 |
# best-effort: return canonicalized even if schema still complains
|
| 136 |
return obj
|
| 137 |
|
| 138 |
-
def
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
if not messages:
|
| 146 |
-
|
| 147 |
-
messages = [{"role": "user", "content":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
# Build a
|
| 150 |
-
sys_text = tools_to_system_text(tools) if
|
| 151 |
-
msgs = []
|
| 152 |
-
if sys_text:
|
| 153 |
-
msgs.append({"role": "system", "content": sys_text})
|
| 154 |
-
msgs.extend(messages)
|
| 155 |
|
| 156 |
-
#
|
| 157 |
user_text = ""
|
| 158 |
-
for m in msgs:
|
| 159 |
if m.get("role") == "user":
|
| 160 |
user_text = m.get("content", "")
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
with torch.inference_mode():
|
| 167 |
-
out = self.model.generate(
|
| 168 |
-
input_ids=prompt,
|
| 169 |
-
max_new_tokens=max_new,
|
| 170 |
-
do_sample=temperature > 0,
|
| 171 |
-
temperature=temperature if temperature > 0 else None,
|
| 172 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
| 173 |
-
)
|
| 174 |
-
raw = self.tokenizer.decode(out[0][prompt.shape[-1]:], skip_special_tokens=True).strip()
|
| 175 |
|
|
|
|
| 176 |
guarded = self._apply_guard(user_text, tools, raw)
|
| 177 |
|
| 178 |
-
# Return both for convenience
|
| 179 |
return {
|
| 180 |
-
"generated_text": raw,
|
| 181 |
-
"envelope": guarded
|
| 182 |
}
|
|
|
|
| 135 |
# best-effort: return canonicalized even if schema still complains
|
| 136 |
return obj
|
| 137 |
|
| 138 |
+
def _unpack(self, data: Dict[str, Any]):
|
| 139 |
+
"""Normalize payload coming from IE:
|
| 140 |
+
- accept top-level or inputs-nested messages/tools
|
| 141 |
+
- accept parameters both top-level and nested
|
| 142 |
+
"""
|
| 143 |
+
body = data.get("inputs", data) # if no "inputs", body == data
|
| 144 |
+
params = data.get("parameters") or {}
|
| 145 |
+
# pull messages/tools from body if dict
|
| 146 |
+
messages = None
|
| 147 |
+
tools = None
|
| 148 |
+
if isinstance(body, dict):
|
| 149 |
+
messages = body.get("messages")
|
| 150 |
+
tools = body.get("tools") or body.get("functions")
|
| 151 |
+
# allow top-level fallbacks
|
| 152 |
+
if messages is None:
|
| 153 |
+
messages = data.get("messages")
|
| 154 |
+
if tools is None:
|
| 155 |
+
tools = data.get("tools") or data.get("functions") or []
|
| 156 |
+
|
| 157 |
+
# if still no messages, treat body as raw text
|
| 158 |
if not messages:
|
| 159 |
+
raw = body if isinstance(body, str) else data.get("text", "")
|
| 160 |
+
messages = [{"role": "user", "content": str(raw)}]
|
| 161 |
+
|
| 162 |
+
# generation params (support both locations)
|
| 163 |
+
temperature = float(params.get("temperature", data.get("temperature", 0.0)))
|
| 164 |
+
max_new = int(params.get("max_new_tokens", data.get("max_new_tokens", 192)))
|
| 165 |
+
top_p = float(params.get("top_p", data.get("top_p", 1.0)))
|
| 166 |
+
|
| 167 |
+
return messages, tools, temperature, max_new, top_p
|
| 168 |
+
|
| 169 |
+
def _encode_messages(self, msgs: List[dict]):
|
| 170 |
+
# Try chat template; fallback to a simple role-tagged prompt
|
| 171 |
+
try:
|
| 172 |
+
return self.tokenizer.apply_chat_template(
|
| 173 |
+
msgs, add_generation_prompt=True, return_tensors="pt"
|
| 174 |
+
).to(self.model.device)
|
| 175 |
+
except Exception:
|
| 176 |
+
lines = []
|
| 177 |
+
for m in msgs:
|
| 178 |
+
role = m.get("role", "user")
|
| 179 |
+
content = m.get("content", "")
|
| 180 |
+
lines.append(f"{role}: {content}")
|
| 181 |
+
lines.append("assistant:")
|
| 182 |
+
prompt_text = "\n".join(lines)
|
| 183 |
+
toks = self.tokenizer(prompt_text, return_tensors="pt")
|
| 184 |
+
return toks["input_ids"].to(self.model.device)
|
| 185 |
+
|
| 186 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 187 |
+
messages, tools, temperature, max_new, top_p = self._unpack(data)
|
| 188 |
|
| 189 |
+
# Build a system message from tools; prepend to conversation
|
| 190 |
+
sys_text = tools_to_system_text(tools) if tools else None
|
| 191 |
+
msgs = [{"role": "system", "content": sys_text}] + messages if sys_text else messages
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
# Remember last user text for the guard’s heuristics
|
| 194 |
user_text = ""
|
| 195 |
+
for m in reversed(msgs):
|
| 196 |
if m.get("role") == "user":
|
| 197 |
user_text = m.get("content", "")
|
| 198 |
+
break
|
| 199 |
|
| 200 |
+
input_ids = self._encode_messages(msgs)
|
| 201 |
+
|
| 202 |
+
gen_kwargs = dict(
|
| 203 |
+
input_ids=input_ids,
|
| 204 |
+
max_new_tokens=max_new,
|
| 205 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 206 |
+
)
|
| 207 |
+
if temperature > 0:
|
| 208 |
+
gen_kwargs.update(do_sample=True, temperature=temperature, top_p=top_p)
|
| 209 |
+
else:
|
| 210 |
+
gen_kwargs.update(do_sample=False)
|
| 211 |
|
| 212 |
with torch.inference_mode():
|
| 213 |
+
out = self.model.generate(**gen_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
raw = self.tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
|
| 216 |
guarded = self._apply_guard(user_text, tools, raw)
|
| 217 |
|
|
|
|
| 218 |
return {
|
| 219 |
+
"generated_text": raw,
|
| 220 |
+
"envelope": guarded
|
| 221 |
}
|