dnj0 commited on
Commit
f84a554
·
1 Parent(s): 34bfedc
Files changed (2) hide show
  1. src/config.py +10 -17
  2. src/vector_store.py +161 -49
src/config.py CHANGED
@@ -2,37 +2,30 @@ import os
2
  from pathlib import Path
3
 
4
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
5
-
6
- OPENAI_MODEL = "gpt-4o-mini"
7
-
8
- USE_CACHE = True
9
 
10
  CHROMA_DB_PATH = "./chroma_db"
11
-
12
  DOCSTORE_PATH = "./docstore"
13
-
14
  PROCESSED_FILES_LOG = "./processed_files.txt"
15
 
16
  EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
17
-
18
  EMBEDDING_DIM = 768
19
 
20
- MAX_CHUNK_SIZE = 500
21
-
22
- CHUNK_OVERLAP = 50
23
-
24
- TEMPERATURE = 0.3
25
-
26
- MAX_TOKENS = 500
27
 
28
  LANGUAGE = "russian"
29
 
30
  Path(CHROMA_DB_PATH).mkdir(exist_ok=True)
31
-
32
  Path(DOCSTORE_PATH).mkdir(exist_ok=True)
33
 
34
  UPLOAD_FOLDER = "./uploaded_pdfs"
35
-
36
  Path(UPLOAD_FOLDER).mkdir(exist_ok=True)
 
37
 
38
- MAX_PDF_SIZE_MB = 50
 
 
 
2
  from pathlib import Path
3
 
4
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
5
+ OPENAI_MODEL = "gpt-4o-mini" # Cheaper model variant
6
+ USE_CACHE = True # Enable response caching
 
 
7
 
8
  CHROMA_DB_PATH = "./chroma_db"
 
9
  DOCSTORE_PATH = "./docstore"
 
10
  PROCESSED_FILES_LOG = "./processed_files.txt"
11
 
12
  EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
 
13
  EMBEDDING_DIM = 768
14
 
15
+ MAX_CHUNK_SIZE = 500 # Smaller chunks = fewer tokens
16
+ CHUNK_OVERLAP = 50 # Less overlap = fewer chunks
17
+ TEMPERATURE = 0.3 # Lower = faster, cheaper
18
+ MAX_TOKENS = 500 # Limit response size (vs 1500)
 
 
 
19
 
20
  LANGUAGE = "russian"
21
 
22
  Path(CHROMA_DB_PATH).mkdir(exist_ok=True)
 
23
  Path(DOCSTORE_PATH).mkdir(exist_ok=True)
24
 
25
  UPLOAD_FOLDER = "./uploaded_pdfs"
 
26
  Path(UPLOAD_FOLDER).mkdir(exist_ok=True)
27
+ MAX_PDF_SIZE_MB = 50
28
 
29
+ BATCH_SEARCH_RESULTS = 3
30
+ CACHE_RESPONSES = True
31
+ SUMMARIZE_FIRST = True
src/vector_store.py CHANGED
@@ -1,88 +1,200 @@
1
  import os
 
2
  from typing import List, Dict
3
- from chromadb.config import Settings
4
  import chromadb
5
- from config import CHROMA_DB_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class VectorStore:
8
  def __init__(self):
9
- self.chroma_path = CHROMA_DB_PATH
10
- self.settings = Settings(
11
- chroma_db_impl_embed_collection_mixin=True,
12
- persist_directory=self.chroma_path,
13
- anonymized_telemetry=False,
14
- allow_reset=True,
15
- )
16
- self.client = chromadb.Client(self.settings)
17
- self.collection = self.client.get_or_create_collection(
18
- name="documents",
19
- metadata={"hnsw:space": "cosine"}
20
- )
21
-
22
- def add_documents(self, documents: Dict, doc_id: str):
23
  try:
24
- text = documents.get('text', '')
25
- if not text or len(text.strip()) < 1:
26
- print(f"Empty text for {doc_id}")
27
- return
28
- self.collection.add(
29
- ids=[doc_id],
30
- documents=[text],
31
- metadatas=[{
32
- 'doc_id': doc_id,
33
- 'source': 'pdf_document'
34
- }]
35
  )
36
- print(f"Added document to vector store: {doc_id}")
37
  except Exception as e:
38
- print(f"Error adding documents to vector store: {e}")
39
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def search(self, query: str, n_results: int = 5) -> List[Dict]:
42
  try:
 
 
43
  results = self.collection.query(
44
- query_texts=[query],
45
- n_results=n_results,
46
- include=['documents', 'metadatas', 'distances', 'embeddings']
47
  )
 
