Spaces:
Sleeping
Sleeping
| """ | |
| Grading module for Curator environment. | |
| Implements standard Information Retrieval metrics for deterministic, | |
| reproducible scoring of agent performance (0.0-1.0). | |
| """ | |
| import math | |
| from typing import Dict, List, Optional | |
| def dcg_at_k(relevances: List[float], k: int) -> float: | |
| """Compute Discounted Cumulative Gain at k.""" | |
| dcg = 0.0 | |
| for i, rel in enumerate(relevances[:k]): | |
| dcg += rel / math.log2(i + 2) # i+2 because log2(1) = 0 | |
| return dcg | |
| def ndcg_at_k( | |
| ranked_ids: List[str], | |
| relevance_scores: Dict[str, float], | |
| k: int, | |
| ) -> float: | |
| """Compute Normalized Discounted Cumulative Gain at k. | |
| Args: | |
| ranked_ids: Agent's ranked list of item IDs (best first). | |
| relevance_scores: Ground truth {item_id: relevance} scores. | |
| k: Evaluate top-k items. | |
| Returns: | |
| NDCG score in [0, 1]. | |
| """ | |
| if not ranked_ids or not relevance_scores or k <= 0: | |
| return 0.0 | |
| # Actual DCG from agent ranking | |
| actual_rels = [relevance_scores.get(iid, 0.0) for iid in ranked_ids[:k]] | |
| actual_dcg = dcg_at_k(actual_rels, k) | |
| # Ideal DCG (sorted by relevance, descending) | |
| ideal_rels = sorted(relevance_scores.values(), reverse=True)[:k] | |
| ideal_dcg = dcg_at_k(ideal_rels, k) | |
| if ideal_dcg == 0: | |
| return 0.0 | |
| return actual_dcg / ideal_dcg | |
| def precision_at_k( | |
| selected_ids: List[str], | |
| relevance_scores: Dict[str, float], | |
| k: int, | |
| threshold: float = 0.5, | |
| ) -> float: | |
| """Compute Precision at k. | |
| Args: | |
| selected_ids: Agent's selected item IDs. | |
| relevance_scores: Ground truth {item_id: relevance} scores. | |
| k: Evaluate top-k items. | |
| threshold: Minimum relevance to count as "relevant". | |
| Returns: | |
| Precision score in [0, 1]. | |
| """ | |
| if not selected_ids or k <= 0: | |
| return 0.0 | |
| top_k = selected_ids[:k] | |
| relevant_count = sum( | |
| 1 for iid in top_k if relevance_scores.get(iid, 0.0) >= threshold | |
| ) | |
| return relevant_count / min(k, len(top_k)) | |
| def recall_at_k( | |
| selected_ids: List[str], | |
| relevance_scores: Dict[str, float], | |
| k: int, | |
| threshold: float = 0.5, | |
| ) -> float: | |
| """Compute Recall at k. | |
| Args: | |
| selected_ids: Agent's selected item IDs. | |
| relevance_scores: Ground truth {item_id: relevance} scores. | |
| k: Evaluate top-k items. | |
| threshold: Minimum relevance to count as "relevant". | |
| Returns: | |
| Recall score in [0, 1]. | |
| """ | |
| total_relevant = sum(1 for v in relevance_scores.values() if v >= threshold) | |
| if total_relevant == 0: | |
| return 1.0 # No relevant items to find | |
| top_k = selected_ids[:k] | |
| found_relevant = sum( | |
| 1 for iid in top_k if relevance_scores.get(iid, 0.0) >= threshold | |
| ) | |
| return found_relevant / total_relevant | |
| def source_diversity(selected_ids: List[str], items_by_id: Dict[str, dict]) -> float: | |
| """Compute source diversity of selected items. | |
| Returns: | |
| Diversity score in [0, 1] based on unique source coverage. | |
| """ | |
| if not selected_ids: | |
| return 0.0 | |
| all_sources = set(it.get("source", "") for it in items_by_id.values()) | |
| selected_sources = set( | |
| items_by_id[iid].get("source", "") for iid in selected_ids if iid in items_by_id | |
| ) | |
| if not all_sources: | |
| return 0.0 | |
| return len(selected_sources) / len(all_sources) | |
| def filter_quality( | |
| removed_ids: List[str], | |
| relevance_scores: Dict[str, float], | |
| ) -> float: | |
| """Score a filter action: reward for removing low-relevance items. | |
| Returns: | |
| Score in [0, 1]. Higher is better (removed less relevant items). | |
| """ | |
| if not removed_ids: | |
| return 0.0 | |
| avg_relevance_of_removed = sum( | |
| relevance_scores.get(iid, 0.5) for iid in removed_ids | |
| ) / len(removed_ids) | |
| # Good filtering removes low-relevance items | |
| return max(0.0, min(1.0, 1.0 - avg_relevance_of_removed)) | |
| def categorize_quality( | |
| agent_categories: Dict[str, str], | |
| relevance_scores: Dict[str, float], | |
| threshold_urgent: float = 0.7, | |
| threshold_read: float = 0.4, | |
| ) -> float: | |
| """Score categorization accuracy against relevance-derived ground truth. | |
| Ground truth categories derived from relevance: | |
| >= threshold_urgent → "urgent" | |
| >= threshold_read → "read_later" | |
| < threshold_read → "skip" | |
| (any relevance can be "share" — not penalized) | |
| Returns: | |
| Accuracy score in [0, 1]. | |
| """ | |
| if not agent_categories: | |
| return 0.0 | |
| correct = 0 | |
| total = len(agent_categories) | |
| for item_id, agent_cat in agent_categories.items(): | |
| rel = relevance_scores.get(item_id, 0.0) | |
| # Derive expected category | |
| if rel >= threshold_urgent: | |
| expected = {"urgent", "share"} | |
| elif rel >= threshold_read: | |
| expected = {"read_later", "share"} | |
| else: | |
| expected = {"skip"} | |
| if agent_cat in expected: | |
| correct += 1 | |
| return correct / total | |
| def grade_episode( | |
| recommended_ids: List[str], | |
| ranked_ids: Optional[List[str]], | |
| categories: Optional[Dict[str, str]], | |
| relevance_scores: Dict[str, float], | |
| items_by_id: Dict[str, dict], | |
| recommend_k: int, | |
| ) -> float: | |
| """Compute final episode score (0-1). | |
| Composite: | |
| 0.35 * NDCG@k | |
| 0.25 * Precision@k | |
| 0.20 * Recall@k | |
| 0.10 * Category accuracy | |
| 0.10 * Source diversity | |
| """ | |
| # Use recommended_ids as ranking if no explicit ranking | |
| ranking = ranked_ids if ranked_ids else recommended_ids | |
| ndcg = ndcg_at_k(ranking, relevance_scores, recommend_k) | |
| precision = precision_at_k(recommended_ids, relevance_scores, recommend_k) | |
| recall = recall_at_k(recommended_ids, relevance_scores, recommend_k) | |
| cat_acc = categorize_quality(categories, relevance_scores) if categories else 0.0 | |
| diversity = source_diversity(recommended_ids, items_by_id) | |
| score = ( | |
| 0.35 * ndcg | |
| + 0.25 * precision | |
| + 0.20 * recall | |
| + 0.10 * cat_acc | |
| + 0.10 * diversity | |
| ) | |
| return max(0.0, min(1.0, score)) | |