explicit attention mask in generate_reply
Browse files
app.py
CHANGED
|
@@ -142,12 +142,20 @@ def generate_reply(
|
|
| 142 |
top_p: float = 0.8,
|
| 143 |
max_retries: int = 10,
|
| 144 |
) -> str:
|
| 145 |
-
"""Implements the 4 guardrails from Appendix C.1."""
|
| 146 |
messages = build_hf_messages(system_prompt, history_pairs)
|
| 147 |
inputs = tokenizer.apply_chat_template(
|
| 148 |
messages, return_tensors="pt", add_generation_prompt=True
|
| 149 |
).to(model.device)
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
for _ in range(max_retries):
|
| 152 |
lp = LogitsProcessorList(
|
| 153 |
[ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])]
|
|
@@ -156,17 +164,18 @@ def generate_reply(
|
|
| 156 |
with torch.no_grad():
|
| 157 |
out = model.generate(
|
| 158 |
input_ids=inputs,
|
|
|
|
| 159 |
do_sample=True,
|
| 160 |
top_p=top_p,
|
| 161 |
temperature=temperature,
|
| 162 |
max_new_tokens=max_new_tokens,
|
| 163 |
eos_token_id=EOS_TOKEN_ID,
|
| 164 |
pad_token_id=tokenizer.eos_token_id,
|
| 165 |
-
bad_words_ids=BAD_WORDS_IDS,
|
| 166 |
-
logits_processor=lp,
|
| 167 |
)
|
| 168 |
|
| 169 |
-
gen = out[0][inputs.shape[1]
|
| 170 |
text = tokenizer.decode(gen, skip_special_tokens=True).strip()
|
| 171 |
|
| 172 |
# Guardrails 3 & 4
|
|
@@ -179,6 +188,7 @@ def generate_reply(
|
|
| 179 |
raise RuntimeError("Failed to generate a valid user utterance after retries.")
|
| 180 |
|
| 181 |
|
|
|
|
| 182 |
# ======================
|
| 183 |
# Gradio UI
|
| 184 |
# ======================
|
|
|
|
| 142 |
top_p: float = 0.8,
|
| 143 |
max_retries: int = 10,
|
| 144 |
) -> str:
|
| 145 |
+
"""Implements the 4 guardrails from Appendix C.1 and passes an explicit attention_mask."""
|
| 146 |
messages = build_hf_messages(system_prompt, history_pairs)
|
| 147 |
inputs = tokenizer.apply_chat_template(
|
| 148 |
messages, return_tensors="pt", add_generation_prompt=True
|
| 149 |
).to(model.device)
|
| 150 |
|
| 151 |
+
# Robust attention mask even when pad_token_id == eos_token_id.
|
| 152 |
+
# If no padding is present (usual single-sequence case), use all-ones.
|
| 153 |
+
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
| 154 |
+
if pad_id is not None and (inputs == pad_id).any():
|
| 155 |
+
attention_mask = (inputs != pad_id).long()
|
| 156 |
+
else:
|
| 157 |
+
attention_mask = torch.ones_like(inputs, dtype=torch.long)
|
| 158 |
+
|
| 159 |
for _ in range(max_retries):
|
| 160 |
lp = LogitsProcessorList(
|
| 161 |
[ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])]
|
|
|
|
| 164 |
with torch.no_grad():
|
| 165 |
out = model.generate(
|
| 166 |
input_ids=inputs,
|
| 167 |
+
attention_mask=attention_mask, # <-- explicit mask to silence warning & be robust
|
| 168 |
do_sample=True,
|
| 169 |
top_p=top_p,
|
| 170 |
temperature=temperature,
|
| 171 |
max_new_tokens=max_new_tokens,
|
| 172 |
eos_token_id=EOS_TOKEN_ID,
|
| 173 |
pad_token_id=tokenizer.eos_token_id,
|
| 174 |
+
bad_words_ids=BAD_WORDS_IDS, # Guardrail 2: block <|endconversation|>
|
| 175 |
+
logits_processor=lp, # Guardrail 1: first-token filter
|
| 176 |
)
|
| 177 |
|
| 178 |
+
gen = out[0][inputs.shape[1]:]
|
| 179 |
text = tokenizer.decode(gen, skip_special_tokens=True).strip()
|
| 180 |
|
| 181 |
# Guardrails 3 & 4
|
|
|
|
| 188 |
raise RuntimeError("Failed to generate a valid user utterance after retries.")
|
| 189 |
|
| 190 |
|
| 191 |
+
|
| 192 |
# ======================
|
| 193 |
# Gradio UI
|
| 194 |
# ======================
|