File size: 19,874 Bytes
bad6218
 
 
831a68e
 
3fb80f4
 
 
831a68e
 
3fb80f4
bad6218
 
831a68e
 
 
 
 
 
 
 
 
 
3fb80f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
3fb80f4
831a68e
 
 
 
 
3fb80f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
 
 
 
3fb80f4
831a68e
3fb80f4
 
 
831a68e
 
 
3fb80f4
 
 
 
 
831a68e
3fb80f4
 
 
 
 
 
 
 
 
 
831a68e
3fb80f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
 
 
 
 
 
3fb80f4
831a68e
 
 
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbdc8a5
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbdc8a5
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831a68e
bad6218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
"""MCP tools for querying territorial ecological indicators."""

import json
import logging
import time
import hashlib
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Optional

from .api_client import get_client, CubeJsClient, CubeJsClientError

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger("mcp_tools")


# =============================================================================
# Session Tracker - Track usage patterns across calls
# =============================================================================
@dataclass
class SessionData:
    """Track data for a single session."""
    session_id: str
    start_time: float = field(default_factory=time.time)
    calls: list = field(default_factory=list)
    last_call_time: float = 0
    indicators_queried: set = field(default_factory=set)
    levels_queried: set = field(default_factory=set)
    
    def add_call(self, tool: str, params: dict, duration_ms: int, 
                 result_count: int, response_size: int, status: str):
        """Record a tool call."""
        now = time.time()
        time_since_last = int((now - self.last_call_time) * 1000) if self.last_call_time else 0
        
        self.calls.append({
            "tool": tool,
            "params": params,
            "duration_ms": duration_ms,
            "result_count": result_count,
            "response_size": response_size,
            "status": status,
            "time_since_last_ms": time_since_last,
        })
        self.last_call_time = now
        
        # Track what's being queried
        if "indicator_id" in params:
            self.indicators_queried.add(params["indicator_id"])
        if "geographic_level" in params:
            self.levels_queried.add(params["geographic_level"])
    
    def get_sequence(self) -> str:
        """Get the sequence of tools called."""
        return "→".join(c["tool"].replace("_indicators", "").replace("_indicator", "") 
                       for c in self.calls)
    
    def get_total_duration_ms(self) -> int:
        """Total time spent in API calls."""
        return sum(c["duration_ms"] for c in self.calls)


class UsageTracker:
    """Track MCP usage patterns across sessions."""
    
    # Session timeout in seconds (new session if no call for 5 minutes)
    SESSION_TIMEOUT = 300
    
    def __init__(self):
        self.sessions: dict[str, SessionData] = {}
        self.patterns: defaultdict[str, int] = defaultdict(int)  # sequence -> count
        self.tool_stats: defaultdict[str, dict] = defaultdict(
            lambda: {"calls": 0, "total_ms": 0, "errors": 0}
        )
    
    def get_or_create_session(self, session_hint: str = "default") -> SessionData:
        """Get existing session or create new one."""
        # Simple session management based on hint (could be IP, user-agent hash, etc.)
        session_id = hashlib.md5(session_hint.encode()).hexdigest()[:8]
        
        now = time.time()
        
        # Check if session exists and is not expired
        if session_id in self.sessions:
            session = self.sessions[session_id]
            if session.last_call_time and (now - session.last_call_time) > self.SESSION_TIMEOUT:
                # Session expired, log pattern and create new
                self._finalize_session(session)
                session = SessionData(session_id=session_id)
                self.sessions[session_id] = session
                logger.info(f"[SESSION] id={session_id} | new_session (previous expired)")
        else:
            session = SessionData(session_id=session_id)
            self.sessions[session_id] = session
            logger.info(f"[SESSION] id={session_id} | new_session")
        
        return session
    
    def _finalize_session(self, session: SessionData):
        """Log session summary when it ends."""
        if len(session.calls) > 1:
            sequence = session.get_sequence()
            self.patterns[sequence] += 1
            
            logger.info(
                f"[PATTERN] id={session.session_id} | "
                f"sequence={sequence} | "
                f"calls={len(session.calls)} | "
                f"total_ms={session.get_total_duration_ms()} | "
                f"indicators={list(session.indicators_queried)} | "
                f"levels={list(session.levels_queried)}"
            )
    
    def log_stats_summary(self):
        """Log accumulated statistics."""
        if self.patterns:
            top_patterns = sorted(self.patterns.items(), key=lambda x: -x[1])[:5]
            logger.info(f"[STATS] top_patterns={top_patterns}")


