Sami Ali commited on
Commit
51349bc
·
1 Parent(s): 7f5929f

implement pipline for magrag-assistant

Browse files
Files changed (8) hide show
  1. .gitignore +3 -0
  2. app.py +12 -0
  3. madrag.ipynb +113 -0
  4. src/__init__.py +0 -0
  5. src/constant.py +3 -0
  6. src/data_processor.py +108 -0
  7. src/embedding.py +26 -0
  8. src/vectorstore.py +91 -0
.gitignore CHANGED
@@ -3,6 +3,9 @@ __pycache__/
3
  *.py[codz]
4
  *$py.class
5
 
 
 
 
6
  # C extensions
7
  *.so
8
 
 
3
  *.py[codz]
4
  *$py.class
5
 
6
+ # data
7
+ data
8
+
9
  # C extensions
10
  *.so
11
 
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.data_processor import DataProcessor
2
+ from src.embedding import EmbeddingManager
3
+ from src.vectorstore import VectorStore
4
+
5
+ if __name__ == '__main__':
6
+ dp = DataProcessor()
7
+ chunks, document = dp.build()
8
+ embd = EmbeddingManager()
9
+ chunks_embedding = embd.embed_texts(chunks)
10
+ vectorstore = VectorStore()
11
+ vectorstore.add_documents(chunks, chunks_embedding)
12
+ retriver = vectorstore.get_retriever()
madrag.ipynb ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 16,
6
+ "id": "c80e0812",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from langchain.text_splitter import CharacterTextSplitter"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 17,
16
+ "id": "bbc6a9d6",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "with open('./data/pmc/PMC10000000.txt', \"r\", encoding='utf-8') as file:\n",
21
+ " data = file.read()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 18,
27
+ "id": "9eba0782",
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "name": "stdout",
32
+ "output_type": "stream",
33
+ "text": [
34
+ "23842\n"
35
+ ]
36
+ }
37
+ ],
38
+ "source": [
39
+ "print(len(data))"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 23,
45
+ "id": "c0b716f8",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "chunks = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0, separator=' ')"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 24,
55
+ "id": "14aa384c",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "temp = chunks.split_text(data)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 25,
65
+ "id": "77187982",
66
+ "metadata": {},
67
+ "outputs": [
68
+ {
69
+ "data": {
70
+ "text/plain": [
71
+ "24"
72
+ ]
73
+ },
74
+ "execution_count": 25,
75
+ "metadata": {},
76
+ "output_type": "execute_result"
77
+ }
78
+ ],
79
+ "source": [
80
+ "len(temp)"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "4c254a11",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": []
90
+ }
91
+ ],
92
+ "metadata": {
93
+ "kernelspec": {
94
+ "display_name": "venv",
95
+ "language": "python",
96
+ "name": "python3"
97
+ },
98
+ "language_info": {
99
+ "codemirror_mode": {
100
+ "name": "ipython",
101
+ "version": 3
102
+ },
103
+ "file_extension": ".py",
104
+ "mimetype": "text/x-python",
105
+ "name": "python",
106
+ "nbconvert_exporter": "python",
107
+ "pygments_lexer": "ipython3",
108
+ "version": "3.11.0"
109
+ }
110
+ },
111
+ "nbformat": 4,
112
+ "nbformat_minor": 5
113
+ }
src/__init__.py ADDED
File without changes
src/constant.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os
2
+
3
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
src/data_processor.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src.constant import BASE_DIR
3
+ from langchain.schema import Document
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+
6
+ DATA_DIR = os.path.join(BASE_DIR, "data", "pmc")
7
+
8
+
9
+ class DataProcessor:
10
+ """
11
+ Handles loading, cleaning, and chunking of text files
12
+ from the PubMed Central (PMC) dataset.
13
+ """
14
+
15
+ def __init__(self, data_path: str = DATA_DIR):
16
+ self.data_path = data_path
17
+
18
+ def _load_files(self) -> list[dict]:
19
+ """
20
+ Load raw text files from the dataset directory.
21
+ Returns a list of dictionaries with file name and raw content.
22
+ """
23
+ count = 0
24
+ data_list = []
25
+ for file_name in os.listdir(self.data_path):
26
+ if not file_name.endswith(".txt"):
27
+ continue
28
+ file_path = os.path.join(self.data_path, file_name)
29
+ with open(file_path, "r", encoding="utf-8") as file_ref:
30
+ data_list.append(
31
+ {
32
+ "file_name": file_name,
33
+ "page_content": file_ref.read()
34
+ }
35
+ )
36
+ if count >= 2:
37
+ break
38
+ count += 1
39
+
40
+ return data_list
41
+
42
+ @staticmethod
43
+ def _decode_unicode(text: str) -> str:
44
+ """
45
+ Convert escaped unicode sequences to proper text.
46
+ """
47
+ if not isinstance(text, str):
48
+ return text
49
+ try:
50
+ return text.encode("utf-8").decode("unicode-escape")
51
+ except Exception:
52
+ return text
53
+
54
+ def _preprocess(self, data: list[dict]) -> list[dict]:
55
+ """
56
+ Apply preprocessing steps (e.g., unicode decoding) to raw data.
57
+ """
58
+ cleaned_data = []
59
+ for record in data:
60
+ decoded_text = self._decode_unicode(record["page_content"])
61
+ cleaned_data.append(
62
+ {
63
+ "file_name": record["file_name"],
64
+ "page_content": decoded_text
65
+ }
66
+ )
67
+ return cleaned_data
68
+
69
+ def load_documents(self) -> list[Document]:
70
+ """
71
+ Load and preprocess text files, converting them into
72
+ LangChain Document objects.
73
+ """
74
+ raw_data = self._load_files()
75
+ cleaned_data = self._preprocess(raw_data)
76
+
77
+ return [
78
+ Document(
79
+ page_content=item["page_content"],
80
+ metadata={"source": item["file_name"]}
81
+ )
82
+ for item in cleaned_data
83
+ ]
84
+
85
+ @staticmethod
86
+ def chunk_documents(documents: list[Document],
87
+ chunk_size: int = 1000,
88
+ chunk_overlap: int = 200) -> list[Document]:
89
+ """
90
+ Split documents into smaller chunks for embedding and retrieval.
91
+ """
92
+ splitter = RecursiveCharacterTextSplitter(
93
+ chunk_size=chunk_size,
94
+ chunk_overlap=chunk_overlap,
95
+ length_function=len
96
+ )
97
+ return splitter.split_documents(documents)
98
+
99
+ def build(self) -> tuple[list[Document], list[Document]]:
100
+ """
101
+ End-to-end pipeline:
102
+ - Load documents
103
+ - Chunk them
104
+ Returns (chunks, original documents).
105
+ """
106
+ documents = self.load_documents()
107
+ chunks = self.chunk_documents(documents)
108
+ return chunks, documents
src/embedding.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+
5
+
6
+ class EmbeddingManager:
7
+ def __init__(self, model_name: str = "pritamdeka/S-BioBERT-snli-multinli-stsb"):
8
+ self.model_name = model_name
9
+ self.model = None
10
+ self.load_model()
11
+
12
+ def load_model(self):
13
+ print("Loading embedding model:", self.model_name)
14
+ self.model = HuggingFaceEmbeddings(model_name=self.model_name)
15
+ print("Model loaded.")
16
+
17
+ def get_model(self):
18
+ return self.model
19
+
20
+ def embed_texts(self, texts: List[str]) -> np.ndarray:
21
+ if self.model is None:
22
+ raise RuntimeError("Model not loaded. Call load_model() first.")
23
+ return self.model.embed_documents(texts)
24
+
25
+ def embed_one(self, text: str) -> np.ndarray:
26
+ return self.model.embed_query(text)
src/vectorstore.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from typing import List
4
+ from pathlib import Path
5
+ from src.constant import BASE_DIR
6
+ import chromadb
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.schema import Document
9
+ from uuid import uuid4
10
+
11
+ DATA_DIR = os.path.join(BASE_DIR, "db")
12
+
13
+
14
+ class VectorStore:
15
+ """
16
+ Wrapper around Chroma vector database for persistent storage
17
+ and retrieval of document embeddings.
18
+ """
19
+
20
+ def __init__(self,
21
+ collection_name: str = "medrag",
22
+ persist_directory: str = DATA_DIR):
23
+ self.collection_name = collection_name
24
+ self.persist_directory = persist_directory
25
+ self.client = None
26
+ self.collection = None
27
+ self._initialize_store()
28
+
29
+ def _initialize_store(self):
30
+ """Initialize Chroma client and collection."""
31
+ try:
32
+ dir_path = Path(self.persist_directory)
33
+ dir_path.mkdir(parents=True, exist_ok=True)
34
+
35
+ self.client = chromadb.PersistentClient(self.persist_directory)
36
+ self.collection = self.client.get_or_create_collection(
37
+ name=self.collection_name,
38
+ metadata={"description": "RAG collection for biomedical research"}
39
+ )
40
+ print(f"Store initialized successfully: {self.collection_name}")
41
+ except Exception as e:
42
+ print(f"Error initializing the store: {e}")
43
+ raise
44
+
45
+ def get_len(self) -> int:
46
+ """Return number of documents in the collection."""
47
+ return self.collection.count()
48
+
49
+ def add_documents(self, documents: List[Document], embeddings: np.ndarray, batch_size: int = 5000):
50
+ """
51
+ Add documents and their embeddings to the vector store in batches.
52
+ """
53
+ if isinstance(embeddings, np.ndarray):
54
+ embeddings = embeddings.tolist() # Ensure compatibility
55
+
56
+ for start in range(0, len(documents), batch_size):
57
+ batch_docs = documents[start:start + batch_size]
58
+ batch_embeds = embeddings[start:start + batch_size]
59
+
60
+ ids, metadatas, texts, embeds = [], [], [], []
61
+
62
+ for idx, (doc, emb) in enumerate(zip(batch_docs, batch_embeds)):
63
+ ids.append(f"doc_{uuid4().hex}")
64
+ texts.append(doc.page_content)
65
+ metadata = dict(doc.metadata) if getattr(doc, "metadata", None) else {}
66
+ metadata.update({"doc_index": idx, "content_length": len(doc.page_content)})
67
+ metadatas.append(metadata)
68
+ embeds.append(emb)
69
+
70
+ self.collection.add(
71
+ ids=ids,
72
+ documents=texts,
73
+ embeddings=embeds,
74
+ metadatas=metadatas
75
+ )
76
+
77
+ print(f"Documents and embeddings added to collection: {self.collection_name}")
78
+
79
+ def get_retriever(self, embedding_function, search_kwargs: dict = None):
80
+ """
81
+ Return a retriever interface for semantic search.
82
+ """
83
+ if search_kwargs is None:
84
+ search_kwargs = {"k": 5}
85
+
86
+ vectorstore = Chroma(
87
+ client=self.client,
88
+ collection_name=self.collection_name,
89
+ embedding_function=embedding_function
90
+ )
91
+ return vectorstore.as_retriever(search_kwargs=search_kwargs)