# src/list.py from __future__ import annotations from dataclasses import dataclass from typing import Dict, List, Any, Callable import re # ----------------------------- # Configuration algorithmique # ----------------------------- @dataclass class ListConfig: # n-grams max_ngram: int = 5 min_doc_freq: int = 2 # scoring window: int = 80 score_threshold: float = 60.0 # output top_k: int = 15 # ----------------------------- # Normalisation & tokens # ----------------------------- def normalize(text: str) -> str: text = (text or "").lower() text = re.sub(r"[’']", " ", text) text = re.sub(r"[^a-zàâçéèêëîïôûùüÿñæœ\s]", " ", text) text = re.sub(r"\s+", " ", text).strip() return text def tokenize(text: str) -> List[str]: return text.split() def generate_ngrams(tokens: List[str], max_ngram: int) -> List[str]: ngrams: List[str] = [] n = len(tokens) for size in range(1, min(max_ngram, n) + 1): for i in range(n - size + 1): ngrams.append(" ".join(tokens[i : i + size])) return ngrams # ----------------------------- # Phrase pivot (corpus-driven) # ----------------------------- def extract_phrase_pivot(query: str, articles: Dict[str, str], cfg: ListConfig) -> str | None: q_norm = normalize(query) tokens = tokenize(q_norm) candidates = generate_ngrams(tokens, cfg.max_ngram) stats = [] for seg in candidates: seg_re = re.compile(rf"\b{re.escape(seg)}\b") doc_freq = 0 for text in articles.values(): if seg_re.search(normalize(text)): doc_freq += 1 if doc_freq >= cfg.min_doc_freq: # longueur = nb de mots (préférence aux pivots plus spécifiques) stats.append((seg, len(seg.split()), doc_freq)) if not stats: return None # priorité : longueur > doc_freq stats.sort(key=lambda x: (x[1], x[2]), reverse=True) return stats[0][0] # ----------------------------- # Centralité normative # ----------------------------- def centrality_factor(text: str, pivot: str) -> float: text_norm = normalize(text) pivot_norm = normalize(pivot) idx = text_norm.find(pivot_norm) if idx == -1: return 0.0 pos = idx / max(len(text_norm), 1) if pos <= 0.20: return 1.4 if pos <= 0.40: return 1.2 if pos <= 0.60: return 1.0 if pos <= 0.80: return 0.8 return 0.6 # ----------------------------- # Score lexical # ----------------------------- def lexical_score(text: str, pivot: str, window: int) -> int: text_norm = normalize(text) pivot_norm = normalize(pivot) score = 0 for m in re.finditer(rf"\b{re.escape(pivot_norm)}\b", text_norm): start = max(0, m.start() - window) end = min(len(text_norm), m.end() + window) score += (end - start) return score # ----------------------------- # Algorithme LIST (coeur) # ----------------------------- def list_articles_lexical(query: str, articles: Dict[str, str], cfg: ListConfig) -> List[str]: pivot = extract_phrase_pivot(query, articles, cfg) if not pivot: return [] scored: List[tuple[str, float]] = [] for aid, text in articles.items(): s_lex = lexical_score(text, pivot, cfg.window) if s_lex == 0: continue factor = centrality_factor(text, pivot) s_final = s_lex * factor if s_final >= cfg.score_threshold: scored.append((aid, s_final)) scored.sort(key=lambda x: x[1], reverse=True) return [aid for aid, _ in scored[: cfg.top_k]] # ----------------------------- # API attendue par rag_core.py # ----------------------------- def list_articles( query: str, articles: Dict[str, str], vs: Any = None, # fallback possible plus tard normalize_article_id: Callable[[str], str] | None = None, list_triggers: List[str] | None = None, cfg: ListConfig | None = None, ) -> Dict[str, Any]: """ Signature compatible avec rag_core.py. Pour l'instant : lexical-only (ton algo). Le paramètre `vs` est accepté pour compatibilité, mais pas utilisé ici. """ cfg = cfg or ListConfig() q = (query or "").strip() if not q: return {"mode": "LIST", "answer": "", "articles": []} ids = list_articles_lexical(q, articles, cfg) return { "mode": "LIST", "answer": "", "articles": ids, }