Ragnar / src /store.py
Eric Marchand
Meilleur gestion du multiclient
f4f9c98
from math import sqrt
import operator
import json
import os
from pathlib import Path
from .astore import AStore
class Document:
'''
Un document est :
une chaîne de caractère, le chunk
une source, livre ou page ou chapitre
un vecteur issu d'un modèle d'embedding
un id, calculé automatiquement par la collection
'''
def __init__(self, chunk:str, source:str, vec:list[float], idd:int):
self.chunk = chunk
self.source = source
self.vec = vec
self.idd = idd
def get_json(self):
json = {
'c':self.chunk,
's':self.source,
'v':self.vec
}
return json
class Collection:
'''
Une Collection est :
un nom
une liste de documents
un id, calculé automatiquement par le Store
Une collection est sauvée dans un fichier idc.col
le nom de la collection
la liste des documents:
chunk
source
vector
'''
def __init__(self,
name:str,
docs:list[Document],
idc:int):
self.name = name
self.docs = docs
self.idc = idc
def add_document(self, chunk:str, source:str, vec:list[float])->Document:
'''
Ajoute un document à la collection
Args:
chunk: le texte du document
source: la source du document (livre, chap...)
vec: la représentation vectorielle du document
Returns:
un Document ou None si problème rencontré
Raise:
si un des paramètres n'est pas défini
'''
if chunk == None or source == None or vec == None:
raise Exception("Document error: chunk, source or vec is None !")
idd:int = len(self.docs) + 1
doc:Document = Document(chunk, source, vec, idd)
self.docs.append(doc)
return doc
def get_length_octets(self)->int:
'''
Return la taille en octets de la collection
'''
if len(self.docs) == 0:
return 0
vector_size = len(self.docs[0])
return len(self.docs) * vector_size * 4 # un float sur 4 octets
@classmethod
def from_disk(self, file_path:str):
'''
Méthode de classe qui renvoie une Collection à partir d'un fichier de la base
Args:
file_path: le chemin vers le fichier
Return:
la Collection
Exception:
si le fichier n'existe pas ou qu'on ne peut pas le lire
'''
if not os.path.exists(file_path):
raise Exception("File {file} doesn't exist !".format(file=file_path))
idc:int = int(Path(file_path).stem)
# print("Collection.from_disk, reading : ", idc)
try:
with open(file_path, "r") as f:
datas = json.load(f)
name:str = datas['name']
docs = []
idd: int = 1
for d in datas['docs']:
doc:Document = Document(d['c'], d['s'], d['v'], idd)
docs.append(doc)
idd += 1
return Collection(name, docs, idc)
except:
raise Exception("Unable to read {file_path} !".format(file_path=file_path))
def save(self, persist_dir:str):
'''
La collection est enregistrée avec le nom idc.col dans le persist_dir
Args:
persist_dir: le chemin du repertoire de la bdd
Exception:
Si on ne peut pas sauver sur le disque
'''
file_path:str = os.path.join(persist_dir, str(self.idc)) + ".col"
# print("Collection.save : ", file_path)
json_object = {
'name':self.name,
'docs':[]
}
for doc in self.docs:
json_object['docs'].append(doc.get_json())
json_object = json.dumps(json_object)
try:
with open(file_path, "w+") as f:
f.write(json_object)
except:
raise Exception("Unable to save the collection {name}, id={id} !".format(name=self.name, id=self.idc))
def delete(self, persist_dir:str)->None:
'''
Supprime la collection de la bdd
Args:
persist_dir: le chemin du repertoire de la bdd
Exception:
Si on ne peut pas supprimer du disque
'''
self.docs.clear()
file_path:str = os.path.join(persist_dir, str(self.idc)) + ".col"
try:
os.remove(file_path)
except:
raise Exception("Unable to delete the collection {name}, id={id} !".format(name=self.name, id=self.idc))
class Store(AStore):
'''
Un store est une liste de collections.
A chaque création, ajout ou suppression d'un élément, la base est sauvée si elle est persistante
Sur le disque, dans store_dir:
Un sous-repertoire par collection, portant le nom de la collection
Dans chaque sous-repertoire d'une collection : la liste des vecteurs
'''
def __init__(self, persist_dir:str):
''' Constructeur de Store
Args:
dir_name: le répertoire persistant de la base de données ou None
Exception:
Dans le cas d'une base persistante:
Impossible de créer le répertoire persistant
Impossible de lire les collections
'''
self.persist_dir = persist_dir
self.collections = []
if persist_dir == None: # store éphémère
pass # Rien à faire
else:
# Charger la liste des collections
try:
self._create_persist_dir()
files = [os.path.join(persist_dir, f) for f in os.listdir(persist_dir) if os.path.isfile(os.path.join(persist_dir, f))]
for f in files:
col: Collection = Collection.from_disk(f)
self.collections.append(col)
except Exception as e:
raise
def reset(self)->None:
'''
Vide la base et l'efface du disque si elle est persistante
Exception:
Dans le cas d'une base persistante:
Impossible de créer le répertoire persistant
Impossible de lire les collections
'''
self.collections = []
if self.persist_dir == None: # store éphémère
pass
else:
try:
# Supprimer les fichiers du disque
if os.path.exists(self.persist_dir):
files = [os.path.join(self.persist_dir, f) for f in os.listdir(self.persist_dir) if os.path.isfile(os.path.join(self.persist_dir, f))]
# print(files)
for f in files:
os.remove(f)
os.rmdir(self.persist_dir)
except Exception as e:
raise
def get_collection_names(self)->list[str]:
return [col.name for col in self.collections]
def print_infos(self)->None:
''' Affiche le nombre de collections et pour chaque collection, affiche son nom et son nombre de documents '''
print("-------- STORE INFOS ---------------")
for col in self.collections:
print(col.name)
# idds = [doc.idd for doc in col.docs]
# print("\t", idds)
print("\tdocuments:", len(col.docs))
print("-------- /STORE INFOS ---------------")
def get_collection(self, collection_name:str)->Collection:
'''
Renvoie la collection dont le nom est 'collection_name' ou None si elle n'existe pas
'''
for col in self.collections:
if col.name == collection_name:
return col
return None
def _create_persist_dir(self):
'''
Recrée le répertoir persistant s'il a disparu après un reset par exemple
Exception:
Si on ne peut pas créer le 'persist_dir'
'''
# Vérifier si le persist_dir existe, sinon le créer
# print("Persist_dir:" + self.persist_dir)
try:
if not os.path.exists(self.persist_dir):
print("Trying to recreate persist_dir", self.persist_dir)
os.mkdir(self.persist_dir)
except:
raise Exception("Unable to create the persit directory: {dir}".format(dir=self.persist_dir))
def create_collection(self, name:str)->Collection:
'''
Crée et renvoie une nouvelle collection vide de documents
Args:
name: le nom de la création à créer
Exception:
Dans le cas d'une base persistante:
Impossible de créer le répertoire persistant
Impossible de sauver la collection
'''
idc:int = len(self.collections) + 1
col:Collection = Collection(name, [], idc)
if self.persist_dir != None:
try:
self._create_persist_dir()
col.save(self.persist_dir)
except:
raise
return col
def add_to_collection(self, collection_name:str, source:str, vectors:list[list[float]], chunks:list[str])->None:
'''
Ajoute une liste de vecteurs à la collection 'collection_name'
Args:
collection_name: le nom de la collection
source: la source unique des chunks, par exemple un nom de fichier, une url ...
vectors: la liste des vecteurs obtenus à l'aide d'un modèle d'embeddings
chunks: la liste des chunks (documents) correspondant aux vecteurs
Exception:
Dans le cas d'une base persistante:
Impossible de créer le répertoire persistant
Impossible de sauver la collection
'''
col:Collection = self.get_collection(collection_name)
if col == None:
col = self.create_collection(collection_name)
self.collections.append(col)
for i in range(len(chunks)):
col.add_document(chunks[i], source, vectors[i])
if self.persist_dir != None:
try:
self._create_persist_dir()
col.save(self.persist_dir)
except:
raise
def delete_collection(self, name:str)->None:
''' Vide et supprime la collection dont le nom est 'name', et la supprime du disque si elle est persistante '''
col = self.get_collection(name)
if col != None:
self.collections.remove(col)
if self.persist_dir != None:
try:
self._create_persist_dir()
col.delete(self.persist_dir)
except:
raise
def normalize(self, v:list[float])->list[float]:
'''
Normalement les LLMs renvoient des vecteurs normalisés mais:
c'est pas sûr pour ceux que je n'ai pas testés
c'est pratique d'avoir cette méthode pour 'test_store.py'
Args:
v: le vecteur à normaliser
Returns:
le vecteur normalisé
'''
norm = 0.0
for i in range(len(v)):
norm += v[i] * v[i]
norm = sqrt(norm)
if norm == 0.0:
return v.copy()
result = [None] * len(v)
for i in range(len(v)):
result[i] = v[i] / norm
return result
def dot_product(self, v1:list[float], v2:list[float])->float:
'''
Le produit scalaire est utilisé pour une similarité en cosinus:
cos(a) = (vecA dot vecB) / (A.B)
si les vecteurs A et B sont normalisés, le cos est simplement le produit scalaire
Args:
v1, v2: les deux vecteurs à multiplier
Returns:
Un float égal à v1 dot v2
'''
result = 0.0
for i in range(len(v1)):
result += v1[i] * v2[i]
return result
def get_similar_vector(self, vector:list[float], collection_name:str)->list[float]:
'''
Renvoie le vecteur de 'collection' le pus similaire à 'vector'.
Args:
vector: un vecteur obtenu avec le même modèle d'embeddings que les vecteurs de la 'collection'
collection_name: le nom de la collection de la base dans laquelle on cherche une similarité
Return:
Le vecteur le plus similaire 'vector'
'''
col:Collection = self.get_collection(collection_name)
best_doc:Document = None
best_dp: float = -20.0
if col != None:
for doc in col.docs:
dp:float = self.dot_product(vector, doc.vec)
if dp > best_dp:
best_dp = dp
best_doc = doc
return best_doc.vec
else:
return None
def get_similar_chunk(self, query_vector:list[float], collection_name:str)->tuple[str, str]:
'''
Renvoie le document de la 'collection' le plus similaire à 'query_vector'.
Args:
query_vector: un vecteur obtenu avec le même modèle d'embeddings que les vecteurs de la 'collection'
collection: la collection de la base dans laquelle on cherche une similarité
Returns:
Un tuple contenant:
le document
la source du document
'''
col:Collection = self.get_collection(collection_name)
best_doc:Document = None
best_dp: float = -20.0
if col != None:
for doc in col.docs:
dp:float = self.dot_product(query_vector, doc.vec)
print(dp)
if dp > best_dp:
best_dp = dp
best_doc = doc
return best_doc.chunk, best_doc.source
else:
return None, None
def get_similar_chunks(self, query_vector:list[float], count:int, collection_name:str):
'''
Returns:
Un tuple contenant:
les documents
la source des documents
les ids des documents
a[0:count-1]
'''
# start:int = time.time()
col:Collection = self.get_collection(collection_name)
if col == None:
return None, None, None
bests:list[dict] = []
# Ajouter tous les docs avec leur dotproduct à la liste bests
for doc in col.docs:
dp:float = self.dot_product(query_vector, doc.vec)
bests.append({'doc':doc, 'dp':dp})
# Trier la liste en reverse à partir de la clé 'dp'
bests.sort(key=operator.itemgetter('dp'), reverse=True)
# Adapter le nombre de documents à renvoyer s'il n'y a pas assez de chunks
n:int = count if len(bests) >= count else len(bests)
# print("get_similar_chunks, count=", count, ", n=", n)
# Créer les variables de retour
docs = [b['doc'].chunk for b in bests[0:n]]
source = bests[0]['doc'].source if n > 0 else None
ids = [b['doc'].idd for b in bests[0:n]]
# print("my_store.get_similar_chunks:", time.time() - start, "s")
return docs, source, ids