File size: 3,838 Bytes
5c1d1c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | """Accès bornés au graphe JDM pour le moteur d'inférence.
Toutes les lectures passent par un `LookupBudget` (1 unité par appel
`relations_from`) et par un cache mémoire local à une inférence (`mem`),
qui évite de re-dépenser le budget pour un couple `(terme, relation)`
déjà consulté. Le cache disque de `JDMClient` couvre, lui, les inférences
successives.
"""
from __future__ import annotations
from typing import Optional
from jdm_agent.client import JDMClient
from jdm_agent.inference.budget import LookupBudget
def norm(s: str) -> str:
"""Normalisation simple pour matcher des noms (casse, espaces)."""
return (s or "").strip().lower()
def outgoing(client: JDMClient, budget: LookupBudget, mem: dict,
term: str, relation: str) -> list[tuple[str, float, int]]:
"""Triplets sortants `(term, relation, ?)`.
Renvoie `[(name, w, rel_id), ...]`. 1 appel HTTP (sauf hit `mem`).
`[]` si la relation est inconnue de JDM ou en cas d'erreur réseau.
"""
key = (norm(term), relation)
if key in mem:
return mem[key]
rid = client.relation_type_id(relation)
if rid is None:
mem[key] = []
return []
budget.spend(1) # une lecture JDM consommée
try:
res = client.relations_from(term, types_ids=[rid])
except Exception:
mem[key] = []
return []
idx = res.node_index()
out: list[tuple[str, float, int]] = []
for r in res.relations:
n = idx.get(r.node2)
if n is not None:
out.append((n.name, r.w, r.id))
mem[key] = out
return out
def edge_weight(client: JDMClient, budget: LookupBudget, mem: dict,
source: str, relation: str, target: str) -> float:
"""Poids signé du triplet exact `(source, relation, target)`, 0 si absent.
Matching insensible à la casse + décodage des refinements opaques.
"""
rows = outgoing(client, budget, mem, source, relation)
tgt = norm(target)
for name, w, _rid in rows:
if norm(name) == tgt:
return w
try:
dec = client.decode_node_name(name)
except Exception:
continue
if dec.get("is_refinement") and norm(dec.get("decoded", "")) == tgt:
return w
return 0.0
def display(client: JDMClient, name: str) -> str:
"""Nom lisible d'un nœud (décodage refinement ; sûr en cas d'erreur)."""
try:
return client.decode_node_name(name).get("decoded", name)
except Exception:
return name
def topk_positive(rows: list[tuple[str, float, int]], k: int
) -> list[tuple[str, float, int]]:
"""Trie les triplets par poids positif décroissant, tronque à `k`."""
pos = [(n, w, rid) for n, w, rid in rows if w > 0]
pos.sort(key=lambda x: -x[1])
return pos[:max(1, k)]
def generics(client: JDMClient, budget: LookupBudget, mem: dict,
term: str, k: int,
relations: Optional[tuple[str, ...]] = None
) -> list[tuple[str, float, str]]:
"""Sur-ensembles d'un terme : hyperonymes (r_isa) + synonymes (r_syn).
Renvoie `[(name, w, via_relation), ...]` dédupliqué, top-k par poids.
"""
from jdm_agent.inference.constants import GENERIC_RELATIONS
rels = relations or GENERIC_RELATIONS
collected: list[tuple[str, float, str]] = []
for rel in rels:
for name, w, _rid in outgoing(client, budget, mem, term, rel):
if w > 0:
collected.append((name, w, rel))
collected.sort(key=lambda x: -x[1])
seen: set[str] = set()
uniq: list[tuple[str, float, str]] = []
for name, w, rel in collected:
kk = norm(name)
if kk in seen or kk == norm(term):
continue
seen.add(kk)
uniq.append((name, w, rel))
return uniq[:max(1, k)]
|