Spaces:
Paused
Paused
| # -*- coding: utf-8 -*- | |
| import os | |
| import re | |
| import time | |
| import logging | |
| from gradio_client import Client | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| confusion_matrix, | |
| classification_report, | |
| precision_recall_fscore_support | |
| ) | |
| import pandas as pd # Optional: if loading data from file | |
| # --- Configuration --- | |
| # Option 1: Hardcode your Space ID/URL | |
| # SPACE_ID = "your-username/your-space-name" | |
| # Option 2: Get from environment variable (useful if running script ON the space or elsewhere with env set) | |
| SPACE_ID = "rkostov/thesis-agent" # Add a default placeholder | |
| # Option 3: Use full URL if needed | |
| # SPACE_URL = "https://your-username-your-space-name.hf.space" | |
| API_NAME = "/respond" # From view_api output | |
| NUM_RESULTS = 3 # Default value for the slider input | |
| SLEEP_BETWEEN_CALLS = 1 # Seconds to wait to avoid rate limiting | |
| # --- Logging Setup --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger('evaluation_script') | |
| # --- Benchmark Dataset --- | |
| # Full list of ~100 queries with intended routing targets | |
| # You can modify/expand this list or load from a file (CSV/JSON) | |
| benchmark_data = [ | |
| # === RAG - Specific Questions === | |
| # Ingredients & Quantities | |
| {'query': 'Does spaghetti carbonara use cream?', 'intended_target': 'RAG'}, | |
| {'query': 'What kind of cheese is in the Greek Salad recipe?', 'intended_target': 'RAG'}, | |
| {'query': 'Is there onion in the banana bread?', 'intended_target': 'RAG'}, | |
| {'query': 'List ingredients for chocolate chip cookies.', 'intended_target': 'RAG'}, | |
| {'query': 'How much butter for the chocolate chip cookies?', 'intended_target': 'RAG'}, | |
| {'query': 'Does the stir fry recipe contain peanuts?', 'intended_target': 'RAG'}, | |
| {'query': 'What oil is recommended for the stir fry?', 'intended_target': 'RAG'}, | |
| {'query': 'Are eggs required for the carbonara?', 'intended_target': 'RAG'}, | |
| {'query': 'Tell me the spices in the default chicken recipe.', 'intended_target': 'RAG'}, | |
| {'query': 'Any garlic in the greek salad?', 'intended_target': 'RAG'}, | |
| {'query': 'What type of flour is used in the banana bread?', 'intended_target': 'RAG'}, | |
| {'query': 'How many eggs in the carbonara?', 'intended_target': 'RAG'}, | |
| {'query': 'Does the banana bread use baking soda or baking powder?', 'intended_target': 'RAG'}, | |
| {'query': 'Are fresh tomatoes needed for the greek salad?', 'intended_target': 'RAG'}, | |
| {'query': 'What cut of chicken for the stir fry?', 'intended_target': 'RAG'}, | |
| # Instructions & Timing | |
| {'query': 'How long do I bake the chocolate chip cookies?', 'intended_target': 'RAG'}, | |
| {'query': 'What temperature to bake cookies?', 'intended_target': 'RAG'}, | |
| {'query': 'What is the first step for the chicken stir fry?', 'intended_target': 'RAG'}, | |
| {'query': 'How do you make the dressing for the Greek Salad?', 'intended_target': 'RAG'}, | |
| {'query': 'Tell me how to cook spaghetti carbonara.', 'intended_target': 'RAG'}, | |
| {'query': 'Summarize the banana bread instructions.', 'intended_target': 'RAG'}, | |
| {'query': 'How many steps are there to make the cookies?', 'intended_target': 'RAG'}, | |
| {'query': 'What do I do after frying the pancetta in carbonara?', 'intended_target': 'RAG'}, | |
| {'query': 'How should I prepare the vegetables for the stir fry?', 'intended_target': 'RAG'}, | |
| {'query': "What's the final step for the Greek salad?", 'intended_target': 'RAG'}, | |
| {'query': 'How long does the banana bread need to cool?', 'intended_target': 'RAG'}, | |
| {'query': 'At what point are the chocolate chips added?', 'intended_target': 'RAG'}, | |
| {'query': 'How long to cook the chicken in the stir fry?', 'intended_target': 'RAG'}, | |
| {'query': 'When is the pasta water used in carbonara?', 'intended_target': 'RAG'}, | |
| {'query': 'Should the feta be crumbled or cubed for the salad?', 'intended_target': 'RAG'}, | |
| # Properties/Suitability | |
| {'query': 'Is the Greek Salad vegetarian?', 'intended_target': 'RAG'}, | |
| {'query': 'Are the chocolate chip cookies gluten-free?', 'intended_target': 'RAG'}, | |
| {'query': 'Is the banana bread recipe vegan?', 'intended_target': 'RAG'}, | |
| {'query': 'Can the carbonara be made ahead of time?', 'intended_target': 'RAG'}, | |
| {'query': 'Is the chicken stir fry spicy?', 'intended_target': 'RAG'}, | |
| {'query': 'Approximate prep time for banana bread?', 'intended_target': 'RAG'}, | |
| {'query': 'Which of the backup recipes are vegetarian?', 'intended_target': 'RAG'}, | |
| {'query': 'Difficulty level of the carbonara?', 'intended_target': 'RAG'}, | |
| {'query': 'Does the cookie recipe yield many cookies?', 'intended_target': 'RAG'}, | |
| {'query': 'Is the stir fry low-carb?', 'intended_target': 'RAG'}, | |
| # Technique/Tools | |
| {'query': 'How do I cream butter and sugar?', 'intended_target': 'RAG'}, | |
| {'query': "What does 'fold in' mean for the banana bread?", 'intended_target': 'RAG'}, | |
| {'query': 'What pan size for the banana bread?', 'intended_target': 'RAG'}, | |
| {'query': 'Do I need a whisk for the carbonara?', 'intended_target': 'RAG'}, | |
| {'query': "What does 'tender-crisp' mean for stir fry vegetables?", 'intended_target': 'RAG'}, | |
| {'query': 'How to mash bananas properly?', 'intended_target': 'RAG'}, | |
| {'query': 'What kind of pan for stir fry?', 'intended_target': 'RAG'}, | |
| {'query': 'How to chop an onion for the salad?', 'intended_target': 'RAG'}, | |
| {'query': "What does 'al dente' mean for spaghetti?", 'intended_target': 'RAG'}, | |
| {'query': 'Why mix wet and dry ingredients separately for cookies?', 'intended_target': 'RAG'}, | |
| # === Text Search - General Queries === | |
| # Recipe Name | |
| {'query': 'Spaghetti Carbonara', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Easy Banana Bread', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Chicken Stir Fry', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Greek Salad', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Chocolate Chip Cookies', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Recipe for carbonara', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Show me banana bread', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'cookies', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'salad', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'pasta', 'intended_target': 'TEXT_SEARCH'}, | |
| # Main Ingredient(s) | |
| {'query': 'recipes with chicken breast', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'broccoli soup', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'something with eggs and pancetta', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Find recipes using feta cheese.', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'pasta with eggs', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'banana recipes', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'cookies with chocolate', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'salad with olives', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'dinner with chicken', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'recipes using ripe bananas', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'find recipes with bell peppers', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Pecorino Romano recipes', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'What can I make with butter and sugar?', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Search for recipes with cucumber', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Got extra eggs, what can I make?', 'intended_target': 'TEXT_SEARCH'}, | |
| # Meal Type/Descriptor | |
| {'query': 'quick weeknight dinner', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'healthy dessert', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'vegetarian main course', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'party appetizer', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'easy baking recipes', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'low carb meals', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'comfort food', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'salad recipes', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'budget friendly ideas', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'simple lunch', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'italian pasta', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'something sweet', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'savory dishes', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'recipes for beginners', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': '30 minute meals', 'intended_target': 'TEXT_SEARCH'}, | |
| # === Ambiguous Queries (Assigning a default target for evaluation) === | |
| {'query': 'ingredients for healthy vegetarian soup', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'how to make vegetarian lasagna', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'best chocolate chip cookie recipe', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'carbonara no cream', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'information about banana bread', 'intended_target': 'RAG'}, | |
| {'query': 'Greek salad dressing instructions', 'intended_target': 'RAG'}, | |
| {'query': 'quick vegetarian pasta', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'tell me about stir fry', 'intended_target': 'RAG'}, | |
| {'query': 'carbonara recipe details', 'intended_target': 'RAG'}, | |
| {'query': 'cookie variations', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Can you find a low-sugar banana bread?', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'What are some salads with cucumber?', 'intended_target': 'TEXT_SEARCH'}, | |
| {'query': 'Talk me through the carbonara recipe', 'intended_target': 'RAG'}, | |
| {'query': 'Nutritional info for cookies', 'intended_target': 'RAG'}, # RAG likely to fail gracefully | |
| {'query': 'Compare carbonara and stir fry', 'intended_target': 'RAG'}, # RAG likely to fail gracefully | |
| # === Edge Cases === | |
| {'query': 'choclate chip cookis', 'intended_target': 'TEXT_SEARCH'}, # Misspelling | |
| {'query': 'soup', 'intended_target': 'TEXT_SEARCH'}, # Broad | |
| {'query': 'Does any recipe use saffron?', 'intended_target': 'RAG'}, # Likely Out of Scope Ingredient | |
| {'query': 'asdfghjkl', 'intended_target': 'TEXT_SEARCH'}, # Nonsense | |
| {'query': 'tell me a joke about cooking', 'intended_target': 'RAG'} # Out of scope Topic | |
| ] | |
| # --- End Benchmark Dataset --- | |
| # --- Helper function to extract routing decision --- | |
| def extract_routing_decision(response_content): | |
| if not isinstance(response_content, str): | |
| return "PARSE_ERROR" # Handle non-string content | |
| # Pattern to find Router=VALUE within the debug string `DEBUG: Router=VALUE,...` | |
| pattern = r"Router=([^,`]+)" | |
| match = re.search(pattern, response_content) | |
| if match: | |
| decision = match.group(1).strip() | |
| if decision in ["RAG", "TEXT_SEARCH"]: | |
| return decision | |
| else: | |
| logger.warning(f"Parsed unexpected decision value: {decision}") | |
| return "PARSE_ERROR" # Unexpected value | |
| else: | |
| # Check if Method= only is present (e.g. from text search helper) | |
| method_pattern = r"Method=([^`]+)" | |
| method_match = re.search(method_pattern, response_content) | |
| if method_match: | |
| method_used = method_match.group(1).strip() | |
| # Infer routing based on method if Router= missing | |
| if "text (router chosen)" in method_used: | |
| logger.warning("Router= missing, inferred TEXT_SEARCH from method.") | |
| return "TEXT_SEARCH" | |
| elif "text (RAG fallback)" in method_used: | |
| logger.warning("Router= missing, inferred RAG (fallback) from method.") | |
| return "RAG" # It was intended RAG, even if it failed | |
| elif "vector (RAG executed)" in method_used: | |
| logger.warning("Router= missing, inferred RAG from method.") | |
| return "RAG" | |
| logger.warning(f"Could not parse routing decision or infer from method in response.") | |
| return "PARSE_ERROR" # Pattern not found | |
| # --- Main Evaluation Logic --- | |
| if not SPACE_ID: | |
| logger.error("Error: SPACE_ID not configured. Set the environment variable or hardcode it.") | |
| exit() | |
| logger.info(f"Connecting to Gradio Space: {SPACE_ID}") | |
| try: | |
| # Increase timeout if needed client = Client(SPACE_ID, hf_token=...) | |
| client = Client(SPACE_ID) | |
| logger.info("Connection successful.") | |
| except Exception as e: | |
| logger.error(f"Failed to connect to Gradio Space: {e}") | |
| exit() | |
| # Lists to store labels and predictions | |
| y_true = [] # Your manual labels ('intended_target') | |
| y_pred = [] # Agent's actual routing decisions | |
| logger.info(f"Starting evaluation for {len(benchmark_data)} queries...") | |
| for i, item in enumerate(benchmark_data): | |
| query = item['query'] | |
| intended_target = item['intended_target'] | |
| logger.info(f"Processing query {i+1}/{len(benchmark_data)}: '{query}' (Expected: {intended_target})") | |
| actual_decision = "API_ERROR" # Default if API call fails | |
| try: | |
| # Make the API call (stateless) - Requires API to accept only message & num_results | |
| result = client.predict( | |
| message=query, | |
| num_results_value=NUM_RESULTS, | |
| api_name=API_NAME # Use "/respond" | |
| # No chat_history argument here due to API bug | |
| ) | |
| # Process the result | |
| # Expected result (based on corrected respond): tuple (chat_history_list, "") | |
| if isinstance(result, tuple) and len(result) == 2 and isinstance(result[0], list) and result[0]: | |
| # Get the last message added (should be the assistant's response) | |
| last_message = result[0][-1] | |
| if isinstance(last_message, dict) and last_message.get("role") == "assistant": | |
| bot_content = last_message.get("content") | |
| actual_decision = extract_routing_decision(bot_content) # Parse debug info | |
| else: | |
| logger.warning(f"Unexpected structure in last message: {last_message}") | |
| actual_decision = "PARSE_ERROR" | |
| elif result is None: | |
| logger.error(f"API call for query '{query}' returned None.") | |
| actual_decision = "API_NONE_RETURN" | |
| else: | |
| logger.warning(f"Unexpected API result structure: {type(result)} | Content: {result}") | |
| actual_decision = "API_STRUCT_ERROR" | |
| except Exception as e: | |
| logger.error(f"API call failed for query '{query}': {e}") | |
| actual_decision = "API_ERROR" | |
| # Append results, ensuring labels are consistent | |
| y_true.append(intended_target) | |
| y_pred.append(actual_decision) | |
| logger.info(f" -> Actual Decision: {actual_decision}") | |
| # Wait briefly to avoid hitting potential rate limits on free Spaces | |
| time.sleep(SLEEP_BETWEEN_CALLS) | |
| logger.info("Evaluation loop finished.") | |
| # --- Calculate and Print Metrics --- | |
| logger.info("\n--- Evaluation Results ---") | |
| # Define the valid labels we expect the parser to return | |
| valid_labels = ['RAG', 'TEXT_SEARCH'] | |
| filtered_y_true = [] | |
| filtered_y_pred = [] | |
| # Tally errors | |
| error_codes = ["API_ERROR", "PARSE_ERROR", "API_STRUCT_ERROR", "API_NONE_RETURN"] | |
| error_counts = {code: 0 for code in error_codes} | |
| unknown_preds = [] | |
| for true_label, pred_label in zip(y_true, y_pred): | |
| if pred_label in valid_labels: | |
| filtered_y_true.append(true_label) | |
| filtered_y_pred.append(pred_label) | |
| elif pred_label in error_counts: | |
| error_counts[pred_label] += 1 | |
| else: # Catch any unexpected prediction labels | |
| logger.error(f"Encountered unexpected predicted label: {pred_label} for true label: {true_label}") | |
| unknown_preds.append(pred_label) | |
| total_processed = len(filtered_y_true) | |
| total_errors = sum(error_counts.values()) | |
| logger.info(f"Total Queries Run: {len(benchmark_data)}") | |
| logger.info(f"Successfully Parsed Predictions: {total_processed}") | |
| logger.info(f"API/Parse Errors: {total_errors}") | |
| for code, count in error_counts.items(): | |
| if count > 0: logger.info(f" - {code}: {count}") | |
| if unknown_preds: logger.warning(f"Unknown predicted labels encountered: {set(unknown_preds)}") | |
| if total_processed > 0: | |
| # Overall Accuracy | |
| accuracy = accuracy_score(filtered_y_true, filtered_y_pred) | |
| logger.info(f"\nOverall Routing Accuracy (on {total_processed} successful predictions): {accuracy:.2%}") | |
| # Confusion Matrix | |
| logger.info("\nConfusion Matrix (Rows: Actual/Intended, Columns: Predicted by Agent):") | |
| # Ensure consistent labeling for the matrix | |
| cm = confusion_matrix(filtered_y_true, filtered_y_pred, labels=valid_labels) | |
| logger.info(f"Labels: {valid_labels}") | |
| # Print matrix with labels | |
| cm_df = pd.DataFrame(cm, index=[f'Actual_{l}' for l in valid_labels], columns=[f'Predicted_{l}' for l in valid_labels]) | |
| logger.info(f"\n{cm_df}\n") | |
| # Explanation (assuming RAG=0, TEXT_SEARCH=1) -> Use labels instead | |
| logger.info(f"TN (Actual RAG, Predicted RAG): {cm[0][0]}") | |
| logger.info(f"FP (Actual RAG, Predicted TEXT_SEARCH): {cm[0][1]}") | |
| logger.info(f"FN (Actual TEXT_SEARCH, Predicted RAG): {cm[1][0]}") | |
| logger.info(f"TP (Actual TEXT_SEARCH, Predicted TEXT_SEARCH): {cm[1][1]}") | |
| # Classification Report (Precision, Recall, F1 per class) | |
| logger.info("\nClassification Report:") | |
| # Use dict output for easier logging if needed, default string is fine too | |
| report = classification_report( | |
| filtered_y_true, | |
| filtered_y_pred, | |
| labels=valid_labels, | |
| target_names=valid_labels, | |
| zero_division=0 # Report 0 instead of warning for classes with no support/predictions | |
| ) | |
| logger.info(f"\n{report}") | |
| else: | |
| logger.warning("No successful predictions were parsed, cannot calculate metrics.") | |
| logger.info("--- Evaluation Complete ---") |