fahmiaziz98 commited on
Commit
8786174
·
1 Parent(s): 8136c86
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -173,22 +173,29 @@ class QwenReranker(RerankerModel):
173
  )
174
 
175
  def _process_inputs(self, pairs: List[str]):
176
- """Process input pairs for Qwen model with prefix and suffix."""
177
- modified_pairs = [
178
- self.tokenizer.decode(self.prefix_tokens) + text + self.tokenizer.decode(self.suffix_tokens)
179
- for text in pairs
180
- ]
181
-
182
  inputs = self.tokenizer(
183
- modified_pairs,
184
- padding="max_length",
185
- truncation=True,
186
- max_length=self.max_length,
187
- return_tensors="pt"
188
- ).to(self.model.device)
189
-
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  return inputs
191
-
192
 
193
  @torch.no_grad()
194
  def _compute_logits(self, inputs):
 
173
  )
174
 
175
  def _process_inputs(self, pairs: List[str]):
176
+ """Process input pairs for Qwen model."""
 
 
 
 
 
177
  inputs = self.tokenizer(
178
+ pairs,
179
+ padding=False,
180
+ truncation='longest_first',
181
+ return_attention_mask=False,
182
+ max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
183
+ )
184
+
185
+ for i, ele in enumerate(inputs['input_ids']):
186
+ inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
187
+
188
+ inputs = self.tokenizer.pad(
189
+ inputs,
190
+ padding=True,
191
+ return_tensors="pt",
192
+ max_length=self.max_length
193
+ )
194
+
195
+ for key in inputs:
196
+ inputs[key] = inputs[key].to(self.model.device)
197
+
198
  return inputs
 
199
 
200
  @torch.no_grad()
201
  def _compute_logits(self, inputs):