48
  formatted_results = []
49
- if results and results['documents'] and len(results['documents']) > 0:
50
- for idx, doc in enumerate(results['documents'][0]):
51
- distance = results['distances'][0][idx] if results['distances'] else 0
 
 
52
  formatted_results.append({
53
  'content': doc,
54
- 'metadata': results['metadatas'][0][idx] if results['metadatas'] else {},
55
  'distance': distance,
56
- 'type': 'document'
57
  })
 
58
  return formatted_results
59
  except Exception as e:
60
  print(f"Error searching vector store: {e}")
61
  return []
62
 
 
 
 
 
 
 
 
 
 
63
  def get_collection_info(self) -> Dict:
64
  try:
65
  count = self.collection.count()
66
  return {
 
67
  'count': count,
68
- 'status': 'ready',
69
- 'persist_path': self.chroma_path
70
  }
71
  except Exception as e:
72
  print(f"Error getting collection info: {e}")
73
- return {
74
- 'count': 0,
75
- 'status': 'error',
76
- 'persist_path': self.chroma_path
77
- }
 
 
 
 
 
 
 
 
 
78
 
79
  def clear_all(self):
80
  try:
81
- self.client.delete_collection(name="documents")
82
  self.collection = self.client.get_or_create_collection(
83
- name="documents",
84
  metadata={"hnsw:space": "cosine"}
85
  )
86
- print("Vector store cleared")
87
  except Exception as e:
88
- print(f"Error clearing vector store: {e}")
 
1
  import os
2
+ import json
3
  from typing import List, Dict
 
4
  import chromadb
5
+ from sentence_transformers import SentenceTransformer
6
+ import numpy as np
7
+ from config import CHROMA_DB_PATH, EMBEDDING_MODEL, EMBEDDING_DIM
8
+
9
+
10
+ class CLIPEmbedder:
11
+ def __init__(self, model_name: str = EMBEDDING_MODEL):
12
+ print(f"Loading embedding model: {model_name}")
13
+ self.model = SentenceTransformer(model_name)
14
+ print(f"Model loaded successfully")
15
+
16
+ def embed(self, text: str) -> List[float]:
17
+ try:
18
+ embedding = self.model.encode(text, convert_to_numpy=False)
19
+ return embedding.tolist() if hasattr(embedding, 'tolist') else embedding
20
+ except Exception as e:
21
+ print(f"Error embedding text: {e}")
22
+ return [0.0] * EMBEDDING_DIM
23
+
24
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
25
+ try:
26
+ embeddings = self.model.encode(texts, convert_to_numpy=False)
27
+ return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings]
28
+ except Exception as e:
29
+ print(f"Error embedding batch: {e}")
30
+ return [[0.0] * EMBEDDING_DIM] * len(texts)
31
+
32
 
33
  class VectorStore:
34
  def __init__(self):
35
+ self.persist_directory = CHROMA_DB_PATH
36
+ self.embedder = CLIPEmbedder()
37
+
38
+ print(f"Initializing ChromaDB at: {self.persist_directory}")
39
+
 
 
 
 
 
 
 
 
 
40
  try:
41
+ self.client = chromadb.PersistentClient(
42
+ path=self.persist_directory
 
 
 
 
 
 
 
 
 
43
  )
44
+ print(f"ChromaDB initialized")
45
  except Exception as e:
