File size: 3,838 Bytes
5c1d1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Accès bornés au graphe JDM pour le moteur d'inférence.

Toutes les lectures passent par un `LookupBudget` (1 unité par appel
`relations_from`) et par un cache mémoire local à une inférence (`mem`),
qui évite de re-dépenser le budget pour un couple `(terme, relation)`
déjà consulté. Le cache disque de `JDMClient` couvre, lui, les inférences
successives.
"""
from __future__ import annotations

from typing import Optional

from jdm_agent.client import JDMClient
from jdm_agent.inference.budget import LookupBudget


def norm(s: str) -> str:
    """Normalisation simple pour matcher des noms (casse, espaces)."""
    return (s or "").strip().lower()


def outgoing(client: JDMClient, budget: LookupBudget, mem: dict,
             term: str, relation: str) -> list[tuple[str, float, int]]:
    """Triplets sortants `(term, relation, ?)`.

    Renvoie `[(name, w, rel_id), ...]`. 1 appel HTTP (sauf hit `mem`).
    `[]` si la relation est inconnue de JDM ou en cas d'erreur réseau.
    """
    key = (norm(term), relation)
    if key in mem:
        return mem[key]
    rid = client.relation_type_id(relation)
    if rid is None:
        mem[key] = []
        return []
    budget.spend(1)  # une lecture JDM consommée
    try:
        res = client.relations_from(term, types_ids=[rid])
    except Exception:
        mem[key] = []
        return []
    idx = res.node_index()
    out: list[tuple[str, float, int]] = []
    for r in res.relations:
        n = idx.get(r.node2)
        if n is not None:
            out.append((n.name, r.w, r.id))
    mem[key] = out
    return out


def edge_weight(client: JDMClient, budget: LookupBudget, mem: dict,
                source: str, relation: str, target: str) -> float:
    """Poids signé du triplet exact `(source, relation, target)`, 0 si absent.

    Matching insensible à la casse + décodage des refinements opaques.
    """
    rows = outgoing(client, budget, mem, source, relation)
    tgt = norm(target)
    for name, w, _rid in rows:
        if norm(name) == tgt:
            return w
        try:
            dec = client.decode_node_name(name)
        except Exception:
            continue
        if dec.get("is_refinement") and norm(dec.get("decoded", "")) == tgt:
            return w
    return 0.0


def display(client: JDMClient, name: str) -> str:
    """Nom lisible d'un nœud (décodage refinement ; sûr en cas d'erreur)."""
    try:
        return client.decode_node_name(name).get("decoded", name)
    except Exception:
        return name


def topk_positive(rows: list[tuple[str, float, int]], k: int
                  ) -> list[tuple[str, float, int]]:
    """Trie les triplets par poids positif décroissant, tronque à `k`."""
    pos = [(n, w, rid) for n, w, rid in rows if w > 0]
    pos.sort(key=lambda x: -x[1])
    return pos[:max(1, k)]


def generics(client: JDMClient, budget: LookupBudget, mem: dict,
             term: str, k: int,
             relations: Optional[tuple[str, ...]] = None
             ) -> list[tuple[str, float, str]]:
    """Sur-ensembles d'un terme : hyperonymes (r_isa) + synonymes (r_syn).

    Renvoie `[(name, w, via_relation), ...]` dédupliqué, top-k par poids.
    """
    from jdm_agent.inference.constants import GENERIC_RELATIONS
    rels = relations or GENERIC_RELATIONS
    collected: list[tuple[str, float, str]] = []
    for rel in rels:
        for name, w, _rid in outgoing(client, budget, mem, term, rel):
            if w > 0:
                collected.append((name, w, rel))
    collected.sort(key=lambda x: -x[1])
    seen: set[str] = set()
    uniq: list[tuple[str, float, str]] = []
    for name, w, rel in collected:
        kk = norm(name)
        if kk in seen or kk == norm(term):
            continue
        seen.add(kk)
        uniq.append((name, w, rel))
    return uniq[:max(1, k)]