Spaces:
Running
Running
| """MCP tools for querying territorial ecological indicators.""" | |
| import json | |
| import logging | |
| import time | |
| import hashlib | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timezone | |
| from functools import wraps | |
| from typing import Any, Callable, Optional | |
| from .api_client import get_client, CubeJsClient, CubeJsClientError | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s | %(levelname)s | %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger("mcp_tools") | |
| # ============================================================================= | |
| # Session Tracker - Track usage patterns across calls | |
| # ============================================================================= | |
| class SessionData: | |
| """Track data for a single session.""" | |
| session_id: str | |
| start_time: float = field(default_factory=time.time) | |
| calls: list = field(default_factory=list) | |
| last_call_time: float = 0 | |
| indicators_queried: set = field(default_factory=set) | |
| levels_queried: set = field(default_factory=set) | |
| def add_call(self, tool: str, params: dict, duration_ms: int, | |
| result_count: int, response_size: int, status: str): | |
| """Record a tool call.""" | |
| now = time.time() | |
| time_since_last = int((now - self.last_call_time) * 1000) if self.last_call_time else 0 | |
| self.calls.append({ | |
| "tool": tool, | |
| "params": params, | |
| "duration_ms": duration_ms, | |
| "result_count": result_count, | |
| "response_size": response_size, | |
| "status": status, | |
| "time_since_last_ms": time_since_last, | |
| }) | |
| self.last_call_time = now | |
| # Track what's being queried | |
| if "indicator_id" in params: | |
| self.indicators_queried.add(params["indicator_id"]) | |
| if "geographic_level" in params: | |
| self.levels_queried.add(params["geographic_level"]) | |
| def get_sequence(self) -> str: | |
| """Get the sequence of tools called.""" | |
| return "→".join(c["tool"].replace("_indicators", "").replace("_indicator", "") | |
| for c in self.calls) | |
| def get_total_duration_ms(self) -> int: | |
| """Total time spent in API calls.""" | |
| return sum(c["duration_ms"] for c in self.calls) | |
| class UsageTracker: | |
| """Track MCP usage patterns across sessions.""" | |
| # Session timeout in seconds (new session if no call for 5 minutes) | |
| SESSION_TIMEOUT = 300 | |
| def __init__(self): | |
| self.sessions: dict[str, SessionData] = {} | |
| self.patterns: defaultdict[str, int] = defaultdict(int) # sequence -> count | |
| self.tool_stats: defaultdict[str, dict] = defaultdict( | |
| lambda: {"calls": 0, "total_ms": 0, "errors": 0} | |
| ) | |
| def get_or_create_session(self, session_hint: str = "default") -> SessionData: | |
| """Get existing session or create new one.""" | |
| # Simple session management based on hint (could be IP, user-agent hash, etc.) | |
| session_id = hashlib.md5(session_hint.encode()).hexdigest()[:8] | |
| now = time.time() | |
| # Check if session exists and is not expired | |
| if session_id in self.sessions: | |
| session = self.sessions[session_id] | |
| if session.last_call_time and (now - session.last_call_time) > self.SESSION_TIMEOUT: | |
| # Session expired, log pattern and create new | |
| self._finalize_session(session) | |
| session = SessionData(session_id=session_id) | |
| self.sessions[session_id] = session | |
| logger.info(f"[SESSION] id={session_id} | new_session (previous expired)") | |
| else: | |
| session = SessionData(session_id=session_id) | |
| self.sessions[session_id] = session | |
| logger.info(f"[SESSION] id={session_id} | new_session") | |
| return session | |
| def _finalize_session(self, session: SessionData): | |
| """Log session summary when it ends.""" | |
| if len(session.calls) > 1: | |
| sequence = session.get_sequence() | |
| self.patterns[sequence] += 1 | |
| logger.info( | |
| f"[PATTERN] id={session.session_id} | " | |
| f"sequence={sequence} | " | |
| f"calls={len(session.calls)} | " | |
| f"total_ms={session.get_total_duration_ms()} | " | |
| f"indicators={list(session.indicators_queried)} | " | |
| f"levels={list(session.levels_queried)}" | |
| ) | |
| def log_stats_summary(self): | |
| """Log accumulated statistics.""" | |
| if self.patterns: | |
| top_patterns = sorted(self.patterns.items(), key=lambda x: -x[1])[:5] | |
| logger.info(f"[STATS] top_patterns={top_patterns}") | |
| # Global tracker instance | |
| _tracker = UsageTracker() | |
| def log_tool_call(func: Callable) -> Callable: | |
| """Decorator to log MCP tool calls with rich metrics.""" | |
| async def wrapper(*args, **kwargs): | |
| tool_name = func.__name__ | |
| start_time = time.time() | |
| # Get or create session | |
| session = _tracker.get_or_create_session() | |
| # Extract params (only non-empty) | |
| params = {k: v for k, v in kwargs.items() if v} | |
| # Build context info | |
| call_num = len(session.calls) + 1 | |
| prev_tool = session.calls[-1]["tool"] if session.calls else None | |
| # Log the call with context | |
| context = f"call#{call_num}" | |
| if prev_tool: | |
| context += f" | prev={prev_tool}" | |
| logger.info(f"[CALL] {tool_name} | {context} | params={params}") | |
| try: | |
| result = await func(*args, **kwargs) | |
| elapsed_ms = int((time.time() - start_time) * 1000) | |
| response_size = len(result.encode('utf-8')) | |
| # Parse result to get metrics | |
| status = "ok" | |
| result_count = 0 | |
| try: | |
| result_data = json.loads(result) | |
| if "error" in result_data: | |
| status = "error" | |
| logger.warning( | |
| f"[ERROR] {tool_name} | {elapsed_ms}ms | " | |
| f"error={result_data['error'][:100]}" | |
| ) | |
| else: | |
| result_count = ( | |
| result_data.get("count") or | |
| result_data.get("total_count") or | |
| len(result_data.get("data", [])) or | |
| (1 if "metadata" in result_data else 0) | |
| ) | |
| logger.info( | |
| f"[OK] {tool_name} | {elapsed_ms}ms | " | |
| f"count={result_count} | size={response_size}B" | |
| ) | |
| except json.JSONDecodeError: | |
| logger.info(f"[OK] {tool_name} | {elapsed_ms}ms | size={response_size}B") | |
| # Record in session | |
| session.add_call( | |
| tool=tool_name, | |
| params=params, | |
| duration_ms=elapsed_ms, | |
| result_count=result_count, | |
| response_size=response_size, | |
| status=status, | |
| ) | |
| # Update global stats | |
| _tracker.tool_stats[tool_name]["calls"] += 1 | |
| _tracker.tool_stats[tool_name]["total_ms"] += elapsed_ms | |
| if status == "error": | |
| _tracker.tool_stats[tool_name]["errors"] += 1 | |
| return result | |
| except Exception as e: | |
| elapsed_ms = int((time.time() - start_time) * 1000) | |
| logger.error(f"[EXCEPTION] {tool_name} | {elapsed_ms}ms | {type(e).__name__}: {e}") | |
| _tracker.tool_stats[tool_name]["errors"] += 1 | |
| raise | |
| return wrapper | |
| from .cache import get_cache, initialize_cache, refresh_cache_if_needed | |
| from .cube_resolver import get_resolver | |
| from .models import ( | |
| IndicatorMetadata, | |
| SourceMetadata, | |
| IndicatorListItem, | |
| GEOGRAPHIC_LEVELS, | |
| GEO_DIMENSION_PATTERNS, | |
| ) | |
| async def _ensure_cache_initialized() -> None: | |
| """Ensure the cache is initialized before tool execution.""" | |
| cache = get_cache() | |
| if not cache.is_initialized: | |
| await initialize_cache() | |
| else: | |
| await refresh_cache_if_needed() | |
| async def list_indicators( | |
| thematique: str = "", | |
| maille: str = "", | |
| ) -> str: | |
| """List all available territorial ecological indicators. | |
| Returns a list of indicators with their main characteristics. You can filter | |
| by thematic (France Nation Verte themes like "mieux se déplacer", "mieux se loger") | |
| or by geographic level (region, departement, epci, commune). | |
| Args: | |
| thematique: Optional filter by FNV thematic. Use partial match, e.g., "déplacer" | |
| for mobility indicators, "loger" for housing, "produire" for production. | |
| maille: Optional filter by available geographic level. Valid values: | |
| "region", "departement", "epci", "commune". | |
| Returns: | |
| JSON string containing a list of indicators with id, libelle, unite, | |
| mailles_disponibles, and thematique_fnv. | |
| Example: | |
| To find mobility indicators available at department level: | |
| list_indicators(thematique="déplacer", maille="departement") | |
| """ | |
| await _ensure_cache_initialized() | |
| cache = get_cache() | |
| # Normalize empty strings to None | |
| theme_filter = thematique.strip() if thematique else None | |
| maille_filter = maille.strip().lower() if maille else None | |
| # Validate maille if provided | |
| if maille_filter and maille_filter not in GEOGRAPHIC_LEVELS: | |
| return json.dumps({ | |
| "error": f"Invalid geographic level: {maille}", | |
| "valid_levels": GEOGRAPHIC_LEVELS, | |
| }, ensure_ascii=False) | |
| indicators = cache.list_indicators( | |
| thematique=theme_filter, | |
| maille=maille_filter, | |
| ) | |
| return json.dumps({ | |
| "indicators": [ind.model_dump() for ind in indicators], | |
| "count": len(indicators), | |
| "filters_applied": { | |
| "thematique": theme_filter, | |
| "maille": maille_filter, | |
| }, | |
| }, ensure_ascii=False, indent=2) | |
| async def get_indicator_details(indicator_id: str) -> str: | |
| """Get detailed information about a specific indicator. | |
| Returns comprehensive metadata including description, calculation method, | |
| data coverage, and data sources for a given indicator ID. | |
| Args: | |
| indicator_id: The numeric ID of the indicator (e.g., "42", "94", "611"). | |
| Returns: | |
| JSON string containing: | |
| - metadata: Full indicator metadata (description, methode_calcul, | |
| annees_disponibles, completion rates by geographic level, etc.) | |
| - sources: List of data sources with producer, license, and links. | |
| - available_cubes: Dict mapping maille to cube name for data queries. | |
| Example: | |
| get_indicator_details("611") returns details about indicator 611 | |
| (Consommation d'espaces naturels, agricoles et forestiers). | |
| """ | |
| await _ensure_cache_initialized() | |
| # Parse indicator ID | |
| try: | |
| ind_id = int(indicator_id) | |
| except ValueError: | |
| return json.dumps({ | |
| "error": f"Invalid indicator ID: {indicator_id}. Must be a number.", | |
| }, ensure_ascii=False) | |
| cache = get_cache() | |
| indicator = cache.get_indicator(ind_id) | |
| if indicator is None: | |
| return json.dumps({ | |
| "error": f"Indicator {ind_id} not found in metadata.", | |
| "hint": "Use list_indicators() to see available indicators.", | |
| }, ensure_ascii=False) | |
| # Get available cubes from resolver | |
| resolver = get_resolver() | |
| available_cubes = resolver.get_cubes_for_indicator(ind_id) | |
| # Fetch sources from API | |
| client = get_client() | |
| try: | |
| sources_data = await client.load_sources_metadata(indicator_id=ind_id) | |
| sources = [ | |
| SourceMetadata.from_api_response(row).model_dump() | |
| for row in sources_data | |
| ] | |
| except CubeJsClientError as e: | |
| sources = [] | |
| sources_error = str(e) | |
| else: | |
| sources_error = None | |
| result = { | |
| "metadata": indicator.model_dump(), | |
| "sources": sources, | |
| "available_cubes": available_cubes, | |
| } | |
| if sources_error: | |
| result["sources_warning"] = f"Could not fetch sources: {sources_error}" | |
| return json.dumps(result, ensure_ascii=False, indent=2) | |
| async def query_indicator_data( | |
| indicator_id: str, | |
| geographic_level: str, | |
| geographic_code: str = "", | |
| year: str = "", | |
| ) -> str: | |
| """Query data values for a specific indicator and territory. | |
| Retrieves actual data values for an indicator at the specified geographic level. | |
| You can filter by a specific territory code and/or year. | |
| Args: | |
| indicator_id: The numeric ID of the indicator (e.g., "611"). | |
| geographic_level: The geographic level to query. Valid values: | |
| "region", "departement", "epci", "commune". | |
| geographic_code: Optional INSEE code to filter by territory: | |
| - Region: 2 digits (e.g., "93" for PACA, "11" for Île-de-France) | |
| - Departement: 2-3 characters (e.g., "13", "2A", "974") | |
| - EPCI: 9 digits (SIREN code) | |
| - Commune: 5 digits (e.g., "75056" for Paris) | |
| year: Optional year to filter data (e.g., "2020"). | |
| Returns: | |
| JSON string containing: | |
| - indicator_id: The queried indicator ID | |
| - indicator_name: Human-readable name | |
| - geographic_level: The queried level | |
| - data: List of data points with geocode, libelle, valeur, annee | |
| - total_count: Number of results | |
| Example: | |
| Query indicator 611 (ENAF consumption) for PACA region: | |
| query_indicator_data("611", "region", "93") | |
| Query all departments for 2020: | |
| query_indicator_data("611", "departement", year="2020") | |
| """ | |
| await _ensure_cache_initialized() | |
| # Parse indicator ID | |
| try: | |
| ind_id = int(indicator_id) | |
| except ValueError: | |
| return json.dumps({ | |
| "error": f"Invalid indicator ID: {indicator_id}. Must be a number.", | |
| }, ensure_ascii=False) | |
| # Validate geographic level | |
| geo_level = geographic_level.strip().lower() | |
| if geo_level not in GEOGRAPHIC_LEVELS: | |
| return json.dumps({ | |
| "error": f"Invalid geographic level: {geographic_level}", | |
| "valid_levels": GEOGRAPHIC_LEVELS, | |
| }, ensure_ascii=False) | |
| cache = get_cache() | |
| resolver = get_resolver() | |
| indicator = cache.get_indicator(ind_id) | |
| indicator_name = indicator.libelle if indicator else f"Indicator {ind_id}" | |
| indicator_unite = indicator.unite if indicator else None | |
| # Find the cube for this indicator and maille | |
| cube_name = resolver.find_cube_for_indicator(ind_id, geo_level) | |
| if cube_name is None: | |
| # Check if indicator exists at all | |
| if not resolver.is_indicator_known(ind_id): | |
| return json.dumps({ | |
| "error": f"Indicator {ind_id} not found in any data cube.", | |
| "hint": "Use get_indicator_details() to check available mailles.", | |
| }, ensure_ascii=False) | |
| # Indicator exists but not at this maille | |
| available = resolver.get_available_mailles(ind_id) | |
| return json.dumps({ | |
| "error": f"Indicator {ind_id} is not available at {geo_level} level.", | |
| "available_levels": available, | |
| "hint": f"Try one of: {', '.join(available)}", | |
| }, ensure_ascii=False) | |
| # Build the query | |
| geo_patterns = GEO_DIMENSION_PATTERNS[geo_level] | |
| # Measure and dimensions with full cube prefix | |
| measure = resolver.get_measure_name(cube_name, ind_id) | |
| geocode_dim = resolver.get_dimension_name(cube_name, geo_patterns["geocode"]) | |
| libelle_dim = resolver.get_dimension_name(cube_name, geo_patterns["libelle"]) | |
| annee_dim = resolver.get_dimension_name(cube_name, "annee") | |
| query: dict[str, Any] = { | |
| "measures": [measure], | |
| "dimensions": [geocode_dim, libelle_dim, annee_dim], | |
| "limit": 500, | |
| } | |
| # Add filters | |
| filters = [] | |
| geo_code = geographic_code.strip() if geographic_code else None | |
| if geo_code: | |
| filters.append({ | |
| "member": geocode_dim, | |
| "operator": "equals", | |
| "values": [geo_code], | |
| }) | |
| year_filter = year.strip() if year else None | |
| if year_filter: | |
| filters.append({ | |
| "member": annee_dim, | |
| "operator": "equals", | |
| "values": [year_filter], | |
| }) | |
| if filters: | |
| query["filters"] = filters | |
| # Execute query | |
| client = get_client() | |
| try: | |
| result = await client.load(query) | |
| data_rows = result.get("data", []) | |
| except CubeJsClientError as e: | |
| return json.dumps({ | |
| "error": f"Query failed: {str(e)}", | |
| "cube": cube_name, | |
| "query": query, | |
| }, ensure_ascii=False, indent=2) | |
| # Parse results | |
| data_points = [] | |
| for row in data_rows: | |
| data_points.append({ | |
| "geocode": row.get(geocode_dim), | |
| "libelle": row.get(libelle_dim), | |
| "annee": row.get(annee_dim), | |
| "valeur": row.get(measure), | |
| "unite": indicator_unite, | |
| }) | |
| # Sort by year, then by libelle | |
| data_points.sort(key=lambda x: (x.get("annee") or "", x.get("libelle") or "")) | |
| return json.dumps({ | |
| "indicator_id": ind_id, | |
| "indicator_name": indicator_name, | |
| "geographic_level": geo_level, | |
| "data": data_points, | |
| "total_count": len(data_points), | |
| "query_info": { | |
| "cube": cube_name, | |
| "measure": measure, | |
| "geographic_code_filter": geo_code, | |
| "year_filter": year_filter, | |
| }, | |
| }, ensure_ascii=False, indent=2) | |
| async def search_indicators(query: str) -> str: | |
| """Search indicators by keywords in their name or description. | |
| Performs a full-text search across indicator names (libelle) and descriptions. | |
| All search terms must be present for an indicator to match (AND logic). | |
| Args: | |
| query: Search terms separated by spaces. Examples: | |
| - "consommation espace" finds indicators about land consumption | |
| - "émissions CO2" finds indicators about CO2 emissions | |
| - "surface bio" finds organic surface indicators | |
| Returns: | |
| JSON string containing: | |
| - indicators: List of matching indicators with id, libelle, unite, | |
| mailles_disponibles, thematique_fnv | |
| - query: The original search query | |
| - total_count: Number of results | |
| Example: | |
| search_indicators("consommation espace") returns indicators mentioning | |
| both "consommation" and "espace" in their name or description. | |
| """ | |
| await _ensure_cache_initialized() | |
| cache = get_cache() | |
| search_query = query.strip() if query else "" | |
| if not search_query: | |
| # Return all indicators if no query | |
| indicators = cache.list_indicators() | |
| else: | |
| indicators = cache.search_indicators(search_query) | |
| return json.dumps({ | |
| "indicators": [ind.model_dump() for ind in indicators], | |
| "query": search_query, | |
| "total_count": len(indicators), | |
| }, ensure_ascii=False, indent=2) | |
| # Export all tools | |
| __all__ = [ | |
| "list_indicators", | |
| "get_indicator_details", | |
| "query_indicator_data", | |
| "search_indicators", | |
| ] | |