Rochane commited on
Commit
45620de
·
1 Parent(s): cc33196

Add file upload RAG: requirements + rag.py

Browse files
Files changed (2) hide show
  1. app/rag.py +173 -12
  2. requirements.txt +1 -0
app/rag.py CHANGED
@@ -1,6 +1,9 @@
1
  """RAG layer: load corpus, chunk, embed, and retrieve."""
2
 
3
  import os
 
 
 
4
 
5
  import chromadb
6
  from sentence_transformers import SentenceTransformer
@@ -13,6 +16,9 @@ TOP_K = 3
13
 
14
  _model: SentenceTransformer | None = None
15
  _collection: chromadb.Collection | None = None
 
 
 
16
 
17
 
18
  def _get_model() -> SentenceTransformer:
@@ -22,6 +28,17 @@ def _get_model() -> SentenceTransformer:
22
  return _model
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  def _approximate_token_split(text: str, size: int, overlap: int) -> list[str]:
26
  """Split text into chunks of approximately `size` words with `overlap`."""
27
  words = text.split()
@@ -36,7 +53,7 @@ def _approximate_token_split(text: str, size: int, overlap: int) -> list[str]:
36
 
37
 
38
  def _read_txt(path: str) -> str:
39
- with open(path, "r", encoding="utf-8") as f:
40
  return f.read()
41
 
42
 
@@ -50,15 +67,65 @@ def _read_pdf(path: str) -> str:
50
  return ""
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def load_corpus() -> None:
54
- """Load all .pdf and .txt files from corpus, chunk, embed, store in ChromaDB."""
55
  global _collection
56
 
57
- client = chromadb.Client(chromadb.config.Settings(
58
- persist_directory=CHROMA_DIR,
59
- anonymized_telemetry=False,
60
- is_persistent=True,
61
- ))
62
 
63
  try:
64
  client.delete_collection("corpus")
@@ -76,17 +143,16 @@ def load_corpus() -> None:
76
  all_meta: list[dict] = []
77
 
78
  if not os.path.isdir(CORPUS_DIR):
 
79
  return
80
 
81
  for filename in sorted(os.listdir(CORPUS_DIR)):
82
  filepath = os.path.join(CORPUS_DIR, filename)
83
- if filename.lower().endswith(".txt"):
84
- text = _read_txt(filepath)
85
- elif filename.lower().endswith(".pdf"):
86
- text = _read_pdf(filepath)
87
- else:
88
  continue
89
 
 
90
  if not text.strip():
91
  continue
92
 
@@ -107,6 +173,101 @@ def load_corpus() -> None:
107
  )
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def retrieve(query: str, top_k: int = TOP_K) -> list[str]:
111
  """Retrieve the top_k most relevant chunks for a query."""
112
  if _collection is None or _collection.count() == 0:
 
1
  """RAG layer: load corpus, chunk, embed, and retrieve."""
2
 
3
  import os
4
+ import shutil
5
+ import tempfile
6
+ import zipfile
7
 
8
  import chromadb
9
  from sentence_transformers import SentenceTransformer
 
16
 
17
  _model: SentenceTransformer | None = None
18
  _collection: chromadb.Collection | None = None
19
+ _client: chromadb.ClientAPI | None = None
20
+
21
+ SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".pptx", ".ppt"}
22
 
23
 
24
  def _get_model() -> SentenceTransformer:
 
28
  return _model
29
 
30
 
31
+ def _get_client() -> chromadb.ClientAPI:
32
+ global _client
33
+ if _client is None:
34
+ _client = chromadb.Client(chromadb.config.Settings(
35
+ persist_directory=CHROMA_DIR,
36
+ anonymized_telemetry=False,
37
+ is_persistent=True,
38
+ ))
39
+ return _client
40
+
41
+
42
  def _approximate_token_split(text: str, size: int, overlap: int) -> list[str]:
43
  """Split text into chunks of approximately `size` words with `overlap`."""
44
  words = text.split()
 
53
 
54
 
55
  def _read_txt(path: str) -> str:
56
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
57
  return f.read()
58
 
59
 
 
67
  return ""
68
 
69
 
70
+ def _read_pptx(path: str) -> str:
71
+ try:
72
+ from pptx import Presentation
73
+ prs = Presentation(path)
74
+ texts = []
75
+ for slide in prs.slides:
76
+ for shape in slide.shapes:
77
+ if shape.has_text_frame:
78
+ for para in shape.text_frame.paragraphs:
79
+ text = para.text.strip()
80
+ if text:
81
+ texts.append(text)
82
+ return "\n".join(texts)
83
+ except Exception:
84
+ return ""
85
+
86
+
87
+ def _read_file(path: str) -> str:
88
+ """Read a file based on its extension."""
89
+ lower = path.lower()
90
+ if lower.endswith(".txt"):
91
+ return _read_txt(path)
92
+ elif lower.endswith(".pdf"):
93
+ return _read_pdf(path)
94
+ elif lower.endswith((".pptx", ".ppt")):
95
+ return _read_pptx(path)
96
+ return ""
97
+
98
+
99
+ def _extract_zip(zip_bytes: bytes) -> list[tuple[str, bytes]]:
100
+ """Extract supported files from a ZIP archive. Returns list of (filename, content)."""
101
+ results = []
102
+ with tempfile.TemporaryDirectory() as tmpdir:
103
+ zip_path = os.path.join(tmpdir, "archive.zip")
104
+ with open(zip_path, "wb") as f:
105
+ f.write(zip_bytes)
106
+
107
+ with zipfile.ZipFile(zip_path, "r") as zf:
108
+ zf.extractall(tmpdir)
109
+
110
+ for root, dirs, files in os.walk(tmpdir):
111
+ # Skip __MACOSX and hidden directories
112
+ dirs[:] = [d for d in dirs if not d.startswith((".", "__"))]
113
+ for fname in files:
114
+ if fname.startswith("."):
115
+ continue
116
+ ext = os.path.splitext(fname)[1].lower()
117
+ if ext in SUPPORTED_EXTENSIONS:
118
+ fpath = os.path.join(root, fname)
119
+ with open(fpath, "rb") as f:
120
+ results.append((fname, f.read()))
121
+ return results
122
+
123
+
124
  def load_corpus() -> None:
