Nursing Citizen Development commited on
Commit
dc8b89c
·
1 Parent(s): c02224c

Fix: Explicitly pass attention_mask to model.generate to resolve warning

Browse files
Files changed (1) hide show
  1. pna_client.py +2 -0
pna_client.py CHANGED
@@ -57,10 +57,12 @@ class PNAAssistantClient:
57
  ]
58
 
59
  inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
 
60
 
61
  with torch.no_grad():
62
  outputs = self.model.generate(
63
  inputs,
 
64
  max_new_tokens=300,
65
  temperature=0.7,
66
  do_sample=True,
 
57
  ]
58
 
59
  inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
60
+ attention_mask = torch.ones_like(inputs).to(self.device)
61
 
62
  with torch.no_grad():
63
  outputs = self.model.generate(
64
  inputs,
65
+ attention_mask=attention_mask,
66
  max_new_tokens=300,
67
  temperature=0.7,
68
  do_sample=True,