46
+ print(f"Error initializing ChromaDB: {e}")
47
+ self.client = chromadb.PersistentClient(
48
+ path=self.persist_directory
49
+ )
50
+
51
+ try:
52
+ self.collection = self.client.get_or_create_collection(
53
+ name="multimodal_rag",
54
+ metadata={"hnsw:space": "cosine"}
55
+ )
56
+ count = self.collection.count()
57
+ print(f"Collection loaded: {count} items in store")
58
+ except Exception as e:
59
+ print(f"Error with collection: {e}")
60
+ self.collection = self.client.get_or_create_collection(
61
+ name="multimodal_rag"
62
+ )
63
+
64
+ def add_documents(self, documents: List[Dict], doc_id: str):
65
+ texts = []
66
+ metadatas = []
67
+ ids = []
68
+
69
+ print(f"Adding documents for: {doc_id}")
70
+
71
+ if 'text' in documents and documents['text']:
72
+ chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200)
73
+ for idx, chunk in enumerate(chunks):
74
+ texts.append(chunk)
75
+ metadatas.append({
76
+ 'doc_id': doc_id,
77
+ 'type': 'text',
78
+ 'chunk_idx': str(idx)
79
+ })
80
+ ids.append(f"{doc_id}_text_{idx}")
81
+ print(f"Text: {len(chunks)} chunks")
82
+
83
+ if 'images' in documents:
84
+ image_count = 0
85
+ for idx, image_data in enumerate(documents['images']):
86
+ if image_data.get('ocr_text'):
87
+ texts.append(f"Image {idx}: {image_data['ocr_text']}")
88
+ metadatas.append({
89
+ 'doc_id': doc_id,
90
+ 'type': 'image',
91
+ 'image_idx': str(idx),
92
+ 'image_path': image_data.get('path', '')
93
+ })
94
+ ids.append(f"{doc_id}_image_{idx}")
95
+ image_count += 1
96
+ if image_count > 0:
97
+ print(f"Images: {image_count} with OCR text")
98
+
99
+ if 'tables' in documents:
100
+ table_count = 0
101
+ for idx, table_data in enumerate(documents['tables']):
102
+ if table_data.get('content'):
103
+ texts.append(f"Table {idx}: {table_data.get('content', '')}")
104
+ metadatas.append({
105
+ 'doc_id': doc_id,
106
+ 'type': 'table',
107
+ 'table_idx': str(idx)
108
+ })
109
+ ids.append(f"{doc_id}_table_{idx}")
110
+ table_count += 1
111
+ if table_count > 0:
112
+ print(f"Tables: {table_count}")
113
+
114
+ if texts:
115
+ print(f"Generating {len(texts)} embeddings...")
116
+ embeddings = self.embedder.embed_batch(texts)
117
+
118
+ try:
119
+ self.collection.add(
120
+ ids=ids,
121
+ documents=texts,
122
+ embeddings=embeddings,
123
+ metadatas=metadatas
124
+ )
125
+ print(f"Successfully added {len(texts)} items to vector store")
126
+ except Exception as e:
127
+ print(f"Error adding to collection: {e}")
128
 
129
  def search(self, query: str, n_results: int = 5) -> List[Dict]:
130
  try:
131
+ query_embedding = self.embedder.embed(query)
132
+
133
  results = self.collection.query(
134
+ query_embeddings=[query_embedding],
135
+ n_results=n_results
 
136
  )
137
+
138
  formatted_results = []
139
+ if results['documents']:
140
+ for i, doc in enumerate(results['documents'][0]):
141
+ metadata = results['metadatas'][0][i] if results['metadatas'] else {}
142
+ distance = results['distances'][0][i] if results['distances'] else 0
143
+
144
  formatted_results.append({
145
  'content': doc,
146
+ 'metadata': metadata,
147
  'distance': distance,
148
+ 'type': metadata.get('type', 'unknown')
149
  })
150
+
151
  return formatted_results
152
  except Exception as e:
153
  print(f"Error searching vector store: {e}")
154
  return []
155
 
156
+ def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
157
+ chunks = []
158
+ start = 0
159
+ while start < len(text):
160
+ end = start + chunk_size
161
+ chunks.append(text[start:end])
162
+ start = end - overlap
163
+ return chunks
164
+
165
  def get_collection_info(self) -> Dict:
166
  try:
167
  count = self.collection.count()
168
  return {
169
+ 'name': 'multimodal_rag',
170
  'count': count,
171
+ 'status': 'active',
172
+ 'persist_path': self.persist_directory
173
  }
174
  except Exception as e:
175
  print(f"Error getting collection info: {e}")
176
+ return {'status': 'error', 'message': str(e)}
177
+
178
+ def delete_by_doc_id(self, doc_id: str):
179
+ try:
180
+ results = self.collection.get(where={'doc_id': doc_id})
181
+ if results['ids']:
182
+ self.collection.delete(ids=results['ids'])
183
+ print(f"Deleted {len(results['ids'])} documents for {doc_id}")
184
+ print(f"Changes persisted automatically")
185
+ except Exception as e:
186
+ print(f"Error deleting documents: {e}")
187
+
188
+ def persist(self):
189
+ print("Vector store is using auto-persist")
190
 
191
  def clear_all(self):
192
  try:
193
+ self.client.delete_collection(name="multimodal_rag")
194
  self.collection = self.client.get_or_create_collection(
195
+ name="multimodal_rag",
196
  metadata={"hnsw:space": "cosine"}
197
  )
198
+ print("Collection cleared and reset")
199
  except Exception as e:
200
+ print(f"Error clearing collection: {e}")