125
+ """Load all supported files from corpus, chunk, embed, store in ChromaDB."""
126
  global _collection
127
 
128
+ client = _get_client()
 
 
 
 
129
 
130
  try:
131
  client.delete_collection("corpus")
 
143
  all_meta: list[dict] = []
144
 
145
  if not os.path.isdir(CORPUS_DIR):
146
+ os.makedirs(CORPUS_DIR, exist_ok=True)
147
  return
148
 
149
  for filename in sorted(os.listdir(CORPUS_DIR)):
150
  filepath = os.path.join(CORPUS_DIR, filename)
151
+ ext = os.path.splitext(filename)[1].lower()
152
+ if ext not in SUPPORTED_EXTENSIONS:
 
 
 
153
  continue
154
 
155
+ text = _read_file(filepath)
156
  if not text.strip():
157
  continue
158
 
 
173
  )
174
 
175
 
176
+ def _add_single_file(filename: str, file_bytes: bytes) -> dict:
177
+ """Process a single file: save to corpus and embed."""
178
+ global _collection
179
+
180
+ os.makedirs(CORPUS_DIR, exist_ok=True)
181
+ filepath = os.path.join(CORPUS_DIR, filename)
182
+
183
+ with open(filepath, "wb") as f:
184
+ f.write(file_bytes)
185
+
186
+ text = _read_file(filepath)
187
+ if not text.strip():
188
+ os.remove(filepath)
189
+ return {"filename": filename, "status": "error", "message": "Texte non extractible"}
190
+
191
+ chunks = _approximate_token_split(text, CHUNK_SIZE, CHUNK_OVERLAP)
192
+ model = _get_model()
193
+
194
+ if _collection is None:
195
+ load_corpus()
196
+ return {"filename": filename, "status": "ok", "chunks": len(chunks)}
197
+
198
+ # Remove old chunks from same file if re-uploading
199
+ try:
200
+ existing = _collection.get(where={"source": filename})
201
+ if existing["ids"]:
202
+ _collection.delete(ids=existing["ids"])
203
+ except Exception:
204
+ pass
205
+
206
+ chunk_ids = [f"{filename}_{i}" for i in range(len(chunks))]
207
+ metas = [{"source": filename, "chunk_index": i} for i in range(len(chunks))]
208
+ embeddings = model.encode(chunks).tolist()
209
+
210
+ _collection.add(
211
+ ids=chunk_ids,
212
+ embeddings=embeddings,
213
+ documents=chunks,
214
+ metadatas=metas,
215
+ )
216
+
217
+ return {"filename": filename, "status": "ok", "chunks": len(chunks)}
218
+
219
+
220
+ def add_documents(files: list[tuple[str, bytes]]) -> list[dict]:
221
+ """Add one or more uploaded files. Handles ZIP extraction automatically."""
222
+ results = []
223
+ for filename, file_bytes in files:
224
+ if filename.lower().endswith(".zip"):
225
+ extracted = _extract_zip(file_bytes)
226
+ if not extracted:
227
+ results.append({"filename": filename, "status": "error",
228
+ "message": "Aucun fichier supporte trouve dans le ZIP"})
229
+ continue
230
+ for inner_name, inner_bytes in extracted:
231
+ results.append(_add_single_file(inner_name, inner_bytes))
232
+ else:
233
+ results.append(_add_single_file(filename, file_bytes))
234
+ return results
235
+
236
+
237
+ def list_documents() -> list[dict]:
238
+ """List all documents in the corpus directory."""
239
+ docs = []
240
+ if not os.path.isdir(CORPUS_DIR):
241
+ return docs
242
+ for filename in sorted(os.listdir(CORPUS_DIR)):
243
+ ext = os.path.splitext(filename)[1].lower()
244
+ if ext in SUPPORTED_EXTENSIONS:
245
+ filepath = os.path.join(CORPUS_DIR, filename)
246
+ size = os.path.getsize(filepath)
247
+ docs.append({"filename": filename, "size": size})
248
+ return docs
249
+
250
+
251
+ def delete_document(filename: str) -> bool:
252
+ """Delete a document from corpus and its embeddings."""
253
+ global _collection
254
+ filepath = os.path.join(CORPUS_DIR, filename)
255
+ if not os.path.isfile(filepath):
256
+ return False
257
+
258
+ os.remove(filepath)
259
+
260
+ if _collection is not None:
261
+ try:
262
+ existing = _collection.get(where={"source": filename})
263
+ if existing["ids"]:
264
+ _collection.delete(ids=existing["ids"])
265
+ except Exception:
266
+ pass
267
+
268
+ return True
269
+
270
+
271
  def retrieve(query: str, top_k: int = TOP_K) -> list[str]:
272
  """Retrieve the top_k most relevant chunks for a query."""
273
  if _collection is None or _collection.count() == 0:
requirements.txt CHANGED
@@ -6,4 +6,5 @@ sentence-transformers==3.3.1
6
  pydantic==2.10.4
7
  python-multipart==0.0.20
8
  pypdf2==3.0.1
 
9
  python-dotenv==1.0.1
 
6
  pydantic==2.10.4
7
  python-multipart==0.0.20
8
  pypdf2==3.0.1
9
+ python-pptx==1.0.2
10
  python-dotenv==1.0.1