Spaces:
Sleeping
Sleeping
Gourisankar Padihary
commited on
Commit
·
f7c2fa3
1
Parent(s):
b1b2c27
Compute RMSE and AUCROC
Browse files- data/load_dataset.py +1 -1
- generator/compute_metrics.py +32 -7
- generator/compute_rmse_auc_roc_metrics.py +101 -0
- generator/extract_attributes.py +10 -3
- generator/initialize_llm.py +5 -0
- main.py +11 -21
data/load_dataset.py
CHANGED
|
@@ -5,5 +5,5 @@ def load_data():
|
|
| 5 |
logging.info("Loading dataset")
|
| 6 |
dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
|
| 7 |
logging.info("Dataset loaded successfully")
|
| 8 |
-
logging.info(dataset)
|
| 9 |
return dataset
|
|
|
|
| 5 |
logging.info("Loading dataset")
|
| 6 |
dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
|
| 7 |
logging.info("Dataset loaded successfully")
|
| 8 |
+
logging.info(f"Number of documents found: {dataset.num_rows}")
|
| 9 |
return dataset
|
generator/compute_metrics.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def compute_metrics(attributes, total_sentences):
|
| 2 |
# Extract relevant information from attributes
|
| 3 |
all_relevant_sentence_keys = attributes.get("all_relevant_sentence_keys", [])
|
|
@@ -8,17 +11,39 @@ def compute_metrics(attributes, total_sentences):
|
|
| 8 |
context_relevance = len(all_relevant_sentence_keys) / total_sentences if total_sentences else 0
|
| 9 |
|
| 10 |
# Compute Context Utilization
|
| 11 |
-
context_utilization = len(all_utilized_sentence_keys) /
|
| 12 |
-
|
| 13 |
-
# Compute Completeness
|
| 14 |
-
completeness = all(info.get("fully_supported", False) for info in sentence_support_information)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# Compute Adherence
|
| 17 |
-
adherence =
|
| 18 |
|
| 19 |
return {
|
| 20 |
"Context Relevance": context_relevance,
|
| 21 |
"Context Utilization": context_utilization,
|
| 22 |
-
"Completeness":
|
| 23 |
"Adherence": adherence
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
def compute_metrics(attributes, total_sentences):
|
| 5 |
# Extract relevant information from attributes
|
| 6 |
all_relevant_sentence_keys = attributes.get("all_relevant_sentence_keys", [])
|
|
|
|
| 11 |
context_relevance = len(all_relevant_sentence_keys) / total_sentences if total_sentences else 0
|
| 12 |
|
| 13 |
# Compute Context Utilization
|
| 14 |
+
context_utilization = len(all_utilized_sentence_keys) / total_sentences if total_sentences else 0
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# Compute Completeness score
|
| 17 |
+
Ri = set(all_relevant_sentence_keys)
|
| 18 |
+
Ui = set(all_utilized_sentence_keys)
|
| 19 |
+
|
| 20 |
+
completeness_score = len(Ri & Ui) / len(Ri) if len(Ri) else 0
|
| 21 |
+
|
| 22 |
# Compute Adherence
|
| 23 |
+
adherence = all(info.get("fully_supported", False) for info in sentence_support_information)
|
| 24 |
|
| 25 |
return {
|
| 26 |
"Context Relevance": context_relevance,
|
| 27 |
"Context Utilization": context_utilization,
|
| 28 |
+
"Completeness Score": completeness_score,
|
| 29 |
"Adherence": adherence
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def get_metrics(attributes, total_sentences):
|
| 33 |
+
if attributes.content:
|
| 34 |
+
result_content = attributes.content # Access the content attribute
|
| 35 |
+
# Extract the JSON part from the result_content
|
| 36 |
+
json_start = result_content.find("{")
|
| 37 |
+
json_end = result_content.rfind("}") + 1
|
| 38 |
+
json_str = result_content[json_start:json_end]
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
result_json = json.loads(json_str)
|
| 42 |
+
print(json.dumps(result_json, indent=2))
|
| 43 |
+
|
| 44 |
+
# Compute metrics using the extracted attributes
|
| 45 |
+
metrics = compute_metrics(result_json, total_sentences)
|
| 46 |
+
print(metrics)
|
| 47 |
+
return metrics
|
| 48 |
+
except json.JSONDecodeError as e:
|
| 49 |
+
logging.error(f"JSONDecodeError: {e}")
|
generator/compute_rmse_auc_roc_metrics.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from sklearn.metrics import roc_auc_score, root_mean_squared_error
|
| 3 |
+
from generator.compute_metrics import get_metrics
|
| 4 |
+
from generator.extract_attributes import extract_attributes
|
| 5 |
+
from generator.generate_response import generate_response
|
| 6 |
+
from retriever.retrieve_documents import retrieve_top_k_documents
|
| 7 |
+
|
| 8 |
+
def compute_rmse_auc_roc_metrics(llm, dataset, vector_store):
|
| 9 |
+
|
| 10 |
+
# Lists to accumulate ground truths and predictions for AUC-ROC computation
|
| 11 |
+
all_ground_truth_relevance = []
|
| 12 |
+
all_predicted_relevance = []
|
| 13 |
+
|
| 14 |
+
all_ground_truth_utilization = []
|
| 15 |
+
all_predicted_utilization = []
|
| 16 |
+
|
| 17 |
+
all_ground_truth_adherence = []
|
| 18 |
+
all_predicted_adherence = []
|
| 19 |
+
|
| 20 |
+
# To store RMSE scores for each question
|
| 21 |
+
relevance_scores = []
|
| 22 |
+
utilization_scores = []
|
| 23 |
+
adherence_scores = []
|
| 24 |
+
|
| 25 |
+
for i, sample in enumerate(dataset):
|
| 26 |
+
print(sample)
|
| 27 |
+
sample_question = sample['question']
|
| 28 |
+
|
| 29 |
+
# Extract ground truth metrics from dataset
|
| 30 |
+
ground_truth_relevance = dataset[i]['relevance_score']
|
| 31 |
+
ground_truth_utilization = dataset[i]['utilization_score']
|
| 32 |
+
ground_truth_completeness = dataset[i]['completeness_score']
|
| 33 |
+
|
| 34 |
+
# Step 1: Retrieve relevant documents
|
| 35 |
+
relevant_docs = retrieve_top_k_documents(vector_store, sample_question, top_k=5)
|
| 36 |
+
|
| 37 |
+
# Step 2: Generate a response using LLM
|
| 38 |
+
response, source_docs = generate_response(llm, vector_store, sample_question, relevant_docs)
|
| 39 |
+
|
| 40 |
+
# Step 3: Extract attributes
|
| 41 |
+
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
|
| 42 |
+
|
| 43 |
+
# Call the process_attributes method in the main block
|
| 44 |
+
metrics = get_metrics(attributes, total_sentences)
|
| 45 |
+
|
| 46 |
+
# Extract predicted metrics (ensure these are continuous if possible)
|
| 47 |
+
predicted_relevance = metrics['Context Relevance']
|
| 48 |
+
predicted_utilization = metrics['Context Utilization']
|
| 49 |
+
predicted_completeness = metrics['Completeness Score']
|
| 50 |
+
|
| 51 |
+
# === Handle Continuous Inputs for RMSE ===
|
| 52 |
+
relevance_rmse = root_mean_squared_error([ground_truth_relevance], [predicted_relevance])
|
| 53 |
+
utilization_rmse = root_mean_squared_error([ground_truth_utilization], [predicted_utilization])
|
| 54 |
+
#adherence_rmse = mean_squared_error([ground_truth_adherence], [predicted_adherence], squared=False)
|
| 55 |
+
|
| 56 |
+
# === Handle Binary Conversion for AUC-ROC ===
|
| 57 |
+
binary_ground_truth_relevance = 1 if ground_truth_relevance > 0.5 else 0
|
| 58 |
+
binary_predicted_relevance = 1 if predicted_relevance > 0.5 else 0
|
| 59 |
+
|
| 60 |
+
binary_ground_truth_utilization = 1 if ground_truth_utilization > 0.5 else 0
|
| 61 |
+
binary_predicted_utilization = 1 if predicted_utilization > 0.5 else 0
|
| 62 |
+
|
| 63 |
+
#binary_ground_truth_adherence = 1 if ground_truth_adherence > 0.5 else 0
|
| 64 |
+
#binary_predicted_adherence = 1 if predicted_adherence > 0.5 else 0
|
| 65 |
+
|
| 66 |
+
# === Accumulate data for overall AUC-ROC computation ===
|
| 67 |
+
all_ground_truth_relevance.append(binary_ground_truth_relevance)
|
| 68 |
+
all_predicted_relevance.append(predicted_relevance) # Use probability-based predictions
|
| 69 |
+
|
| 70 |
+
all_ground_truth_utilization.append(binary_ground_truth_utilization)
|
| 71 |
+
all_predicted_utilization.append(predicted_utilization)
|
| 72 |
+
|
| 73 |
+
#all_ground_truth_adherence.append(binary_ground_truth_adherence)
|
| 74 |
+
#all_predicted_adherence.append(predicted_adherence)
|
| 75 |
+
|
| 76 |
+
# Store RMSE scores for each question
|
| 77 |
+
relevance_scores.append(relevance_rmse)
|
| 78 |
+
utilization_scores.append(utilization_rmse)
|
| 79 |
+
#adherence_scores.append(adherence_rmse)
|
| 80 |
+
if i == 9: # Stop after processing the first 10 rows
|
| 81 |
+
break
|
| 82 |
+
# === Compute AUC-ROC for the Entire Dataset ===
|
| 83 |
+
try:
|
| 84 |
+
print(f"All Ground Truth Relevance: {all_ground_truth_relevance}")
|
| 85 |
+
print(f"All Predicted Relevance: {all_predicted_relevance}")
|
| 86 |
+
relevance_auc = roc_auc_score(all_ground_truth_relevance, all_predicted_relevance)
|
| 87 |
+
except ValueError:
|
| 88 |
+
relevance_auc = None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
print(f"All Ground Truth Utilization: {all_ground_truth_utilization}")
|
| 92 |
+
print(f"All Predicted Utilization: {all_predicted_utilization}")
|
| 93 |
+
utilization_auc = roc_auc_score(all_ground_truth_utilization, all_predicted_utilization)
|
| 94 |
+
except ValueError:
|
| 95 |
+
utilization_auc = None
|
| 96 |
+
|
| 97 |
+
print(f"Relevance RMSE (per question): {relevance_scores}")
|
| 98 |
+
print(f"Utilization RMSE (per question): {utilization_scores}")
|
| 99 |
+
#print(f"Adherence RMSE (per question): {adherence_scores}")
|
| 100 |
+
print(f"\nOverall Relevance AUC-ROC: {relevance_auc}")
|
| 101 |
+
print(f"Overall Utilization AUC-ROC: {utilization_auc}")
|
generator/extract_attributes.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
from generator.create_prompt import create_prompt
|
| 2 |
-
from generator.initialize_llm import
|
| 3 |
from generator.document_utils import Document, apply_sentence_keys_documents, apply_sentence_keys_response
|
| 4 |
|
| 5 |
# Initialize the LLM
|
| 6 |
-
llm =
|
| 7 |
|
| 8 |
# Function to extract attributes
|
| 9 |
def extract_attributes(question, relevant_docs, response):
|
|
@@ -12,9 +12,16 @@ def extract_attributes(question, relevant_docs, response):
|
|
| 12 |
formatted_documents = apply_sentence_keys_documents(relevant_docs)
|
| 13 |
formatted_responses = apply_sentence_keys_response(response)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# Calculate the total number of sentences from formatted_documents
|
| 16 |
total_sentences = sum(len(doc) for doc in formatted_documents)
|
| 17 |
-
|
|
|
|
| 18 |
attribute_prompt = create_prompt(formatted_documents, question, formatted_responses)
|
| 19 |
|
| 20 |
# Instead of using BaseMessage, pass the formatted prompt directly to invoke
|
|
|
|
| 1 |
from generator.create_prompt import create_prompt
|
| 2 |
+
from generator.initialize_llm import initialize_validation_llm
|
| 3 |
from generator.document_utils import Document, apply_sentence_keys_documents, apply_sentence_keys_response
|
| 4 |
|
| 5 |
# Initialize the LLM
|
| 6 |
+
llm = initialize_validation_llm()
|
| 7 |
|
| 8 |
# Function to extract attributes
|
| 9 |
def extract_attributes(question, relevant_docs, response):
|
|
|
|
| 12 |
formatted_documents = apply_sentence_keys_documents(relevant_docs)
|
| 13 |
formatted_responses = apply_sentence_keys_response(response)
|
| 14 |
|
| 15 |
+
#print(f"Formatted documents : {formatted_documents}")
|
| 16 |
+
# Print the number of sentences in each document
|
| 17 |
+
for i, doc in enumerate(formatted_documents):
|
| 18 |
+
num_sentences = len(doc)
|
| 19 |
+
print(f"Document {i} has {num_sentences} sentences.")
|
| 20 |
+
|
| 21 |
# Calculate the total number of sentences from formatted_documents
|
| 22 |
total_sentences = sum(len(doc) for doc in formatted_documents)
|
| 23 |
+
print(f"Total number of sentences {total_sentences}")
|
| 24 |
+
|
| 25 |
attribute_prompt = create_prompt(formatted_documents, question, formatted_responses)
|
| 26 |
|
| 27 |
# Instead of using BaseMessage, pass the formatted prompt directly to invoke
|
generator/initialize_llm.py
CHANGED
|
@@ -4,4 +4,9 @@ from langchain_groq import ChatGroq
|
|
| 4 |
def initialize_llm():
|
| 5 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
| 6 |
llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
return llm
|
|
|
|
| 4 |
def initialize_llm():
|
| 5 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
| 6 |
llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
|
| 7 |
+
return llm
|
| 8 |
+
|
| 9 |
+
def initialize_validation_llm():
|
| 10 |
+
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
| 11 |
+
llm = ChatGroq(model="llama3-70b-8192", temperature=0.7)
|
| 12 |
return llm
|
main.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
-
import logging
|
| 2 |
from data.load_dataset import load_data
|
|
|
|
| 3 |
from retriever.chunk_documents import chunk_documents
|
| 4 |
from retriever.embed_documents import embed_documents
|
| 5 |
from retriever.retrieve_documents import retrieve_top_k_documents
|
| 6 |
from generator.initialize_llm import initialize_llm
|
| 7 |
from generator.generate_response import generate_response
|
| 8 |
from generator.extract_attributes import extract_attributes
|
| 9 |
-
from generator.compute_metrics import
|
| 10 |
|
| 11 |
# Configure logging
|
| 12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -27,7 +28,8 @@ def main():
|
|
| 27 |
logging.info("Documents embedded")
|
| 28 |
|
| 29 |
# Sample question
|
| 30 |
-
|
|
|
|
| 31 |
logging.info(f"Sample question: {sample_question}")
|
| 32 |
|
| 33 |
# Retrieve relevant documents
|
|
@@ -52,23 +54,11 @@ def main():
|
|
| 52 |
# Valuations : Extract attributes from the response and source documents
|
| 53 |
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
json_str = result_content[json_start:json_end]
|
| 62 |
-
|
| 63 |
-
try:
|
| 64 |
-
result_json = json.loads(json_str)
|
| 65 |
-
print(json.dumps(result_json, indent=2))
|
| 66 |
-
|
| 67 |
-
# Compute metrics using the extracted attributes
|
| 68 |
-
metrics = compute_metrics(result_json, total_sentences)
|
| 69 |
-
print(metrics)
|
| 70 |
-
except json.JSONDecodeError as e:
|
| 71 |
-
logging.error(f"JSONDecodeError: {e}")
|
| 72 |
-
|
| 73 |
if __name__ == "__main__":
|
| 74 |
main()
|
|
|
|
| 1 |
+
import logging
|
| 2 |
from data.load_dataset import load_data
|
| 3 |
+
from generator import compute_rmse_auc_roc_metrics
|
| 4 |
from retriever.chunk_documents import chunk_documents
|
| 5 |
from retriever.embed_documents import embed_documents
|
| 6 |
from retriever.retrieve_documents import retrieve_top_k_documents
|
| 7 |
from generator.initialize_llm import initialize_llm
|
| 8 |
from generator.generate_response import generate_response
|
| 9 |
from generator.extract_attributes import extract_attributes
|
| 10 |
+
from generator.compute_metrics import get_metrics
|
| 11 |
|
| 12 |
# Configure logging
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 28 |
logging.info("Documents embedded")
|
| 29 |
|
| 30 |
# Sample question
|
| 31 |
+
row_num = 1
|
| 32 |
+
sample_question = dataset[row_num]['question']
|
| 33 |
logging.info(f"Sample question: {sample_question}")
|
| 34 |
|
| 35 |
# Retrieve relevant documents
|
|
|
|
| 54 |
# Valuations : Extract attributes from the response and source documents
|
| 55 |
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
|
| 56 |
|
| 57 |
+
# Call the process_attributes method in the main block
|
| 58 |
+
metrics = get_metrics(attributes, total_sentences)
|
| 59 |
+
|
| 60 |
+
#Compute RMSE and AUC-ROC for entire dataset
|
| 61 |
+
#compute_rmse_auc_roc_metrics(llm, dataset, vector_store)
|
| 62 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
if __name__ == "__main__":
|
| 64 |
main()
|