FabIndy's picture
Fix imports: use src package everywhere
247f65e
# src/list.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Any, Callable
import re
# -----------------------------
# Configuration algorithmique
# -----------------------------
@dataclass
class ListConfig:
# n-grams
max_ngram: int = 5
min_doc_freq: int = 2
# scoring
window: int = 80
score_threshold: float = 60.0
# output
top_k: int = 15
# -----------------------------
# Normalisation & tokens
# -----------------------------
def normalize(text: str) -> str:
text = (text or "").lower()
text = re.sub(r"[’']", " ", text)
text = re.sub(r"[^a-zàâçéèêëîïôûùüÿñæœ\s]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def tokenize(text: str) -> List[str]:
return text.split()
def generate_ngrams(tokens: List[str], max_ngram: int) -> List[str]:
ngrams: List[str] = []
n = len(tokens)
for size in range(1, min(max_ngram, n) + 1):
for i in range(n - size + 1):
ngrams.append(" ".join(tokens[i : i + size]))
return ngrams
# -----------------------------
# Phrase pivot (corpus-driven)
# -----------------------------
def extract_phrase_pivot(query: str, articles: Dict[str, str], cfg: ListConfig) -> str | None:
q_norm = normalize(query)
tokens = tokenize(q_norm)
candidates = generate_ngrams(tokens, cfg.max_ngram)
stats = []
for seg in candidates:
seg_re = re.compile(rf"\b{re.escape(seg)}\b")
doc_freq = 0
for text in articles.values():
if seg_re.search(normalize(text)):
doc_freq += 1
if doc_freq >= cfg.min_doc_freq:
# longueur = nb de mots (préférence aux pivots plus spécifiques)
stats.append((seg, len(seg.split()), doc_freq))
if not stats:
return None
# priorité : longueur > doc_freq
stats.sort(key=lambda x: (x[1], x[2]), reverse=True)
return stats[0][0]
# -----------------------------
# Centralité normative
# -----------------------------
def centrality_factor(text: str, pivot: str) -> float:
text_norm = normalize(text)
pivot_norm = normalize(pivot)
idx = text_norm.find(pivot_norm)
if idx == -1:
return 0.0
pos = idx / max(len(text_norm), 1)
if pos <= 0.20:
return 1.4
if pos <= 0.40:
return 1.2
if pos <= 0.60:
return 1.0
if pos <= 0.80:
return 0.8
return 0.6
# -----------------------------
# Score lexical
# -----------------------------
def lexical_score(text: str, pivot: str, window: int) -> int:
text_norm = normalize(text)
pivot_norm = normalize(pivot)
score = 0
for m in re.finditer(rf"\b{re.escape(pivot_norm)}\b", text_norm):
start = max(0, m.start() - window)
end = min(len(text_norm), m.end() + window)
score += (end - start)
return score
# -----------------------------
# Algorithme LIST (coeur)
# -----------------------------
def list_articles_lexical(query: str, articles: Dict[str, str], cfg: ListConfig) -> List[str]:
pivot = extract_phrase_pivot(query, articles, cfg)
if not pivot:
return []
scored: List[tuple[str, float]] = []
for aid, text in articles.items():
s_lex = lexical_score(text, pivot, cfg.window)
if s_lex == 0:
continue
factor = centrality_factor(text, pivot)
s_final = s_lex * factor
if s_final >= cfg.score_threshold:
scored.append((aid, s_final))
scored.sort(key=lambda x: x[1], reverse=True)
return [aid for aid, _ in scored[: cfg.top_k]]
# -----------------------------
# API attendue par rag_core.py
# -----------------------------
def list_articles(
query: str,
articles: Dict[str, str],
vs: Any = None, # fallback possible plus tard
normalize_article_id: Callable[[str], str] | None = None,
list_triggers: List[str] | None = None,
cfg: ListConfig | None = None,
) -> Dict[str, Any]:
"""
Signature compatible avec rag_core.py.
Pour l'instant : lexical-only (ton algo).
Le paramètre `vs` est accepté pour compatibilité, mais pas utilisé ici.
"""
cfg = cfg or ListConfig()
q = (query or "").strip()
if not q:
return {"mode": "LIST", "answer": "", "articles": []}
ids = list_articles_lexical(q, articles, cfg)
return {
"mode": "LIST",
"answer": "",
"articles": ids,
}