mohmad017's picture
Multi-Agent Research Assistant — LangGraph + FAISS + RAG + Evaluation
4619ed7
Raw
History Blame Contribute Delete
5.07 kB
import json
import os
import sys
import time
import uuid
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from dotenv import load_dotenv
load_dotenv()
def score_faithfulness(question: str, answer: str, context: str, llm) -> float:
prompt = f"""Rate from 0 to 10 how factually consistent this AI answer is with the reference.
10 = all facts match, 5 = some facts match, 0 = contradicts reference.
Only output a single integer number, nothing else.
Reference: {context[:400]}
AI Answer: {answer[:400]}"""
try:
result = llm.invoke(prompt)
import re
nums = re.findall(r'\d+', result.content.strip())
score = int(nums[0]) / 10.0 if nums else 0.5
return min(1.0, max(0.0, score))
except Exception:
return 0.5
def score_relevancy(question: str, answer: str, llm) -> float:
prompt = f"""Rate from 0 to 10 how well this answer addresses the question.
10 = completely answers it, 5 = partially, 0 = off topic.
Only output a single integer number, nothing else.
Question: {question}
Answer: {answer[:400]}"""
try:
result = llm.invoke(prompt)
import re
nums = re.findall(r'\d+', result.content.strip())
score = int(nums[0]) / 10.0 if nums else 0.5
return min(1.0, max(0.0, score))
except Exception:
return 0.5
def run_evaluation(qa_path="tests/qa_pairs.json", sample=None):
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage, AIMessage
from src.agents.graph import get_graph
import src.agents.graph as g
print("\n" + "="*50)
print("Evaluation Pipeline")
print("="*50)
with open(qa_path) as f:
qa_pairs = json.load(f)
if sample:
qa_pairs = qa_pairs[:sample]
print(f"\nEvaluating {len(qa_pairs)} questions...\n")
graph = get_graph()
# clear stale FAISS stores and disable web search
g._THREAD_RETRIEVERS.clear()
g._THREAD_META.clear()
g.search_tool = None
llm = ChatGroq(
model="llama-3.1-8b-instant",
api_key=os.getenv("GROQ_API_KEY"),
temperature=0,
)
faithfulness_scores = []
relevancy_scores = []
latencies = []
for i, pair in enumerate(qa_pairs):
q = pair["question"]
gt = pair.get("ground_truth", "")
thread_id = f"eval_{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": thread_id}}
t0 = time.perf_counter()
try:
result = graph.invoke(
{"messages": [HumanMessage(content=q)]},
config=config,
)
last_ai = next(
(m for m in reversed(result["messages"]) if isinstance(m, AIMessage)), None
)
answer = last_ai.content if last_ai else ""
except Exception as e:
answer = ""
latency = (time.perf_counter() - t0) * 1000
latencies.append(latency)
faith = score_faithfulness(q, answer, gt, llm)
relevancy = score_relevancy(q, answer, llm)
faithfulness_scores.append(faith)
relevancy_scores.append(relevancy)
print(f" [{i+1:2d}/{len(qa_pairs)}] F:{faith:.2f} R:{relevancy:.2f} {latency:.0f}ms | {q[:45]}")
faith_avg = sum(faithfulness_scores) / len(faithfulness_scores)
rel_avg = sum(relevancy_scores) / len(relevancy_scores)
latencies_sorted = sorted(latencies)
p50 = latencies_sorted[len(latencies_sorted) // 2]
p90 = latencies_sorted[int(len(latencies_sorted) * 0.9)]
report = {
"num_questions": len(qa_pairs),
"metrics": {
"faithfulness": round(faith_avg, 4),
"answer_relevancy": round(rel_avg, 4),
},
"latency_ms": {
"p50": round(p50, 1),
"p90": round(p90, 1),
},
"targets_met": {
"faithfulness_gt_085": faith_avg > 0.85,
"answer_relevancy_gt_080": rel_avg > 0.80,
"p90_lt_2000ms": p90 < 2000,
}
}
print("\n" + "="*50)
print("RESULTS")
print("="*50)
print(f"Faithfulness: {faith_avg:.4f} {'✅' if faith_avg > 0.85 else '❌'} (target >0.85)")
print(f"Answer Relevancy: {rel_avg:.4f} {'✅' if rel_avg > 0.80 else '❌'} (target >0.80)")
print(f"P50 Latency: {p50:.0f}ms")
print(f"P90 Latency: {p90:.0f}ms {'✅' if p90 < 2000 else '❌'} (target <2000ms)")
print(f"\nQuestions tested: {len(qa_pairs)}")
with open("evaluation_report.json", "w") as f:
json.dump(report, f, indent=2)
print("\nSaved to evaluation_report.json")
return report
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--sample", type=int, default=None)
parser.add_argument("--qa", default="tests/qa_pairs.json")
args = parser.parse_args()
run_evaluation(qa_path=args.qa, sample=args.sample)