FabIndy commited on
Commit
247f65e
·
1 Parent(s): 456aba5

Fix imports: use src package everywhere

Browse files
Files changed (5) hide show
  1. src/__init__.py +0 -0
  2. src/list.py +69 -26
  3. src/rag_core.py +12 -7
  4. src/resources.py +2 -1
  5. src/utils.py +3 -1
src/__init__.py ADDED
File without changes
src/list.py CHANGED
@@ -1,6 +1,9 @@
1
  # src/list.py
 
2
  from __future__ import annotations
3
- from typing import Dict, List, Tuple
 
 
4
  import re
5
 
6
 
@@ -8,11 +11,18 @@ import re
8
  # Configuration algorithmique
9
  # -----------------------------
10
 
11
- MAX_NGRAM = 5
12
- MIN_DOC_FREQ = 2
 
 
 
13
 
14
- WINDOW = 80
15
- SCORE_THRESHOLD = 60
 
 
 
 
16
 
17
 
18
  # -----------------------------
@@ -20,7 +30,7 @@ SCORE_THRESHOLD = 60
20
  # -----------------------------
21
 
22
  def normalize(text: str) -> str:
23
- text = text.lower()
24
  text = re.sub(r"[’']", " ", text)
25
  text = re.sub(r"[^a-zàâçéèêëîïôûùüÿñæœ\s]", " ", text)
26
  text = re.sub(r"\s+", " ", text).strip()
@@ -31,13 +41,12 @@ def tokenize(text: str) -> List[str]:
31
  return text.split()
32
 
33
 
34
- def generate_ngrams(tokens: List[str]) -> List[Tuple[str, int]]:
35
- ngrams = []
36
  n = len(tokens)
37
- for size in range(1, min(MAX_NGRAM, n) + 1):
38
  for i in range(n - size + 1):
39
- seg = " ".join(tokens[i : i + size])
40
- ngrams.append((seg, size))
41
  return ngrams
42
 
43
 
@@ -45,14 +54,14 @@ def generate_ngrams(tokens: List[str]) -> List[Tuple[str, int]]:
45
  # Phrase pivot (corpus-driven)
46
  # -----------------------------
47
 
48
- def extract_phrase_pivot(query: str, articles: Dict[str, str]) -> str | None:
49
  q_norm = normalize(query)
50
  tokens = tokenize(q_norm)
51
- candidates = generate_ngrams(tokens)
52
 
53
  stats = []
54
 
55
- for seg, size in candidates:
56
  seg_re = re.compile(rf"\b{re.escape(seg)}\b")
57
  doc_freq = 0
58
 
@@ -60,8 +69,9 @@ def extract_phrase_pivot(query: str, articles: Dict[str, str]) -> str | None:
60
  if seg_re.search(normalize(text)):
61
  doc_freq += 1
62
 
63
- if doc_freq >= MIN_DOC_FREQ:
64
- stats.append((seg, size, doc_freq))
 
65
 
66
  if not stats:
67
  return None
@@ -100,40 +110,73 @@ def centrality_factor(text: str, pivot: str) -> float:
100
  # Score lexical
101
  # -----------------------------
102
 
103
- def lexical_score(text: str, pivot: str) -> int:
104
  text_norm = normalize(text)
105
  pivot_norm = normalize(pivot)
106
 
107
  score = 0
108
  for m in re.finditer(rf"\b{re.escape(pivot_norm)}\b", text_norm):
109
- start = max(0, m.start() - WINDOW)
110
- end = min(len(text_norm), m.end() + WINDOW)
111
  score += (end - start)
112
 
113
  return score
114
 
115
 
116
  # -----------------------------
117
- # API principale LIST
118
  # -----------------------------
119
 
120
- def list_articles(query: str, articles: Dict[str, str], top_k: int = 15) -> List[str]:
121
- pivot = extract_phrase_pivot(query, articles)
122
  if not pivot:
123
  return []
124
 
125
- scored = []
126
 
127
  for aid, text in articles.items():
128
- s_lex = lexical_score(text, pivot)
129
  if s_lex == 0:
130
  continue
131
 
132
  factor = centrality_factor(text, pivot)
133
  s_final = s_lex * factor