# Global tracker instance
_tracker = UsageTracker()


def log_tool_call(func: Callable) -> Callable:
    """Decorator to log MCP tool calls with rich metrics."""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        tool_name = func.__name__
        start_time = time.time()
        
        # Get or create session
        session = _tracker.get_or_create_session()
        
        # Extract params (only non-empty)
        params = {k: v for k, v in kwargs.items() if v}
        
        # Build context info
        call_num = len(session.calls) + 1
        prev_tool = session.calls[-1]["tool"] if session.calls else None
        
        # Log the call with context
        context = f"call#{call_num}"
        if prev_tool:
            context += f" | prev={prev_tool}"
        logger.info(f"[CALL] {tool_name} | {context} | params={params}")
        
        try:
            result = await func(*args, **kwargs)
            elapsed_ms = int((time.time() - start_time) * 1000)
            response_size = len(result.encode('utf-8'))
            
            # Parse result to get metrics
            status = "ok"
            result_count = 0
            try:
                result_data = json.loads(result)
                if "error" in result_data:
                    status = "error"
                    logger.warning(
                        f"[ERROR] {tool_name} | {elapsed_ms}ms | "
                        f"error={result_data['error'][:100]}"
                    )
                else:
                    result_count = (
                        result_data.get("count") or 
                        result_data.get("total_count") or 
                        len(result_data.get("data", [])) or
                        (1 if "metadata" in result_data else 0)
                    )
                    logger.info(
                        f"[OK] {tool_name} | {elapsed_ms}ms | "
                        f"count={result_count} | size={response_size}B"
                    )
            except json.JSONDecodeError:
                logger.info(f"[OK] {tool_name} | {elapsed_ms}ms | size={response_size}B")
            
            # Record in session
            session.add_call(
                tool=tool_name,
                params=params,
                duration_ms=elapsed_ms,
                result_count=result_count,
                response_size=response_size,
                status=status,
            )
            
            # Update global stats
            _tracker.tool_stats[tool_name]["calls"] += 1
            _tracker.tool_stats[tool_name]["total_ms"] += elapsed_ms
            if status == "error":
                _tracker.tool_stats[tool_name]["errors"] += 1
            
            return result
            
        except Exception as e:
            elapsed_ms = int((time.time() - start_time) * 1000)
            logger.error(f"[EXCEPTION] {tool_name} | {elapsed_ms}ms | {type(e).__name__}: {e}")
            _tracker.tool_stats[tool_name]["errors"] += 1
            raise
    
    return wrapper
from .cache import get_cache, initialize_cache, refresh_cache_if_needed
from .cube_resolver import get_resolver
from .models import (
    IndicatorMetadata,
    SourceMetadata,
    IndicatorListItem,
    GEOGRAPHIC_LEVELS,
    GEO_DIMENSION_PATTERNS,
)


async def _ensure_cache_initialized() -> None:
    """Ensure the cache is initialized before tool execution."""
    cache = get_cache()
    if not cache.is_initialized:
        await initialize_cache()
    else:
        await refresh_cache_if_needed()


@log_tool_call
async def list_indicators(
    thematique: str = "",
    maille: str = "",
) -> str:
    """List all available territorial ecological indicators.
    
    Returns a list of indicators with their main characteristics. You can filter
    by thematic (France Nation Verte themes like "mieux se déplacer", "mieux se loger")
    or by geographic level (region, departement, epci, commune).
    
    Args:
        thematique: Optional filter by FNV thematic. Use partial match, e.g., "déplacer" 
            for mobility indicators, "loger" for housing, "produire" for production.
        maille: Optional filter by available geographic level. Valid values: 
            "region", "departement", "epci", "commune".
    
    Returns:
        JSON string containing a list of indicators with id, libelle, unite, 
        mailles_disponibles, and thematique_fnv.
    
    Example:
        To find mobility indicators available at department level:
        list_indicators(thematique="déplacer", maille="departement")
    """
    await _ensure_cache_initialized()
    cache = get_cache()
    
    # Normalize empty strings to None
    theme_filter = thematique.strip() if thematique else None
    maille_filter = maille.strip().lower() if maille else None
    
    # Validate maille if provided
    if maille_filter and maille_filter not in GEOGRAPHIC_LEVELS:
        return json.dumps({
            "error": f"Invalid geographic level: {maille}",
            "valid_levels": GEOGRAPHIC_LEVELS,
        }, ensure_ascii=False)
    
    indicators = cache.list_indicators(
        thematique=theme_filter,
        maille=maille_filter,
    )
    
    return json.dumps({
        "indicators": [ind.model_dump() for ind in indicators],
        "count": len(indicators),
        "filters_applied": {
            "thematique": theme_filter,
            "maille": maille_filter,
        },
    }, ensure_ascii=False, indent=2)


