Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import re | |
| import json | |
| from app import rag_reply, llm | |
| # --- CONFIG --- | |
| GOLD_FILE = "gold.csv" | |
| SOURCES_FILE = "sources.csv" | |
| OUTPUT_LOG = "source_accuracy_report-llama.jsonl" | |
| def get_id_from_filename(filename): | |
| """Standardizes a filename to an ID (e.g. 'S42- Paper.pdf' -> 'S42').""" | |
| if not isinstance(filename, str): return str(filename) | |
| match = re.search(r'^(S\d+)', filename, re.IGNORECASE) | |
| if match: | |
| return match.group(1).upper() | |
| return filename.strip().lower() | |
| # --- INITIALIZE MAPPINGS --- | |
| print("π Loading Source Mappings...") | |
| sources_df = pd.read_csv(SOURCES_FILE) | |
| # Create a robust lookup table: Filename -> S-Code | |
| # This fixes the issue where gold.csv has long filenames but the AI outputs S-codes | |
| filename_to_s_code = {} | |
| for _, row in sources_df.iterrows(): | |
| fname = str(row['name']).strip().lower() | |
| # Extract the numeric ID from the PAPER_xxx format | |
| paper_id_raw = str(row['id']) | |
| numeric_id = paper_id_raw.replace("PAPER_", "").lstrip("0") | |
| if not numeric_id: numeric_id = "0" | |
| s_code = f"S{numeric_id}" | |
| filename_to_s_code[fname] = s_code | |
| # Also map the literal Sxx code if it exists in the filename | |
| s_prefix = get_id_from_filename(fname) | |
| if s_prefix.startswith('S'): | |
| filename_to_s_code[s_prefix.lower()] = s_code | |
| def extract_sources_from_text(text): | |
| """Looks for [Sxx] codes using Regex.""" | |
| if not text: return set() | |
| found_ids = set() | |
| # Regex for S-codes (e.g. [S42] or S42) | |
| codes = re.findall(r'\[?(S\d+)\]?', text, re.IGNORECASE) | |
| for c in codes: | |
| found_ids.add(c.upper()) | |
| return found_ids | |
| # --- RUN EVALUATION --- | |
| try: | |
| gold_df = pd.read_csv(GOLD_FILE) | |
| except Exception as e: | |
| print(f"Error loading {GOLD_FILE}: {e}") | |
| gold_df = pd.DataFrame() | |
| results = [] | |
| current_model = getattr(llm, 'model_name', 'Unknown-Model') | |
| client_url = str(getattr(llm.client, 'base_url', '')) | |
| billing_info = "HF Credits ($57 Lab)" if "huggingface" in client_url else "Personal OpenAI Key" | |
| print("="*40) | |
| print(f"π€ ACTIVE MODEL: {current_model}") | |
| print(f"π³ BILLING FROM: {billing_info}") | |
| print("="*40) | |
| for index, row in gold_df.iterrows(): | |
| question = row['question'] | |
| # Parse Expected Sources from Gold and TRANSLATE them to S-Codes | |
| true_source_files = [s.strip().lower() for s in str(row['relevant_docs']).split(';')] | |
| true_source_s_codes = set() | |
| for f in true_source_files: | |
| # Try direct filename match | |
| if f in filename_to_s_code: | |
| true_source_s_codes.add(filename_to_s_code[f]) | |
| else: | |
| # Try matching the S-prefix if it has one | |
| prefix = get_id_from_filename(f).lower() | |
| if prefix in filename_to_s_code: | |
| true_source_s_codes.add(filename_to_s_code[prefix]) | |
| else: | |
| true_source_s_codes.add(get_id_from_filename(f)) # Fallback | |
| n = len(true_source_s_codes) | |
| print(f"[{index+1}/{len(gold_df)}] Testing: {question[:60]}...") | |
| # Get AI response | |
| ai_response = rag_reply(question) | |
| # Extract using the new logic | |
| cited_ids = extract_sources_from_text(ai_response) | |
| # Calculate intersection based on the standardized S-codes | |
| hits = true_source_s_codes.intersection(cited_ids) | |
| j = len(hits) | |
| score = j / n if n > 0 else 0 | |
| log_entry = { | |
| "id": index + 1, | |
| "model_used": current_model, | |
| "billing": billing_info, | |
| "question": question, | |
| "expected_sources": list(true_source_s_codes), | |
| "ai_cited_sources": list(cited_ids), | |
| "hits": list(hits), | |
| "hit_rate": f"{j}/{n}", | |
| "score": round(score, 4) | |
| } | |
| results.append(log_entry) | |
| with open(OUTPUT_LOG, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(log_entry) + "\n") | |
| # --- SUMMARY --- | |
| avg_recall = sum([r['score'] for r in results]) / len(results) if results else 0 | |
| print("\n" + "="*40) | |
| print(f"π SOURCE RECALL: {avg_recall:.2%}") | |
| print(f"π Log: {OUTPUT_LOG}") | |
| print("="*40) |