Project_Gen_AI / rag_gpt2.py
Darryl237's picture
Initial commit - Mini chatbot RAG
eff6ee6
Raw
History Blame Contribute Delete
5.35 kB
from pathlib import Path
from dataclasses import dataclass
from typing import List
import re
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
@dataclass
class RetrievedChunk:
text: str
filename: str
score: float
class RAGEngineGPT2:
def __init__(
self,
corpus_dir: str = "corpus/",
chunk_size: int = 450,
chunk_overlap: int = 80,
min_score: float = 0.05
):
self.corpus_dir = Path(corpus_dir)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_score = min_score
self.chunks: List[str] = []
self.chunk_sources: List[str] = []
self.vectorizer = TfidfVectorizer(
lowercase=True,
strip_accents="unicode",
ngram_range=(1, 2),
max_features=10000
)
self._load_and_chunk_corpus()
self._index()
def _clean_text(self, text: str) -> str:
text = re.sub(r"\s+", " ", text)
return text.strip()
def _clean_for_prompt(self, text: str) -> str:
text = re.sub(r"\[[A-ZÉÈÊÀÙÎÏÇ _-]+\]", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def _split_into_chunks(self, text: str) -> List[str]:
text = self._clean_text(text)
if len(text) <= self.chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + self.chunk_size
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start += self.chunk_size - self.chunk_overlap
return chunks
def _load_and_chunk_corpus(self):
if not self.corpus_dir.exists():
raise FileNotFoundError(f"Dossier corpus introuvable : {self.corpus_dir}")
files = sorted(
list(self.corpus_dir.glob("*.txt")) +
list(self.corpus_dir.glob("*.md"))
)
if not files:
raise ValueError("Aucun fichier .txt ou .md trouvé dans le dossier corpus.")
for file_path in files:
try:
text = file_path.read_text(encoding="utf-8")
chunks = self._split_into_chunks(text)
for chunk in chunks:
self.chunks.append(chunk)
self.chunk_sources.append(file_path.name)
except Exception as e:
print(f"Fichier ignoré : {file_path.name} | Erreur : {e}")
if not self.chunks:
raise ValueError("Aucun passage exploitable trouvé dans le corpus.")
print(f"{len(files)} fichiers chargés.")
print(f"{len(self.chunks)} passages indexés.")
def _index(self):
self.tfidf_matrix = self.vectorizer.fit_transform(self.chunks)
def _apply_domain_boost(self, query: str, scores):
query_lower = query.lower()
for i in range(len(scores)):
filename = self.chunk_sources[i].lower()
if ("médecine" in query_lower or "medecine" in query_lower or "santé" in query_lower) and "ia_medecine" in filename:
scores[i] += 0.08
if ("finance" in query_lower or "banque" in query_lower or "crédit" in query_lower or "credit" in query_lower) and "ia_finance" in filename:
scores[i] += 0.08
if ("recrutement" in query_lower or "rh" in query_lower) and "ia_rh" in filename:
scores[i] += 0.08
if ("éducation" in query_lower or "education" in query_lower or "école" in query_lower or "ecole" in query_lower or "élève" in query_lower or "eleve" in query_lower) and "ia_education" in filename:
scores[i] += 0.08
if ("cybersécurité" in query_lower or "cybersecurite" in query_lower or "cyber" in query_lower) and "ia_cybersecurite" in filename:
scores[i] += 0.08
if ("art" in query_lower or "image" in query_lower or "création" in query_lower or "creation" in query_lower) and "ia_art" in filename:
scores[i] += 0.08
return scores
def search(self, query: str, top_k: int = 1) -> List[RetrievedChunk]:
query = self._clean_text(query)
if not query:
return []
q_vec = self.vectorizer.transform([query])
scores = cosine_similarity(q_vec, self.tfidf_matrix)[0]
scores = self._apply_domain_boost(query, scores)
top_idx = np.argsort(scores)[::-1][:top_k]
results = []
for i in top_idx:
score = float(scores[i])
if score >= self.min_score:
results.append(
RetrievedChunk(
text=self.chunks[i],
filename=self.chunk_sources[i],
score=score
)
)
return results
def build_prompt(self, query: str, top_k: int = 1) -> str:
results = self.search(query, top_k=top_k)
if not results:
return ""
context = "\n\n".join(
[self._clean_for_prompt(result.text) for result in results]
)
return f"""
Contexte :
{context}
Question :
{query}
Réponse en français :
"""