File size: 11,833 Bytes
3efe7a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
import glob
import pickle, json
from tqdm import tqdm
import numpy as np

# Try imports with friendly errors
try:
    import faiss
except Exception as e:
    raise ImportError("faiss is required. Install cpu version: `pip install faiss-cpu` or install via conda for GPU (faiss-gpu).") from e

try:
    from sentence_transformers import SentenceTransformer
except Exception as e:
    raise ImportError("sentence-transformers is required. `pip install sentence-transformers`") from e

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from dotenv import load_dotenv


from Data_Cleaning import GetDataCleaning
from Logger import GetLogger


class GetEmbeddings:
    """
    Embedding pipeline for cleaned text files.
    Generates embeddings using SentenceTransformers, builds a FAISS index,
    and allows searching queries against the vector database.
    """

    def __init__(self, config_path="config.json", logger=None):
	    
        with open(config_path, "r") as f:
            self.config = json.load(f)

        cfg_paths = self.config["paths"]
		
		
        cfg_emb = self.config["embedding"]

        self.root = cfg_paths["root"]
        self.cleaned_suffix = "_cleaned_txt"
        self.chunk_words = cfg_emb["chunk_words"]
        self.batch_size = cfg_emb["batch_size"]
        self.faiss_index_path = cfg_paths["faiss_index"]
        self.metadata_path = cfg_paths["metadata"]
        self.embedding_model = cfg_emb["model"]

        if not logger:
            obj = GetLogger()
            logger = obj.get_logger()
        self.logger = logger
        self.logger.info("Initializing Embedding Pipeline...")

        # Device
        self.device = "cuda" if self.check_cuda() and cfg_emb["use_gpu"] else "cpu"
        load_dotenv()
        self.hf_token = os.getenv("HF_TOKEN")

    def check_cuda(self):
        """Return True if CUDA is available and usable."""
        try:
            if torch.cuda.is_available():
                _ = torch.cuda.current_device()
                self.logger.info(f"โœ… CUDA available. Device: {torch.cuda.get_device_name(0)}")
                return True
            self.logger.info("โš ๏ธ CUDA not available. Using CPU.")
            return False
        except Exception as e:
            self.logger.error(f"Error checking CUDA, defaulting to CPU. Error: {e}")
            return False

    def list_cleaned_files(self):
        """Return sorted list of cleaned text files under root/*{cleaned_suffix}/*.txt"""
        pattern = os.path.join(self.root, f"*{self.cleaned_suffix}", "*.txt")
        files = glob.glob(pattern)
        files.sort()
        return files
    
    def read_text_file(self, path):
        """Read a text file and return string content."""
        with open(path, "r", encoding="utf-8") as f:
            return f.read()

    def chunk_text_words(self, text):
        """
        Simple word-based chunking.
        Returns list of text chunks.
        """
        words = text.split()
        if not words:
            return []
        return [" ".join(words[i:i + self.chunk_words]) for i in range(0, len(words), self.chunk_words)]

    def save_index_and_metadata(self):
        """Save FAISS index and metadata to disk."""
        os.makedirs(os.path.dirname(self.faiss_index_path), exist_ok=True)
        faiss.write_index(self.index, self.faiss_index_path)
        with open(self.metadata_path, "wb") as f:
            pickle.dump(self.metadata, f)
        self.logger.info(f"๐Ÿ’พ Saved FAISS index to {self.faiss_index_path}")
        self.logger.info(f"๐Ÿ’พ Saved metadata to {self.metadata_path}")

    def load_index_and_metadata(self):
        """Load FAISS index and metadata if they exist."""
        if os.path.exists(self.faiss_index_path) and os.path.exists(self.metadata_path):
            try:
                self.index = faiss.read_index(self.faiss_index_path)
                with open(self.metadata_path, "rb") as f:
                    self.metadata = pickle.load(f)
                self.logger.info(f"โœ… Loaded existing FAISS index + metadata from disk.")
                return True
            except Exception as e:
                self.logger.warning(f"โš ๏ธ Failed to load FAISS index/metadata, will rebuild. Error: {e}")
                return False
        return False

    def load_encoder(self):
        """Loading Encoder"""
        self.encoder = SentenceTransformer(self.embedding_model, device=self.device)
        self.logger.info(f"Loaded embedding model '{self.embedding_model}' on {self.device}")
        return self.encoder


    def building_embeddings_index(self, files):
        """Build embeddings for all text chunks and return FAISS index + metadata."""
        

        all_embeddings, metadata = [], []
        next_id = 0
        # Iterate files and chunks
        for fp in tqdm(files, desc="Files", unit="file"):
            text = self.read_text_file(fp)

            if not text.strip():
                continue   

            # metadata: infer company and file from path
            # e.g., financial_reports/Infosys_cleaned_txt/Infosys_2023_AR.txt
            rel = os.path.relpath(fp, self.root)
            folder = rel.split(os.sep)[0]
            filename = os.path.basename(fp)

            chunks = self.chunk_text_words(text)
            if not chunks:
                continue

            for i in range(0, len(chunks), self.batch_size):
                batch = chunks[i:i + self.batch_size]
                embs = self.encoder.encode(batch, show_progress_bar=False, convert_to_numpy=True)
                embs = embs.astype(np.float32)

                for j, vec in enumerate(embs):
                    all_embeddings.append(vec)
                    metadata.append({
                        "id": next_id,
                        "source_folder": folder,
                        "file": filename,
                        "chunk_id": i + j,
                        "text": batch[j]  # store chunk text for retrieval
                    })
                    next_id += 1

        if not all_embeddings:
            raise RuntimeError("No embeddings were produced. Check cleaned files and chunking.")

        emb_matrix = np.vstack(all_embeddings).astype(np.float32)
        faiss.normalize_L2(emb_matrix)

        # Build FAISS index (IndexFlatIP over normalized vectors = cosine similarity)
        dim = emb_matrix.shape[1]
        self.index = faiss.IndexFlatIP(dim)
        self.index.add(emb_matrix)
        self.metadata = metadata
        self.logger.info(f"โœ… Built FAISS index with {self.index.ntotal} vectors, dim={dim}")

        return self.index, self.metadata

    def run(self):
        """Main entry: load or build embeddings + FAISS index."""
        if self.load_index_and_metadata():
            return

        files = self.list_cleaned_files()
        if not files:
            self.logger.error("โŒ No cleaned text files found.")
            raise SystemExit(1)
        self.load_encoder()
        self.building_embeddings_index(files)
        self.save_index_and_metadata()

    def load_summarizer(self, model_name="google/gemma-2b"):
        """
        Load summarizer LLM once.
        If already loaded, skip.
        """
        if hasattr(self, "summarizer_pipeline"):
            self.logger.info("โ„น๏ธ Summarizer already loaded, skipping reload.")
            return

        try:
            self.logger.info(f"โณ Loading summarizer model '{model_name}'...")
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.hf_token)
            self.summarizer_model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                device_map=self.device,
                token=self.hf_token
            )
            self.summarizer_pipeline = pipeline(
                "text-generation",
                model=self.summarizer_model,
                tokenizer=self.tokenizer
            )
            self.logger.info(f"โœ… Summarizer model '{model_name}' loaded successfully.")

        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                self.logger.warning("โš ๏ธ CUDA OOM while loading summarizer. Retrying on CPU...")
                self.device = "cpu"
                torch.cuda.empty_cache()
                return self.load_summarizer(model_name=model_name)
            else:
                self.logger.error(f"โŒ Failed to load summarizer: {e}")
                raise

    def summarize_chunks(self, chunks, max_content_tokens=2048, max_output_tokens=256):
        """
        Summarize list of text chunks using LLM.
        - Chunks are joined until they fit into max_context_tokens
        - Generates a concise summary.
        """

        if not hasattr(self, "summarizer_pipeline"):
            self.load_summarizer()
            self.logger.info("Summarizer not initialized. Called load_summarizer(). pipeline will work with default parameters.")
        
        # Join chunks into one context, respecting token budget
        context = " ".join(chunks)
        input_tokens = len(self.tokenizer.encode(context))

        if input_tokens > max_content_tokens:
            # Trim to fit context window
            context = " ".join(context.split()[:max_content_tokens])
            self.logger.warning("โš ๏ธ Context truncated to fit within model token limit.")
        
        # Build summarization prompt
        prompt = f"""
            Summarize the following financial report excerpts into a concise answer.
            Keep it factual, short, and grounded in the text.

            Excerpts: 
            {context}

            Summary:
            """
        
        try:
            output = self.summarizer_pipeline(
                prompt,
                max_new_tokens=max_output_tokens,
                do_sample=False
            )[0]["generated_text"]

            if "Summary:" in output:
                summary = output.split("Summary:")[-1].strip()
            else:
                summary = output.strip()

            return summary

        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                self.logger.warning("โš ๏ธ CUDA OOM during summarization. Retrying on CPU...")
                self.device = "cpu"
                torch.cuda.empty_cache()
                return self.summarize_chunks(chunks, max_content_tokens, max_output_tokens)
            else:
                self.logger.error(f"โŒ Summarizer failed: {e}. Falling back to raw chunks.")
                return " ".join(chunks[:2])  # fallback: return first 2 chunks


    def answer_query(self, query, top_k=3):
        """
        End-to-end QA:
        - Retrieve relevant chunks from FAISS
        - Summarize into a final answer.
        """
        try:
            #step 1: Retrieve
            self.logger.info(f"๐Ÿ” searching vector DB for query: {query}")
            q_emb = self.encoder.encode(query, show_progress_bar=False, convert_to_numpy=True).reshape(1, -1)
            faiss.normalize_L2(q_emb)

            scores, idxs = self.index.search(q_emb, k=top_k)
            chunks = [self.metadata[idx]["text"] for idx in idxs[0]]

            # Step 2: Summarize
            summary = self.summarize_chunks(chunks)

            # Log results
            self.logger.info(f"โœ… Final Answer: {summary}")
            return summary

        except Exception as e:
            self.logger.error(f"Error in answer_query: {e}")
            return None
        

# Example
ge = GetEmbeddings()  
# ge.run()  
# # NEW STEP
# ge.load_summarizer("google/gemma-2b")  
# answer = ge.answer_query("What are the key highlights from Q2 financial report?")  
# print(answer)