Spaces:
Sleeping
Sleeping
File size: 5,091 Bytes
a3d26e6 e496d26 a3d26e6 e496d26 a3d26e6 59ba192 a3d26e6 e496d26 a3d26e6 e2a89f5 59ba192 a3d26e6 e496d26 a3d26e6 e496d26 0e34878 a3d26e6 0e34878 a3d26e6 0e34878 a3d26e6 b10792b a3d26e6 b10792b a3d26e6 b10792b e496d26 b10792b e496d26 b10792b a3d26e6 b10792b a3d26e6 6fe1dfe a3d26e6 476a632 89033ee a3d26e6 7335750 e2a89f5 7335750 a3d26e6 89033ee 7335750 a3d26e6 3a8ddd8 a3d26e6 c7321da a3d26e6 3a8ddd8 59ba192 3a8ddd8 59ba192 476a632 a3d26e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import os
import faiss
from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_core.documents import Document
from util import getYamlConfig
# load .env in local dev
load_dotenv()
env_api_key = os.environ.get("MISTRAL_API_KEY")
class Rag:
def __init__(self, vectore_store=None):
print("Nouvelle instance de Rag créée")
self.document_vector_store = None
self.retriever = None
self.chain = None
self.readableModelName = ""
self.documents = []
self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=300, separators="\n\n", length_function=len)
base_template = getYamlConfig()['prompt_template']
self.prompt = PromptTemplate.from_template(base_template)
self.reset_faiss_store()
self.vector_store = vectore_store
def reset_faiss_store(self):
""" Initialise un FAISS vide avec la bonne dimension """
# Ajouter un document à l'index FAISS
docs = [ Document(page_content=" ") ]
self.document_vector_store = FAISS.from_documents(docs, self.embedding)
# Vider l'index FAISS
self.document_vector_store.index.reset()
# Vérifier que l'index est vidé
print(f"Nombre de vecteurs après reset: {self.document_vector_store.index.ntotal}")
def setModel(self, model, readableModelName = ""):
self.model = model
self.readableModelName = readableModelName
def getReadableModel(self):
return self.readableModelName
def ingestToDb(self, file_path: str, filename: str):
docs = PyPDFLoader(file_path=file_path).load()
# Extract all text from the document
text = ""
for page in docs:
text += page.page_content
# Split the text into chunks
chunks = self.text_splitter.split_text(text)
return self.vector_store.addDoc(filename=filename, text_chunks=chunks, embedding=self.embedding)
def getDbFiles(self):
return self.vector_store.getDocs()
def ingest(self, pdf_file_path: str):
docs = PyPDFLoader(file_path=pdf_file_path).load()
chunks = self.text_splitter.split_documents(docs)
self.documents.extend(chunks)
if self.document_vector_store:
print(f"Nombre de documents indexés dans FAISS : {self.document_vector_store.index.ntotal}")
else:
print("No document_vectore")
self.document_vector_store = FAISS.from_documents(self.documents, self.embedding)
print(f"Après ingestion, FAISS contient {self.document_vector_store.index.ntotal} documents.")
self.retriever = self.document_vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 10,
"score_threshold": 0.5,
},
)
def ask(self, query: str, prompt_system: str, messages: list, variables: list = None):
self.chain = self.prompt | self.model | StrOutputParser()
queryForRetriever = query
print(f"Nb messages : {len(messages)}")
if len(messages) == 2 :
queryForRetriever = messages[0].content + "\n" + query ;
# Retrieve the context document
if self.retriever is None:
documentContext = ''
else:
documentContext = self.retriever.invoke(queryForRetriever)
# Dictionnaire de base avec les variables principales
chain_input = {
"query": query,
"documentContext": documentContext,
"prompt_system": prompt_system,
"messages": messages
}
# Suppression des valeurs nulles (facultatif)
chain_input = {k: v for k, v in chain_input.items() if v is not None}
# Si des variables sous forme de liste sont fournies
if variables:
# Convertir la liste en dictionnaire avec 'key' comme clé et 'value' comme valeur
extra_vars = {item['key']: item['value'] for item in variables if 'key' in item and 'value' in item}
# Fusionner avec chain_input
chain_input.update(extra_vars)
return self.chain.stream(chain_input)
def clear(self):
self.document_vector_store = None
self.vector_store = None
self.retriever = None
self.chain = None |