134
 
135
- if s_final >= SCORE_THRESHOLD:
136
  scored.append((aid, s_final))
137
 
138
  scored.sort(key=lambda x: x[1], reverse=True)
139
- return [aid for aid, _ in scored[:top_k]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/list.py
2
+
3
  from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Any, Callable
7
  import re
8
 
9
 
 
11
  # Configuration algorithmique
12
  # -----------------------------
13
 
14
+ @dataclass
15
+ class ListConfig:
16
+ # n-grams
17
+ max_ngram: int = 5
18
+ min_doc_freq: int = 2
19
 
20
+ # scoring
21
+ window: int = 80
22
+ score_threshold: float = 60.0
23
+
24
+ # output
25
+ top_k: int = 15
26
 
27
 
28
  # -----------------------------
 
30
  # -----------------------------
31
 
32
  def normalize(text: str) -> str:
33
+ text = (text or "").lower()
34
  text = re.sub(r"[’']", " ", text)
35
  text = re.sub(r"[^a-zàâçéèêëîïôûùüÿñæœ\s]", " ", text)
36
  text = re.sub(r"\s+", " ", text).strip()
 
41
  return text.split()
42
 
43
 
44
+ def generate_ngrams(tokens: List[str], max_ngram: int) -> List[str]:
45
+ ngrams: List[str] = []
46
  n = len(tokens)
47
+ for size in range(1, min(max_ngram, n) + 1):
48
  for i in range(n - size + 1):
49
+ ngrams.append(" ".join(tokens[i : i + size]))
 
50
  return ngrams
51
 
52
 
 
54
  # Phrase pivot (corpus-driven)
55
  # -----------------------------
56
 
57
+ def extract_phrase_pivot(query: str, articles: Dict[str, str], cfg: ListConfig) -> str | None:
58
  q_norm = normalize(query)
59
  tokens = tokenize(q_norm)
60
+ candidates = generate_ngrams(tokens, cfg.max_ngram)
61
 
62
  stats = []
63
 
64
+ for seg in candidates:
65
  seg_re = re.compile(rf"\b{re.escape(seg)}\b")
66
  doc_freq = 0
67
 
 
69
  if seg_re.search(normalize(text)):
70
  doc_freq += 1
71
 
72
+ if doc_freq >= cfg.min_doc_freq:
73
+ # longueur = nb de mots (préférence aux pivots plus spécifiques)
74
+ stats.append((seg, len(seg.split()), doc_freq))
75
 
76
  if not stats:
77
  return None
 
110
  # Score lexical
111
  # -----------------------------
112
 
113
+ def lexical_score(text: str, pivot: str, window: int) -> int:
114
  text_norm = normalize(text)
115
  pivot_norm = normalize(pivot)
116
 
117
  score = 0
118
  for m in re.finditer(rf"\b{re.escape(pivot_norm)}\b", text_norm):
119
+ start = max(0, m.start() - window)
120
+ end = min(len(text_norm), m.end() + window)
121
  score += (end - start)
122
 
123
  return score
124
 
125
 
126
  # -----------------------------
127
+ # Algorithme LIST (coeur)
128
  # -----------------------------
129
 
130
+ def list_articles_lexical(query: str, articles: Dict[str, str], cfg: ListConfig) -> List[str]:
131
+ pivot = extract_phrase_pivot(query, articles, cfg)
132
  if not pivot:
133
  return []
134
 
135
+ scored: List[tuple[str, float]] = []
136
 
137
  for aid, text in articles.items():
138
+ s_lex = lexical_score(text, pivot, cfg.window)
139
  if s_lex == 0:
140
  continue
141
 
142
  factor = centrality_factor(text, pivot)
143
  s_final = s_lex * factor
144
 
145
+ if s_final >= cfg.score_threshold:
146
  scored.append((aid, s_final))
147
 
148
  scored.sort(key=lambda x: x[1], reverse=True)
149
+ return [aid for aid, _ in scored[: cfg.top_k]]
150
+
151
+
152
+ # -----------------------------
153
+ # API attendue par rag_core.py
154
+ # -----------------------------
155
+
156
+ def list_articles(
157
+ query: str,
158
+ articles: Dict[str, str],
159
+ vs: Any = None, # fallback possible plus tard
160
+ normalize_article_id: Callable[[str], str] | None = None,
161
+ list_triggers: List[str] | None = None,
162
+ cfg: ListConfig | None = None,
163
+ ) -> Dict[str, Any]:
164
+ """
165
+ Signature compatible avec rag_core.py.
166
+
167
+ Pour l'instant : lexical-only (ton algo).
168
+ Le paramètre `vs` est accepté pour compatibilité, mais pas utilisé ici.
169
+ """
170
+ cfg = cfg or ListConfig()
171
+
172
+ q = (query or "").strip()
173
+ if not q:
174
+ return {"mode": "LIST", "answer": "", "articles": []}
175
+
176
+ ids = list_articles_lexical(q, articles, cfg)
177
+
178
+ return {
179
+ "mode": "LIST",
180
+ "answer": "",
181
+ "articles": ids,
182
+ }
src/rag_core.py CHANGED
@@ -3,12 +3,15 @@ from __future__ import annotations
3
  from typing import Dict, Any, List
4
  import json
5
 
6
- import list as list_mode
7
- import fulltext as fulltext_mode
8
- import synthesis as synthesis_mode
9
- import qa as qa_mode
10
 
11
- from config import (
 
 
 
 
 
 
 
12
  CHUNKS_PATH,
13
  LIST_TRIGGERS,
14
  REFUSAL,
@@ -19,14 +22,16 @@ from config import (
19
  QA_MAX_TOKENS,
20
  QA_TEMPERATURE,
21
  )
22
- from utils import (
23
  normalize_article_id,
24
  extract_article_id,
25
  is_list_request,
26
  is_fulltext_request,
27
  is_synthesis_request,
28
  )
29
- from resources import get_vectorstore, get_llm
 
 
30
 
31
 
32
  # ====================
 
3
  from typing import Dict, Any, List
4
  import json
5
 
 
 
 
 
6
 
7
+ from src import list as list_mode
8
+ from src import fulltext as fulltext_mode
9
+ from src import synthesis as synthesis_mode
10
+ from src import qa as qa_mode
11
+ from src import resources
12
+
13
+
14
+ from src.config import (
15
  CHUNKS_PATH,
16
  LIST_TRIGGERS,
17
  REFUSAL,
 
22
  QA_MAX_TOKENS,
23
  QA_TEMPERATURE,
24
  )
25
+ from src.utils import (
26
  normalize_article_id,
27
  extract_article_id,
28
  is_list_request,
29
  is_fulltext_request,
30
  is_synthesis_request,
31
  )
32
+
33
+ from src.resources import get_vectorstore, get_llm
34
+
35
 
36
 
37
  # ====================
src/resources.py CHANGED
@@ -6,7 +6,8 @@ from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from llama_cpp import Llama
8
 
9
- from config import DB_DIR, EMBED_MODEL, LLM_MODEL_PATH, LLM_N_CTX, LLM_N_THREADS, LLM_N_BATCH
 
10
 
11
 
12
  _VS: Optional[FAISS] = None
 
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from llama_cpp import Llama
8
 
9
+ from src.config import DB_DIR, EMBED_MODEL, LLM_MODEL_PATH, LLM_N_CTX, LLM_N_THREADS, LLM_N_BATCH
10
+
11
 
12
 
13
  _VS: Optional[FAISS] = None
src/utils.py CHANGED
@@ -1,7 +1,9 @@
1
  # src/utils.py
2
  from __future__ import annotations
3
  from typing import Optional
4
- from config import ARTICLE_ID_RE, LIST_TRIGGERS, FULLTEXT_TRIGGERS, EXPLAIN_TRIGGERS
 
 
5
 
6
 
7
  def normalize_article_id(raw: str) -> str:
 
1
  # src/utils.py
2
  from __future__ import annotations
3
  from typing import Optional
4
+
5
+ from src.config import ARTICLE_ID_RE, LIST_TRIGGERS, FULLTEXT_TRIGGERS, EXPLAIN_TRIGGERS
6
+
7
 
8
 
9
  def normalize_article_id(raw: str) -> str: