elly99 commited on
Commit
73cd1d1
·
verified ·
1 Parent(s): e854e57

Create memory/faiss_memory.py

Browse files
Files changed (1) hide show
  1. src/memory/faiss_memory.py +93 -0
src/memory/faiss_memory.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © 2025 Elena Marziali — Code released under Apache 2.0 license.
2
+ # See LICENSE in the repository for details.
3
+ # Removal of this copyright is prohibited.
4
+
5
+ # === FAISS Parameters ===
6
+ INDEX_FILE = "faiss_memoria_pq.pkl"
7
+ dimension = 768
8
+ nlist = 100
9
+ m = 32
10
+ nbits = 8
11
+
12
+ # Load or create a FAISS index for vector memory
13
+ def load_or_create_index():
14
+ if os.path.exists(INDEX_FILE):
15
+ with open(INDEX_FILE, "rb") as f:
16
+ index = pickle.load(f)
17
+ # Verifica che l'indice sia addestrato
18
+ if hasattr(index, "is_trained") and not index.is_trained:
19
+ print("Indice FAISS caricato ma non addestrato. Addestramento in corso...")
20
+ index.train(np.random.rand(5000, dimension).astype(np.float32))
21
+ with open(INDEX_FILE, "wb") as f:
22
+ pickle.dump(index, f)
23
+ return index
24
+ else:
25
+ quantizer = faiss.IndexFlatL2(dimension)
26
+ index = faiss.IndexIVFPQ(quantizer, dimension, nlist, m, nbits)
27
+ index.train(np.random.rand(5000, dimension).astype(np.float32))
28
+ with open(INDEX_FILE, "wb") as f:
29
+ pickle.dump(index, f)
30
+ return index
31
+
32
+ index = load_or_create_index()
33
+
34
+ if hasattr(index, "is_trained") and not index.is_trained:
35
+ logging.warning("Indice FAISS non addestrato. Addestramento in corso...")
36
+ index.train(np.random.rand(5000, DIMENSION).astype(np.float32))
37
+
38
+
39
+ # === Semantic coherence check ===
40
+ def check_coherence(query, response):
41
+ emb_query = embedding_model.encode([query])
42
+ emb_response = embedding_model.encode([response])
43
+ similarity = np.dot(emb_query, emb_response.T) / (np.linalg.norm(emb_query) * np.linalg.norm(emb_response))
44
+ if similarity < 0.7:
45
+ return "The response is too generic, reformulating with more precision..."
46
+ return response
47
+
48
+ # === Memory addition ===
49
+ # Each document is converted into embeddings and inserted into the index.
50
+ def add_to_memory(question, answer):
51
+ emb_question = embedding_model.encode([question])
52
+ if emb_question.shape[1] != index.d:
53
+ raise ValueError(f"Embedding dimension ({emb_question.shape[1]}) not compatible with FAISS ({index.d})")
54
+ index.add(np.array(emb_question, dtype=np.float32))
55
+ with open(INDEX_FILE, "wb") as f:
56
+ pickle.dump(index, f)
57
+ print("Memory updated with new question!")
58
+
59
+ def add_diary_to_memory(diary_text, index):
60
+ embedding = embedding_model.encode([diary_text])
61
+ index.add(np.array(embedding, dtype=np.float32))
62
+
63
+ def search_similar_diaries(query, index, top_k=3):
64
+ query_emb = embedding_model.encode([query])
65
+ _, indices = index.search(np.array(query_emb, dtype=np.float32), top_k)
66
+ return indices[0] # You can then map these IDs to files or content
67
+
68
+ # === Context retrieval ===
69
+ def retrieve_context(question, top_k=3):
70
+ emb_question = embedding_model.encode([question])
71
+ _, indices = index.search(np.array(emb_question, dtype=np.float32), top_k)
72
+ return [f"Similar response {i+1}" for i in indices[0]] if indices[0][0] != -1 else []
73
+
74
+ def retrieve_similar_embeddings(question, top_k=2):
75
+ """
76
+ Retrieves the top-k most similar embeddings to the given question.
77
+ """
78
+ emb = embedding_model.encode([question])
79
+ _, indices = index.search(np.array([emb], dtype=np.float32), top_k)
80
+ return [f"Similar response {i+1}" for i in indices[0]] if indices[0][0] != -1 else []
81
+
82
+ # === Multi-turn retrieval ===
83
+ # Retrieves context from previous conversations
84
+ def retrieve_multiturn_context(question, top_k=5):
85
+ emb_question = embedding_model.encode([question])
86
+ _, indices = index.search(np.array(emb_question, dtype=np.float32), top_k)
87
+ context = [f"Previous turn {i+1}" for i in indices[0] if i != -1]
88
+ return " ".join(context) if context else ""
89
+
90
+ # === Usage example ===
91
+ add_to_memory("What is general relativity?", "General relativity is Einstein's theory of gravity.")
92
+ similar_responses = retrieve_context("Can you explain general relativity?")
93
+ print("Related responses:", similar_responses)