@log_tool_call
async def get_indicator_details(indicator_id: str) -> str:
    """Get detailed information about a specific indicator.
    
    Returns comprehensive metadata including description, calculation method,
    data coverage, and data sources for a given indicator ID.
    
    Args:
        indicator_id: The numeric ID of the indicator (e.g., "42", "94", "611").
    
    Returns:
        JSON string containing:
        - metadata: Full indicator metadata (description, methode_calcul, 
          annees_disponibles, completion rates by geographic level, etc.)
        - sources: List of data sources with producer, license, and links.
        - available_cubes: Dict mapping maille to cube name for data queries.
    
    Example:
        get_indicator_details("611") returns details about indicator 611
        (Consommation d'espaces naturels, agricoles et forestiers).
    """
    await _ensure_cache_initialized()
    
    # Parse indicator ID
    try:
        ind_id = int(indicator_id)
    except ValueError:
        return json.dumps({
            "error": f"Invalid indicator ID: {indicator_id}. Must be a number.",
        }, ensure_ascii=False)
    
    cache = get_cache()
    indicator = cache.get_indicator(ind_id)
    
    if indicator is None:
        return json.dumps({
            "error": f"Indicator {ind_id} not found in metadata.",
            "hint": "Use list_indicators() to see available indicators.",
        }, ensure_ascii=False)
    
    # Get available cubes from resolver
    resolver = get_resolver()
    available_cubes = resolver.get_cubes_for_indicator(ind_id)
    
    # Fetch sources from API
    client = get_client()
    try:
        sources_data = await client.load_sources_metadata(indicator_id=ind_id)
        sources = [
            SourceMetadata.from_api_response(row).model_dump()
            for row in sources_data
        ]
    except CubeJsClientError as e:
        sources = []
        sources_error = str(e)
    else:
        sources_error = None
    
    result = {
        "metadata": indicator.model_dump(),
        "sources": sources,
        "available_cubes": available_cubes,
    }
    
    if sources_error:
        result["sources_warning"] = f"Could not fetch sources: {sources_error}"
    
    return json.dumps(result, ensure_ascii=False, indent=2)


