pszemraj commited on
Commit
57679c3
·
verified ·
1 Parent(s): a34fe64

explicit attention mask in generate_reply

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -142,12 +142,20 @@ def generate_reply(
142
  top_p: float = 0.8,
143
  max_retries: int = 10,
144
  ) -> str:
145
- """Implements the 4 guardrails from Appendix C.1."""
146
  messages = build_hf_messages(system_prompt, history_pairs)
147
  inputs = tokenizer.apply_chat_template(
148
  messages, return_tensors="pt", add_generation_prompt=True
149
  ).to(model.device)
150
 
 
 
 
 
 
 
 
 
151
  for _ in range(max_retries):
152
  lp = LogitsProcessorList(
153
  [ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])]
@@ -156,17 +164,18 @@ def generate_reply(
156
  with torch.no_grad():
157
  out = model.generate(
158
  input_ids=inputs,
 
159
  do_sample=True,
160
  top_p=top_p,
161
  temperature=temperature,
162
  max_new_tokens=max_new_tokens,
163
  eos_token_id=EOS_TOKEN_ID,
164
  pad_token_id=tokenizer.eos_token_id,
165
- bad_words_ids=BAD_WORDS_IDS, # Guardrail 2: block <|endconversation|>
166
- logits_processor=lp, # Guardrail 1
167
  )
168
 
169
- gen = out[0][inputs.shape[1] :]
170
  text = tokenizer.decode(gen, skip_special_tokens=True).strip()
171
 
172
  # Guardrails 3 & 4
@@ -179,6 +188,7 @@ def generate_reply(
179
  raise RuntimeError("Failed to generate a valid user utterance after retries.")
180
 
181
 
 
182
  # ======================
183
  # Gradio UI
184
  # ======================
 
142
  top_p: float = 0.8,
143
  max_retries: int = 10,
144
  ) -> str:
145
+ """Implements the 4 guardrails from Appendix C.1 and passes an explicit attention_mask."""
146
  messages = build_hf_messages(system_prompt, history_pairs)
147
  inputs = tokenizer.apply_chat_template(
148
  messages, return_tensors="pt", add_generation_prompt=True
149
  ).to(model.device)
150
 
151
+ # Robust attention mask even when pad_token_id == eos_token_id.
152
+ # If no padding is present (usual single-sequence case), use all-ones.
153
+ pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
154
+ if pad_id is not None and (inputs == pad_id).any():
155
+ attention_mask = (inputs != pad_id).long()
156
+ else:
157
+ attention_mask = torch.ones_like(inputs, dtype=torch.long)
158
+
159
  for _ in range(max_retries):
160
  lp = LogitsProcessorList(
161
  [ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])]
 
164
  with torch.no_grad():
165
  out = model.generate(
166
  input_ids=inputs,
167
+ attention_mask=attention_mask, # <-- explicit mask to silence warning & be robust
168
  do_sample=True,
169
  top_p=top_p,
170
  temperature=temperature,
171
  max_new_tokens=max_new_tokens,
172
  eos_token_id=EOS_TOKEN_ID,
173
  pad_token_id=tokenizer.eos_token_id,
174
+ bad_words_ids=BAD_WORDS_IDS, # Guardrail 2: block <|endconversation|>
175
+ logits_processor=lp, # Guardrail 1: first-token filter
176
  )
177
 
178
+ gen = out[0][inputs.shape[1]:]
179
  text = tokenizer.decode(gen, skip_special_tokens=True).strip()
180
 
181
  # Guardrails 3 & 4
 
188
  raise RuntimeError("Failed to generate a valid user utterance after retries.")
189
 
190
 
191
+
192
  # ======================
193
  # Gradio UI
194
  # ======================