fix: remove preprompt text from final API response and exclude preprompt from score normalization
Browse files
main.py
CHANGED
|
@@ -274,6 +274,12 @@ async def analyze_text(request: TextRequest):
|
|
| 274 |
|
| 275 |
inputs = tokenizer(full_text, return_tensors="pt").to(device)
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
with torch.no_grad():
|
| 278 |
# Ensure we ask the model to output attentions explicitly
|
| 279 |
outputs = model(**inputs, output_attentions=True)
|
|
@@ -301,13 +307,19 @@ async def analyze_text(request: TextRequest):
|
|
| 301 |
# Calculate importance: sum of attention each token *receives* from the sequence
|
| 302 |
importance = avg_attention.sum(dim=0).cpu().float().numpy()
|
| 303 |
|
| 304 |
-
if len(importance) >
|
| 305 |
-
# Normalize to 0-1,
|
| 306 |
-
# as
|
| 307 |
-
|
| 308 |
-
|
|
|
|
| 309 |
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
# Keep <bos> at max score
|
| 312 |
normalized_scores[0] = 1.0
|
| 313 |
normalized_scores = normalized_scores.clip(0, 1)
|
|
@@ -332,6 +344,10 @@ async def analyze_text(request: TextRequest):
|
|
| 332 |
"score": float(normalized_scores[i])
|
| 333 |
})
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
return {"words": result}
|
| 336 |
|
| 337 |
if __name__ == "__main__":
|
|
|
|
| 274 |
|
| 275 |
inputs = tokenizer(full_text, return_tensors="pt").to(device)
|
| 276 |
|
| 277 |
+
# Calculate how many tokens belong to the preprompt so we can strip them later
|
| 278 |
+
num_preprompt_tokens = 1 # default is 1 for <bos>
|
| 279 |
+
if preprompt:
|
| 280 |
+
p_toks = tokenizer(f"{preprompt}\n\n")["input_ids"]
|
| 281 |
+
num_preprompt_tokens = len(p_toks)
|
| 282 |
+
|
| 283 |
with torch.no_grad():
|
| 284 |
# Ensure we ask the model to output attentions explicitly
|
| 285 |
outputs = model(**inputs, output_attentions=True)
|
|
|
|
| 307 |
# Calculate importance: sum of attention each token *receives* from the sequence
|
| 308 |
importance = avg_attention.sum(dim=0).cpu().float().numpy()
|
| 309 |
|
| 310 |
+
if len(importance) > num_preprompt_tokens:
|
| 311 |
+
# Normalize to 0-1, excluding the preprompt and <bos> from max/min calculation
|
| 312 |
+
# as they often have very high attention, skewing the rest
|
| 313 |
+
text_importance = importance[num_preprompt_tokens:]
|
| 314 |
+
min_score = text_importance.min()
|
| 315 |
+
max_score = text_importance.max()
|
| 316 |
|
| 317 |
+
# Avoid division by zero
|
| 318 |
+
if max_score > min_score:
|
| 319 |
+
normalized_scores = (importance - min_score) / (max_score - min_score)
|
| 320 |
+
else:
|
| 321 |
+
normalized_scores = importance - min_score
|
| 322 |
+
|
| 323 |
# Keep <bos> at max score
|
| 324 |
normalized_scores[0] = 1.0
|
| 325 |
normalized_scores = normalized_scores.clip(0, 1)
|
|
|
|
| 344 |
"score": float(normalized_scores[i])
|
| 345 |
})
|
| 346 |
|
| 347 |
+
# Return only the <bos> token plus the actual text tokens
|
| 348 |
+
if num_preprompt_tokens > 1 and len(result) > num_preprompt_tokens:
|
| 349 |
+
result = [result[0]] + result[num_preprompt_tokens:]
|
| 350 |
+
|
| 351 |
return {"words": result}
|
| 352 |
|
| 353 |
if __name__ == "__main__":
|