@log_tool_call
async def query_indicator_data(
    indicator_id: str,
    geographic_level: str,
    geographic_code: str = "",
    year: str = "",
) -> str:
    """Query data values for a specific indicator and territory.
    
    Retrieves actual data values for an indicator at the specified geographic level.
    You can filter by a specific territory code and/or year.
    
    Args:
        indicator_id: The numeric ID of the indicator (e.g., "611").
        geographic_level: The geographic level to query. Valid values:
            "region", "departement", "epci", "commune".
        geographic_code: Optional INSEE code to filter by territory:
            - Region: 2 digits (e.g., "93" for PACA, "11" for Île-de-France)
            - Departement: 2-3 characters (e.g., "13", "2A", "974")
            - EPCI: 9 digits (SIREN code)
            - Commune: 5 digits (e.g., "75056" for Paris)
        year: Optional year to filter data (e.g., "2020").
    
    Returns:
        JSON string containing:
        - indicator_id: The queried indicator ID
        - indicator_name: Human-readable name
        - geographic_level: The queried level
        - data: List of data points with geocode, libelle, valeur, annee
        - total_count: Number of results
    
    Example:
        Query indicator 611 (ENAF consumption) for PACA region:
        query_indicator_data("611", "region", "93")
        
        Query all departments for 2020:
        query_indicator_data("611", "departement", year="2020")
    """
    await _ensure_cache_initialized()
    
    # Parse indicator ID
    try:
        ind_id = int(indicator_id)
    except ValueError:
        return json.dumps({
            "error": f"Invalid indicator ID: {indicator_id}. Must be a number.",
        }, ensure_ascii=False)
    
    # Validate geographic level
    geo_level = geographic_level.strip().lower()
    if geo_level not in GEOGRAPHIC_LEVELS:
        return json.dumps({
            "error": f"Invalid geographic level: {geographic_level}",
            "valid_levels": GEOGRAPHIC_LEVELS,
        }, ensure_ascii=False)
    
    cache = get_cache()
    resolver = get_resolver()
    
    indicator = cache.get_indicator(ind_id)
    indicator_name = indicator.libelle if indicator else f"Indicator {ind_id}"
    indicator_unite = indicator.unite if indicator else None
    
    # Find the cube for this indicator and maille
    cube_name = resolver.find_cube_for_indicator(ind_id, geo_level)
    
    if cube_name is None:
        # Check if indicator exists at all
        if not resolver.is_indicator_known(ind_id):
            return json.dumps({
                "error": f"Indicator {ind_id} not found in any data cube.",
                "hint": "Use get_indicator_details() to check available mailles.",
            }, ensure_ascii=False)
        
        # Indicator exists but not at this maille
        available = resolver.get_available_mailles(ind_id)
        return json.dumps({
            "error": f"Indicator {ind_id} is not available at {geo_level} level.",
            "available_levels": available,
            "hint": f"Try one of: {', '.join(available)}",
        }, ensure_ascii=False)
    
    # Build the query
    geo_patterns = GEO_DIMENSION_PATTERNS[geo_level]
    
    # Measure and dimensions with full cube prefix
    measure = resolver.get_measure_name(cube_name, ind_id)
    geocode_dim = resolver.get_dimension_name(cube_name, geo_patterns["geocode"])
    libelle_dim = resolver.get_dimension_name(cube_name, geo_patterns["libelle"])
    annee_dim = resolver.get_dimension_name(cube_name, "annee")
    
    query: dict[str, Any] = {
        "measures": [measure],
        "dimensions": [geocode_dim, libelle_dim, annee_dim],
        "limit": 500,
    }
    
    # Add filters
    filters = []
    
    geo_code = geographic_code.strip() if geographic_code else None
    if geo_code:
        filters.append({
            "member": geocode_dim,
            "operator": "equals",
            "values": [geo_code],
        })
    
    year_filter = year.strip() if year else None
    if year_filter:
        filters.append({
            "member": annee_dim,
            "operator": "equals",
            "values": [year_filter],
        })
    
    if filters:
        query["filters"] = filters
    
    # Execute query
    client = get_client()
    try:
        result = await client.load(query)
        data_rows = result.get("data", [])
    except CubeJsClientError as e:
        return json.dumps({
            "error": f"Query failed: {str(e)}",
            "cube": cube_name,
            "query": query,
        }, ensure_ascii=False, indent=2)
    
    # Parse results
    data_points = []
    for row in data_rows:
        data_points.append({
            "geocode": row.get(geocode_dim),
            "libelle": row.get(libelle_dim),
            "annee": row.get(annee_dim),
            "valeur": row.get(measure),
            "unite": indicator_unite,
        })
    
    # Sort by year, then by libelle
    data_points.sort(key=lambda x: (x.get("annee") or "", x.get("libelle") or ""))
    
    return json.dumps({
        "indicator_id": ind_id,
        "indicator_name": indicator_name,
        "geographic_level": geo_level,
        "data": data_points,
        "total_count": len(data_points),
        "query_info": {
            "cube": cube_name,
            "measure": measure,
            "geographic_code_filter": geo_code,
            "year_filter": year_filter,
        },
    }, ensure_ascii=False, indent=2)


@log_tool_call
async def search_indicators(query: str) -> str:
    """Search indicators by keywords in their name or description.
    
    Performs a full-text search across indicator names (libelle) and descriptions.
    All search terms must be present for an indicator to match (AND logic).
    
    Args:
        query: Search terms separated by spaces. Examples:
            - "consommation espace" finds indicators about land consumption
            - "émissions CO2" finds indicators about CO2 emissions
            - "surface bio" finds organic surface indicators
    
    Returns:
        JSON string containing:
        - indicators: List of matching indicators with id, libelle, unite, 
          mailles_disponibles, thematique_fnv
        - query: The original search query
        - total_count: Number of results
    
    Example:
        search_indicators("consommation espace") returns indicators mentioning
        both "consommation" and "espace" in their name or description.
    """
    await _ensure_cache_initialized()
    cache = get_cache()
    
    search_query = query.strip() if query else ""
    
    if not search_query:
        # Return all indicators if no query
        indicators = cache.list_indicators()
    else:
        indicators = cache.search_indicators(search_query)
    
    return json.dumps({
        "indicators": [ind.model_dump() for ind in indicators],
        "query": search_query,
        "total_count": len(indicators),
    }, ensure_ascii=False, indent=2)


# Export all tools
__all__ = [
    "list_indicators",
    "get_indicator_details",
    "query_indicator_data",
    "search_indicators",
]