jobbler commited on
Commit
5dae651
·
1 Parent(s): 19d388b

fix: remove preprompt text from final API response and exclude preprompt from score normalization

Browse files
Files changed (1) hide show
  1. main.py +22 -6
main.py CHANGED
@@ -274,6 +274,12 @@ async def analyze_text(request: TextRequest):
274
 
275
  inputs = tokenizer(full_text, return_tensors="pt").to(device)
276
 
 
 
 
 
 
 
277
  with torch.no_grad():
278
  # Ensure we ask the model to output attentions explicitly
279
  outputs = model(**inputs, output_attentions=True)
@@ -301,13 +307,19 @@ async def analyze_text(request: TextRequest):
301
  # Calculate importance: sum of attention each token *receives* from the sequence
302
  importance = avg_attention.sum(dim=0).cpu().float().numpy()
303
 
304
- if len(importance) > 1:
305
- # Normalize to 0-1, optionally excluding the first token (<bos>) from max/min calculation
306
- # as <bos> often has very high attention, skewing the rest
307
- min_score = importance[1:].min()
308
- max_score = importance[1:].max()
 
309
 
310
- normalized_scores = (importance - min_score) / (max_score - min_score)
 
 
 
 
 
311
  # Keep <bos> at max score
312
  normalized_scores[0] = 1.0
313
  normalized_scores = normalized_scores.clip(0, 1)
@@ -332,6 +344,10 @@ async def analyze_text(request: TextRequest):
332
  "score": float(normalized_scores[i])
333
  })
334
 
 
 
 
 
335
  return {"words": result}
336
 
337
  if __name__ == "__main__":
 
274
 
275
  inputs = tokenizer(full_text, return_tensors="pt").to(device)
276
 
277
+ # Calculate how many tokens belong to the preprompt so we can strip them later
278
+ num_preprompt_tokens = 1 # default is 1 for <bos>
279
+ if preprompt:
280
+ p_toks = tokenizer(f"{preprompt}\n\n")["input_ids"]
281
+ num_preprompt_tokens = len(p_toks)
282
+
283
  with torch.no_grad():
284
  # Ensure we ask the model to output attentions explicitly
285
  outputs = model(**inputs, output_attentions=True)
 
307
  # Calculate importance: sum of attention each token *receives* from the sequence
308
  importance = avg_attention.sum(dim=0).cpu().float().numpy()
309
 
310
+ if len(importance) > num_preprompt_tokens:
311
+ # Normalize to 0-1, excluding the preprompt and <bos> from max/min calculation
312
+ # as they often have very high attention, skewing the rest
313
+ text_importance = importance[num_preprompt_tokens:]
314
+ min_score = text_importance.min()
315
+ max_score = text_importance.max()
316
 
317
+ # Avoid division by zero
318
+ if max_score > min_score:
319
+ normalized_scores = (importance - min_score) / (max_score - min_score)
320
+ else:
321
+ normalized_scores = importance - min_score
322
+
323
  # Keep <bos> at max score
324
  normalized_scores[0] = 1.0
325
  normalized_scores = normalized_scores.clip(0, 1)
 
344
  "score": float(normalized_scores[i])
345
  })
346
 
347
+ # Return only the <bos> token plus the actual text tokens
348
+ if num_preprompt_tokens > 1 and len(result) > num_preprompt_tokens:
349
+ result = [result[0]] + result[num_preprompt_tokens:]
350
+
351
  return {"words": result}
352
 
353
  if __name__ == "__main__":