aim / app /rag.py
Rochane's picture
Add file upload RAG: requirements + rag.py
45620de
"""RAG layer: load corpus, chunk, embed, and retrieve."""
import os
import shutil
import tempfile
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
CORPUS_DIR = os.environ.get("CORPUS_DIR", "corpus")
CHROMA_DIR = os.environ.get("CHROMA_DIR", "chroma_data")
CHUNK_SIZE = 500 # approximate token count (words used as proxy)
CHUNK_OVERLAP = 50
TOP_K = 3
_model: SentenceTransformer | None = None
_collection: chromadb.Collection | None = None
_client: chromadb.ClientAPI | None = None
SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".pptx", ".ppt"}
def _get_model() -> SentenceTransformer:
global _model
if _model is None:
_model = SentenceTransformer("all-MiniLM-L6-v2")
return _model
def _get_client() -> chromadb.ClientAPI:
global _client
if _client is None:
_client = chromadb.Client(chromadb.config.Settings(
persist_directory=CHROMA_DIR,
anonymized_telemetry=False,
is_persistent=True,
))
return _client
def _approximate_token_split(text: str, size: int, overlap: int) -> list[str]:
"""Split text into chunks of approximately `size` words with `overlap`."""
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + size
chunk = " ".join(words[start:end])
chunks.append(chunk)
start = end - overlap
return chunks
def _read_txt(path: str) -> str:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
def _read_pdf(path: str) -> str:
try:
from PyPDF2 import PdfReader
reader = PdfReader(path)
pages = [page.extract_text() or "" for page in reader.pages]
return "\n".join(pages)
except Exception:
return ""
def _read_pptx(path: str) -> str:
try:
from pptx import Presentation
prs = Presentation(path)
texts = []
for slide in prs.slides:
for shape in slide.shapes:
if shape.has_text_frame:
for para in shape.text_frame.paragraphs:
text = para.text.strip()
if text:
texts.append(text)
return "\n".join(texts)
except Exception:
return ""
def _read_file(path: str) -> str:
"""Read a file based on its extension."""
lower = path.lower()
if lower.endswith(".txt"):
return _read_txt(path)
elif lower.endswith(".pdf"):
return _read_pdf(path)
elif lower.endswith((".pptx", ".ppt")):
return _read_pptx(path)
return ""
def _extract_zip(zip_bytes: bytes) -> list[tuple[str, bytes]]:
"""Extract supported files from a ZIP archive. Returns list of (filename, content)."""
results = []
with tempfile.TemporaryDirectory() as tmpdir:
zip_path = os.path.join(tmpdir, "archive.zip")
with open(zip_path, "wb") as f:
f.write(zip_bytes)
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(tmpdir)
for root, dirs, files in os.walk(tmpdir):
# Skip __MACOSX and hidden directories
dirs[:] = [d for d in dirs if not d.startswith((".", "__"))]
for fname in files:
if fname.startswith("."):
continue
ext = os.path.splitext(fname)[1].lower()
if ext in SUPPORTED_EXTENSIONS:
fpath = os.path.join(root, fname)
with open(fpath, "rb") as f:
results.append((fname, f.read()))
return results
def load_corpus() -> None:
"""Load all supported files from corpus, chunk, embed, store in ChromaDB."""
global _collection
client = _get_client()
try:
client.delete_collection("corpus")
except Exception:
pass
_collection = client.create_collection(
name="corpus",
metadata={"hnsw:space": "cosine"},
)
model = _get_model()
all_chunks: list[str] = []
all_ids: list[str] = []
all_meta: list[dict] = []
if not os.path.isdir(CORPUS_DIR):
os.makedirs(CORPUS_DIR, exist_ok=True)
return
for filename in sorted(os.listdir(CORPUS_DIR)):
filepath = os.path.join(CORPUS_DIR, filename)
ext = os.path.splitext(filename)[1].lower()
if ext not in SUPPORTED_EXTENSIONS:
continue
text = _read_file(filepath)
if not text.strip():
continue
chunks = _approximate_token_split(text, CHUNK_SIZE, CHUNK_OVERLAP)
for i, chunk in enumerate(chunks):
chunk_id = f"{filename}_{i}"
all_chunks.append(chunk)
all_ids.append(chunk_id)
all_meta.append({"source": filename, "chunk_index": i})
if all_chunks:
embeddings = model.encode(all_chunks).tolist()
_collection.add(
ids=all_ids,
embeddings=embeddings,
documents=all_chunks,
metadatas=all_meta,
)
def _add_single_file(filename: str, file_bytes: bytes) -> dict:
"""Process a single file: save to corpus and embed."""
global _collection
os.makedirs(CORPUS_DIR, exist_ok=True)
filepath = os.path.join(CORPUS_DIR, filename)
with open(filepath, "wb") as f:
f.write(file_bytes)
text = _read_file(filepath)
if not text.strip():
os.remove(filepath)
return {"filename": filename, "status": "error", "message": "Texte non extractible"}
chunks = _approximate_token_split(text, CHUNK_SIZE, CHUNK_OVERLAP)
model = _get_model()
if _collection is None:
load_corpus()
return {"filename": filename, "status": "ok", "chunks": len(chunks)}
# Remove old chunks from same file if re-uploading
try:
existing = _collection.get(where={"source": filename})
if existing["ids"]:
_collection.delete(ids=existing["ids"])
except Exception:
pass
chunk_ids = [f"{filename}_{i}" for i in range(len(chunks))]
metas = [{"source": filename, "chunk_index": i} for i in range(len(chunks))]
embeddings = model.encode(chunks).tolist()
_collection.add(
ids=chunk_ids,
embeddings=embeddings,
documents=chunks,
metadatas=metas,
)
return {"filename": filename, "status": "ok", "chunks": len(chunks)}
def add_documents(files: list[tuple[str, bytes]]) -> list[dict]:
"""Add one or more uploaded files. Handles ZIP extraction automatically."""
results = []
for filename, file_bytes in files:
if filename.lower().endswith(".zip"):
extracted = _extract_zip(file_bytes)
if not extracted:
results.append({"filename": filename, "status": "error",
"message": "Aucun fichier supporte trouve dans le ZIP"})
continue
for inner_name, inner_bytes in extracted:
results.append(_add_single_file(inner_name, inner_bytes))
else:
results.append(_add_single_file(filename, file_bytes))
return results
def list_documents() -> list[dict]:
"""List all documents in the corpus directory."""
docs = []
if not os.path.isdir(CORPUS_DIR):
return docs
for filename in sorted(os.listdir(CORPUS_DIR)):
ext = os.path.splitext(filename)[1].lower()
if ext in SUPPORTED_EXTENSIONS:
filepath = os.path.join(CORPUS_DIR, filename)
size = os.path.getsize(filepath)
docs.append({"filename": filename, "size": size})
return docs
def delete_document(filename: str) -> bool:
"""Delete a document from corpus and its embeddings."""
global _collection
filepath = os.path.join(CORPUS_DIR, filename)
if not os.path.isfile(filepath):
return False
os.remove(filepath)
if _collection is not None:
try:
existing = _collection.get(where={"source": filename})
if existing["ids"]:
_collection.delete(ids=existing["ids"])
except Exception:
pass
return True
def retrieve(query: str, top_k: int = TOP_K) -> list[str]:
"""Retrieve the top_k most relevant chunks for a query."""
if _collection is None or _collection.count() == 0:
return []
model = _get_model()
query_embedding = model.encode([query]).tolist()
results = _collection.query(
query_embeddings=query_embedding,
n_results=min(top_k, _collection.count()),
)
return results["documents"][0] if results["documents"] else []