ML-Chatbot / source_eval_sweep.py
kmanche4675
chore: clean up repo, add benchmark logs, and ignore dev scripts
333621f
Raw
History Blame Contribute Delete
4.11 kB
import os
import pandas as pd
import re
import json
from app import rag_reply, llm
# --- CONFIG ---
GOLD_FILE = "gold.csv"
SOURCES_FILE = "sources.csv"
OUTPUT_LOG = "source_accuracy_report-llama.jsonl"
def get_id_from_filename(filename):
"""Standardizes a filename to an ID (e.g. 'S42- Paper.pdf' -> 'S42')."""
if not isinstance(filename, str): return str(filename)
match = re.search(r'^(S\d+)', filename, re.IGNORECASE)
if match:
return match.group(1).upper()
return filename.strip().lower()
# --- INITIALIZE MAPPINGS ---
print("πŸ“Š Loading Source Mappings...")
sources_df = pd.read_csv(SOURCES_FILE)
# Create a robust lookup table: Filename -> S-Code
# This fixes the issue where gold.csv has long filenames but the AI outputs S-codes
filename_to_s_code = {}
for _, row in sources_df.iterrows():
fname = str(row['name']).strip().lower()
# Extract the numeric ID from the PAPER_xxx format
paper_id_raw = str(row['id'])
numeric_id = paper_id_raw.replace("PAPER_", "").lstrip("0")
if not numeric_id: numeric_id = "0"
s_code = f"S{numeric_id}"
filename_to_s_code[fname] = s_code
# Also map the literal Sxx code if it exists in the filename
s_prefix = get_id_from_filename(fname)
if s_prefix.startswith('S'):
filename_to_s_code[s_prefix.lower()] = s_code
def extract_sources_from_text(text):
"""Looks for [Sxx] codes using Regex."""
if not text: return set()
found_ids = set()
# Regex for S-codes (e.g. [S42] or S42)
codes = re.findall(r'\[?(S\d+)\]?', text, re.IGNORECASE)
for c in codes:
found_ids.add(c.upper())
return found_ids
# --- RUN EVALUATION ---
try:
gold_df = pd.read_csv(GOLD_FILE)
except Exception as e:
print(f"Error loading {GOLD_FILE}: {e}")
gold_df = pd.DataFrame()
results = []
current_model = getattr(llm, 'model_name', 'Unknown-Model')
client_url = str(getattr(llm.client, 'base_url', ''))
billing_info = "HF Credits ($57 Lab)" if "huggingface" in client_url else "Personal OpenAI Key"
print("="*40)
print(f"πŸ€– ACTIVE MODEL: {current_model}")
print(f"πŸ’³ BILLING FROM: {billing_info}")
print("="*40)
for index, row in gold_df.iterrows():
question = row['question']
# Parse Expected Sources from Gold and TRANSLATE them to S-Codes
true_source_files = [s.strip().lower() for s in str(row['relevant_docs']).split(';')]
true_source_s_codes = set()
for f in true_source_files:
# Try direct filename match
if f in filename_to_s_code:
true_source_s_codes.add(filename_to_s_code[f])
else:
# Try matching the S-prefix if it has one
prefix = get_id_from_filename(f).lower()
if prefix in filename_to_s_code:
true_source_s_codes.add(filename_to_s_code[prefix])
else:
true_source_s_codes.add(get_id_from_filename(f)) # Fallback
n = len(true_source_s_codes)
print(f"[{index+1}/{len(gold_df)}] Testing: {question[:60]}...")
# Get AI response
ai_response = rag_reply(question)
# Extract using the new logic
cited_ids = extract_sources_from_text(ai_response)
# Calculate intersection based on the standardized S-codes
hits = true_source_s_codes.intersection(cited_ids)
j = len(hits)
score = j / n if n > 0 else 0
log_entry = {
"id": index + 1,
"model_used": current_model,
"billing": billing_info,
"question": question,
"expected_sources": list(true_source_s_codes),
"ai_cited_sources": list(cited_ids),
"hits": list(hits),
"hit_rate": f"{j}/{n}",
"score": round(score, 4)
}
results.append(log_entry)
with open(OUTPUT_LOG, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry) + "\n")
# --- SUMMARY ---
avg_recall = sum([r['score'] for r in results]) / len(results) if results else 0
print("\n" + "="*40)
print(f"πŸ† SOURCE RECALL: {avg_recall:.2%}")
print(f"πŸ“ Log: {OUTPUT_LOG}")
print("="*40)