File size: 7,816 Bytes
f05e8f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from __future__ import annotations

import json
import os
import hashlib
from pathlib import Path
from typing import List

import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer

from src.utils.logger import get_logger
from config.settings import settings

logger = get_logger(__name__)


class ChromaVectorDBManager:
    """Corporate-friendly ChromaDB manager - completely offline."""

    def __init__(self, model_name: str = None, db_path: str = None):
        self.model_name = model_name or getattr(
            settings, 'EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2'
        )
        self.embedding_model = SentenceTransformer(self.model_name)

        self.db_path = db_path or getattr(settings, 'CHROMADB_PATH', './chroma_db')
        os.makedirs(self.db_path, exist_ok=True)

        self.client = chromadb.PersistentClient(
            path=self.db_path,
            settings=Settings(
                anonymized_telemetry=False,
                allow_reset=True,
                is_persistent=True
            )
        )

        self.collection_name = getattr(settings, 'COLLECTION_NAME', 'rag_chunks')
        self.collection = self._get_collection()

        logger.info(f"ChromaDB initialized at: {self.db_path}")

    def _get_collection(self):
        """Get or create collection without embedding function."""
        try:
            return self.client.get_collection(name=self.collection_name)
        except Exception:
            try:
                self.client.delete_collection(name=self.collection_name)
            except Exception:
                pass
            return self.client.create_collection(
                name=self.collection_name,
                metadata={"description": "RAG chunks"}
            )

    def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings using local sentence-transformers."""
        embeddings = self.embedding_model.encode(
            texts,
            batch_size=32,
            show_progress_bar=len(texts) > 100,
            convert_to_tensor=False
        )
        return embeddings.tolist()

    def add_chunks_to_db(self, chunks: list, source_file: str) -> bool:
        """Add chunks (list of dicts) to ChromaDB with manual embedding generation."""
        if not chunks:
            return True

        texts, ids, metadatas = [], [], []
        seen_hashes = set()

        for chunk in chunks:
            text = chunk.get("text", "").strip()
            if not text:
                continue

            text_hash = hashlib.md5(text.encode()).hexdigest()
            if text_hash in seen_hashes:
                continue
            seen_hashes.add(text_hash)

            chunk_id = f"{source_file}_{chunk.get('chunk_id', 0)}_{text_hash[:8]}"
            try:
                if self.collection.get(ids=[chunk_id])['ids']:
                    continue
            except Exception:
                pass

            texts.append(text)
            ids.append(chunk_id)

            metadata = {
                "source_file": source_file,
                "chunk_index": chunk.get("chunk_id", 0),
                "char_length": len(text),
                "text_hash": text_hash
            }
            metadatas.append(metadata)

        if not texts:
            return True

        embeddings = self.generate_embeddings(texts)
        self.collection.add(
            embeddings=embeddings,
            documents=texts,
            metadatas=metadatas,
            ids=ids
        )

        logger.info(f"Added {len(texts)} chunks from {source_file} to ChromaDB")
        return True

    def search_for_rag(

        self,

        query: str,

        n_results: int = 5,

        use_truncated: bool = False,

        filter_128_context: bool = False

    ) -> List[dict]:
        """Search using manual query embedding generation - completely offline."""
        query_embedding = self.generate_embeddings([query])[0]

        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=min(n_results * 2, 20),
            include=["documents", "metadatas", "distances"]
        )

        search_results = []
        for i, (doc, metadata, distance) in enumerate(zip(
            results['documents'][0], results['metadatas'][0], results['distances'][0]
        )):
            if len(search_results) >= n_results:
                break

            similarity_score = 1 / (1 + distance)
            result = {
                "id": results['ids'][0][i],
                "score": similarity_score,
                "distance": distance,
                "text": doc,
                "source_file": metadata["source_file"],
                "chunk_index": metadata["chunk_index"]
            }
            search_results.append(result)

        return search_results

    def reset_database(self):
        """Reset/delete existing collection."""
        try:
            self.client.delete_collection(name=self.collection_name)
            self.collection = self._get_collection()
            logger.info(f"Reset collection: {self.collection_name}")
            return True
        except Exception as e:
            logger.error(f"Failed to reset database: {e}")
            return False

    def get_collection_stats(self) -> dict:
        """Get collection statistics."""
        count = self.collection.count()

        db_size_mb = 0
        try:
            for file_path in Path(self.db_path).rglob("*"):
                if file_path.is_file():
                    db_size_mb += file_path.stat().st_size
            db_size_mb /= (1024 * 1024)
        except Exception:
            db_size_mb = 0

        return {
            "total_chunks": count,
            "collection_name": self.collection_name,
            "embedding_model": self.model_name,
            "db_path": self.db_path,
            "db_size_mb": db_size_mb
        }

    def process_all_chunks(self, chunks_dir: str = None) -> bool:
        """Process all *_extracted.json files and build ChromaDB."""
        if not chunks_dir:
            chunks_dir = getattr(settings, 'PROCESSED_TEXT_DIR', './data/processed_text')

        chunk_files = list(Path(chunks_dir).glob("*_extracted.json"))
        logger.info(f"Found {len(chunk_files)} extracted JSON files to process")

        total_processed = 0
        for chunk_file in chunk_files:
            try:
                with open(chunk_file, "r", encoding="utf-8") as f:
                    data = json.load(f)

                # Handle the actual structure of extracted JSON files
                if isinstance(data, dict) and "initial_chunks" in data:
                    # New format: { "source_info": {...}, "initial_chunks": [...] }
                    chunks = data["initial_chunks"]
                elif isinstance(data, list):
                    # Old format: list of chunks directly
                    chunks = data
                else:
                    logger.warning(f"Unexpected format in {chunk_file.name}")
                    continue

                if chunks and self.add_chunks_to_db(chunks, source_file=chunk_file.name):
                    total_processed += len(chunks)
                    logger.info(f"Processed {chunk_file.name}: {len(chunks)} chunks")

            except Exception as e:
                logger.error(f"Error processing {chunk_file}: {e}")
                continue

        logger.info(f"Successfully processed {total_processed} total chunks")
        return total_processed > 0