Spaces:
Running
Running
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import List | |
| import re | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| class RetrievedChunk: | |
| text: str | |
| filename: str | |
| score: float | |
| class RAGEngineGPT2: | |
| def __init__( | |
| self, | |
| corpus_dir: str = "corpus/", | |
| chunk_size: int = 450, | |
| chunk_overlap: int = 80, | |
| min_score: float = 0.05 | |
| ): | |
| self.corpus_dir = Path(corpus_dir) | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.min_score = min_score | |
| self.chunks: List[str] = [] | |
| self.chunk_sources: List[str] = [] | |
| self.vectorizer = TfidfVectorizer( | |
| lowercase=True, | |
| strip_accents="unicode", | |
| ngram_range=(1, 2), | |
| max_features=10000 | |
| ) | |
| self._load_and_chunk_corpus() | |
| self._index() | |
| def _clean_text(self, text: str) -> str: | |
| text = re.sub(r"\s+", " ", text) | |
| return text.strip() | |
| def _clean_for_prompt(self, text: str) -> str: | |
| text = re.sub(r"\[[A-ZÉÈÊÀÙÎÏÇ _-]+\]", "", text) | |
| text = re.sub(r"\s+", " ", text) | |
| return text.strip() | |
| def _split_into_chunks(self, text: str) -> List[str]: | |
| text = self._clean_text(text) | |
| if len(text) <= self.chunk_size: | |
| return [text] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + self.chunk_size | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| start += self.chunk_size - self.chunk_overlap | |
| return chunks | |
| def _load_and_chunk_corpus(self): | |
| if not self.corpus_dir.exists(): | |
| raise FileNotFoundError(f"Dossier corpus introuvable : {self.corpus_dir}") | |
| files = sorted( | |
| list(self.corpus_dir.glob("*.txt")) + | |
| list(self.corpus_dir.glob("*.md")) | |
| ) | |
| if not files: | |
| raise ValueError("Aucun fichier .txt ou .md trouvé dans le dossier corpus.") | |
| for file_path in files: | |
| try: | |
| text = file_path.read_text(encoding="utf-8") | |
| chunks = self._split_into_chunks(text) | |
| for chunk in chunks: | |
| self.chunks.append(chunk) | |
| self.chunk_sources.append(file_path.name) | |
| except Exception as e: | |
| print(f"Fichier ignoré : {file_path.name} | Erreur : {e}") | |
| if not self.chunks: | |
| raise ValueError("Aucun passage exploitable trouvé dans le corpus.") | |
| print(f"{len(files)} fichiers chargés.") | |
| print(f"{len(self.chunks)} passages indexés.") | |
| def _index(self): | |
| self.tfidf_matrix = self.vectorizer.fit_transform(self.chunks) | |
| def _apply_domain_boost(self, query: str, scores): | |
| query_lower = query.lower() | |
| for i in range(len(scores)): | |
| filename = self.chunk_sources[i].lower() | |
| if ("médecine" in query_lower or "medecine" in query_lower or "santé" in query_lower) and "ia_medecine" in filename: | |
| scores[i] += 0.08 | |
| if ("finance" in query_lower or "banque" in query_lower or "crédit" in query_lower or "credit" in query_lower) and "ia_finance" in filename: | |
| scores[i] += 0.08 | |
| if ("recrutement" in query_lower or "rh" in query_lower) and "ia_rh" in filename: | |
| scores[i] += 0.08 | |
| if ("éducation" in query_lower or "education" in query_lower or "école" in query_lower or "ecole" in query_lower or "élève" in query_lower or "eleve" in query_lower) and "ia_education" in filename: | |
| scores[i] += 0.08 | |
| if ("cybersécurité" in query_lower or "cybersecurite" in query_lower or "cyber" in query_lower) and "ia_cybersecurite" in filename: | |
| scores[i] += 0.08 | |
| if ("art" in query_lower or "image" in query_lower or "création" in query_lower or "creation" in query_lower) and "ia_art" in filename: | |
| scores[i] += 0.08 | |
| return scores | |
| def search(self, query: str, top_k: int = 1) -> List[RetrievedChunk]: | |
| query = self._clean_text(query) | |
| if not query: | |
| return [] | |
| q_vec = self.vectorizer.transform([query]) | |
| scores = cosine_similarity(q_vec, self.tfidf_matrix)[0] | |
| scores = self._apply_domain_boost(query, scores) | |
| top_idx = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| for i in top_idx: | |
| score = float(scores[i]) | |
| if score >= self.min_score: | |
| results.append( | |
| RetrievedChunk( | |
| text=self.chunks[i], | |
| filename=self.chunk_sources[i], | |
| score=score | |
| ) | |
| ) | |
| return results | |
| def build_prompt(self, query: str, top_k: int = 1) -> str: | |
| results = self.search(query, top_k=top_k) | |
| if not results: | |
| return "" | |
| context = "\n\n".join( | |
| [self._clean_for_prompt(result.text) for result in results] | |
| ) | |
| return f""" | |
| Contexte : | |
| {context} | |
| Question : | |
| {query} | |
| Réponse en français : | |
| """ |