Spaces:
Running
Running
Commit
·
18d1c8f
1
Parent(s):
2100725
Added penalty for reasoning and future prediction questions
Browse files- app.py +42 -7
- data_filters.py +4 -0
app.py
CHANGED
|
@@ -24,6 +24,7 @@ from data_filters import (
|
|
| 24 |
FINANCIAL_ENTITY_LABELS,
|
| 25 |
GENERAL_KNOWLEDGE_PATTERNS,
|
| 26 |
sensitive_terms,
|
|
|
|
| 27 |
FINANCIAL_TERMS,
|
| 28 |
)
|
| 29 |
|
|
@@ -266,6 +267,15 @@ def is_general_knowledge_query(query):
|
|
| 266 |
return False
|
| 267 |
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
def is_irrelevant_query(query):
|
| 270 |
"""Check if the query is not finance related"""
|
| 271 |
# If the query is general knowledge and not finance-related
|
|
@@ -365,6 +375,20 @@ def compute_entropy(logits):
|
|
| 365 |
return entropy.mean().item()
|
| 366 |
|
| 367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
# A confidence score is computed using FAISS and BM25 ranking
|
| 369 |
# FAISS: The similarity score between the response and the retrieved chunks are normalized
|
| 370 |
# BM25: The BM25 scores for the query and response combined tokens is normalized
|
|
@@ -375,12 +399,14 @@ def compute_response_confidence(
|
|
| 375 |
response,
|
| 376 |
retrieved_chunks,
|
| 377 |
bm25,
|
| 378 |
-
model_conf_signal
|
| 379 |
-
lambda_faiss=0.
|
| 380 |
-
lambda_conf=0.
|
| 381 |
-
lambda_bm25=1.
|
|
|
|
|
|
|
| 382 |
):
|
| 383 |
-
"""Calculates a confidence score
|
| 384 |
if not retrieved_chunks:
|
| 385 |
return 0.0
|
| 386 |
# Compute FAISS similarity
|
|
@@ -406,15 +432,24 @@ def compute_response_confidence(
|
|
| 406 |
normalized_bm25 = max(0, min(1, normalized_bm25))
|
| 407 |
else:
|
| 408 |
normalized_bm25 = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
logger.info(
|
| 410 |
-
f"Faiss score: {normalized_faiss},
|
| 411 |
-
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}"
|
|
|
|
| 412 |
)
|
| 413 |
# Weighted sum of all the normalized scores
|
| 414 |
confidence_score = (
|
| 415 |
lambda_faiss * normalized_faiss
|
| 416 |
+ model_conf_signal * lambda_conf
|
| 417 |
+ lambda_bm25 * normalized_bm25
|
|
|
|
|
|
|
| 418 |
)
|
| 419 |
return round(min(100, max(0, confidence_score.item() * 100)), 2)
|
| 420 |
|
|
|
|
| 24 |
FINANCIAL_ENTITY_LABELS,
|
| 25 |
GENERAL_KNOWLEDGE_PATTERNS,
|
| 26 |
sensitive_terms,
|
| 27 |
+
EXPLANATORY_PATTERNS,
|
| 28 |
FINANCIAL_TERMS,
|
| 29 |
)
|
| 30 |
|
|
|
|
| 267 |
return False
|
| 268 |
|
| 269 |
|
| 270 |
+
def get_latest_available_year(retrieved_chunks):
|
| 271 |
+
"""Extracts the latest available year from retrieved financial data"""
|
| 272 |
+
years = set()
|
| 273 |
+
year_pattern = r"\b(20\d{2})\b"
|
| 274 |
+
for chunk in retrieved_chunks:
|
| 275 |
+
years.update(map(int, re.findall(year_pattern, chunk)))
|
| 276 |
+
return max(years) if years else 2024
|
| 277 |
+
|
| 278 |
+
|
| 279 |
def is_irrelevant_query(query):
|
| 280 |
"""Check if the query is not finance related"""
|
| 281 |
# If the query is general knowledge and not finance-related
|
|
|
|
| 375 |
return entropy.mean().item()
|
| 376 |
|
| 377 |
|
| 378 |
+
def contains_future_year(query, retrieved_chunks):
|
| 379 |
+
"""Detects if the query asks for future data beyond available reports"""
|
| 380 |
+
latest_year = get_latest_available_year(retrieved_chunks)
|
| 381 |
+
# Extract years from query
|
| 382 |
+
future_years = set(map(int, re.findall(r"\b(20\d{2})\b", query)))
|
| 383 |
+
return any(year > latest_year for year in future_years)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def is_explanatory_query(query):
|
| 387 |
+
"""Checks if the query requires an explanation rather than factual data"""
|
| 388 |
+
query_lower = query.lower()
|
| 389 |
+
return any(re.search(pattern, query_lower) for pattern in EXPLANATORY_PATTERNS)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
# A confidence score is computed using FAISS and BM25 ranking
|
| 393 |
# FAISS: The similarity score between the response and the retrieved chunks are normalized
|
| 394 |
# BM25: The BM25 scores for the query and response combined tokens is normalized
|
|
|
|
| 399 |
response,
|
| 400 |
retrieved_chunks,
|
| 401 |
bm25,
|
| 402 |
+
model_conf_signal,
|
| 403 |
+
lambda_faiss=0.6,
|
| 404 |
+
lambda_conf=0.3,
|
| 405 |
+
lambda_bm25=1.0,
|
| 406 |
+
future_penalty=-0.3,
|
| 407 |
+
explanation_penalty=-0.2,
|
| 408 |
):
|
| 409 |
+
"""Calculates a confidence score for the model response"""
|
| 410 |
if not retrieved_chunks:
|
| 411 |
return 0.0
|
| 412 |
# Compute FAISS similarity
|
|
|
|
| 432 |
normalized_bm25 = max(0, min(1, normalized_bm25))
|
| 433 |
else:
|
| 434 |
normalized_bm25 = 0.0
|
| 435 |
+
# Penalize if query contains future years
|
| 436 |
+
future_penalty = -0.3 if contains_future_year(query, retrieved_chunks) else 0.0
|
| 437 |
+
# Penalize if query is reasoning based
|
| 438 |
+
explanation_penalty_value = (
|
| 439 |
+
explanation_penalty if is_explanatory_query(query) else 0.0
|
| 440 |
+
)
|
| 441 |
logger.info(
|
| 442 |
+
f"Faiss score: {normalized_faiss}, BM25: {normalized_bm25}\n"
|
| 443 |
+
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}\n"
|
| 444 |
+
f"Future penalty: {future_penalty}, Reasoning penalty: {explanation_penalty_value}"
|
| 445 |
)
|
| 446 |
# Weighted sum of all the normalized scores
|
| 447 |
confidence_score = (
|
| 448 |
lambda_faiss * normalized_faiss
|
| 449 |
+ model_conf_signal * lambda_conf
|
| 450 |
+ lambda_bm25 * normalized_bm25
|
| 451 |
+
+ future_penalty
|
| 452 |
+
+ explanation_penalty_value
|
| 453 |
)
|
| 454 |
return round(min(100, max(0, confidence_score.item() * 100)), 2)
|
| 455 |
|
data_filters.py
CHANGED
|
@@ -48,6 +48,10 @@ sensitive_terms = {
|
|
| 48 |
"wages",
|
| 49 |
}
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
FINANCIAL_DATA_PATTERNS = (
|
| 53 |
r"\b(\₹?\s?\d{1,3}(?:,\d{2,3})*(?:\.\d+)?\s*(million|billion|crore|lakh|%)"
|
|
|
|
| 48 |
"wages",
|
| 49 |
}
|
| 50 |
|
| 51 |
+
EXPLANATORY_PATTERNS = [
|
| 52 |
+
r"\b(why|reason|cause|explanation|due to|because|factor|impact of|effect of|influence of|driven by)\b",
|
| 53 |
+
r"\b(how did|what led to|what caused|why did|how was|contributing factor|explain)\b",
|
| 54 |
+
]
|
| 55 |
|
| 56 |
FINANCIAL_DATA_PATTERNS = (
|
| 57 |
r"\b(\₹?\s?\d{1,3}(?:,\d{2,3})*(?:\.\d+)?\s*(million|billion|crore|lakh|%)"
|