Spaces:
Sleeping
Sleeping
Sami Ali commited on
Commit ·
51349bc
1
Parent(s): 7f5929f
implement pipline for magrag-assistant
Browse files- .gitignore +3 -0
- app.py +12 -0
- madrag.ipynb +113 -0
- src/__init__.py +0 -0
- src/constant.py +3 -0
- src/data_processor.py +108 -0
- src/embedding.py +26 -0
- src/vectorstore.py +91 -0
.gitignore
CHANGED
|
@@ -3,6 +3,9 @@ __pycache__/
|
|
| 3 |
*.py[codz]
|
| 4 |
*$py.class
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
# C extensions
|
| 7 |
*.so
|
| 8 |
|
|
|
|
| 3 |
*.py[codz]
|
| 4 |
*$py.class
|
| 5 |
|
| 6 |
+
# data
|
| 7 |
+
data
|
| 8 |
+
|
| 9 |
# C extensions
|
| 10 |
*.so
|
| 11 |
|
app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.data_processor import DataProcessor
|
| 2 |
+
from src.embedding import EmbeddingManager
|
| 3 |
+
from src.vectorstore import VectorStore
|
| 4 |
+
|
| 5 |
+
if __name__ == '__main__':
|
| 6 |
+
dp = DataProcessor()
|
| 7 |
+
chunks, document = dp.build()
|
| 8 |
+
embd = EmbeddingManager()
|
| 9 |
+
chunks_embedding = embd.embed_texts(chunks)
|
| 10 |
+
vectorstore = VectorStore()
|
| 11 |
+
vectorstore.add_documents(chunks, chunks_embedding)
|
| 12 |
+
retriver = vectorstore.get_retriever()
|
madrag.ipynb
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 16,
|
| 6 |
+
"id": "c80e0812",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from langchain.text_splitter import CharacterTextSplitter"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 17,
|
| 16 |
+
"id": "bbc6a9d6",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"with open('./data/pmc/PMC10000000.txt', \"r\", encoding='utf-8') as file:\n",
|
| 21 |
+
" data = file.read()"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": 18,
|
| 27 |
+
"id": "9eba0782",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [
|
| 30 |
+
{
|
| 31 |
+
"name": "stdout",
|
| 32 |
+
"output_type": "stream",
|
| 33 |
+
"text": [
|
| 34 |
+
"23842\n"
|
| 35 |
+
]
|
| 36 |
+
}
|
| 37 |
+
],
|
| 38 |
+
"source": [
|
| 39 |
+
"print(len(data))"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": 23,
|
| 45 |
+
"id": "c0b716f8",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"chunks = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0, separator=' ')"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 24,
|
| 55 |
+
"id": "14aa384c",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"temp = chunks.split_text(data)"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": 25,
|
| 65 |
+
"id": "77187982",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [
|
| 68 |
+
{
|
| 69 |
+
"data": {
|
| 70 |
+
"text/plain": [
|
| 71 |
+
"24"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
"execution_count": 25,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"output_type": "execute_result"
|
| 77 |
+
}
|
| 78 |
+
],
|
| 79 |
+
"source": [
|
| 80 |
+
"len(temp)"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"id": "4c254a11",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": []
|
| 90 |
+
}
|
| 91 |
+
],
|
| 92 |
+
"metadata": {
|
| 93 |
+
"kernelspec": {
|
| 94 |
+
"display_name": "venv",
|
| 95 |
+
"language": "python",
|
| 96 |
+
"name": "python3"
|
| 97 |
+
},
|
| 98 |
+
"language_info": {
|
| 99 |
+
"codemirror_mode": {
|
| 100 |
+
"name": "ipython",
|
| 101 |
+
"version": 3
|
| 102 |
+
},
|
| 103 |
+
"file_extension": ".py",
|
| 104 |
+
"mimetype": "text/x-python",
|
| 105 |
+
"name": "python",
|
| 106 |
+
"nbconvert_exporter": "python",
|
| 107 |
+
"pygments_lexer": "ipython3",
|
| 108 |
+
"version": "3.11.0"
|
| 109 |
+
}
|
| 110 |
+
},
|
| 111 |
+
"nbformat": 4,
|
| 112 |
+
"nbformat_minor": 5
|
| 113 |
+
}
|
src/__init__.py
ADDED
|
File without changes
|
src/constant.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
src/data_processor.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from src.constant import BASE_DIR
|
| 3 |
+
from langchain.schema import Document
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
|
| 6 |
+
DATA_DIR = os.path.join(BASE_DIR, "data", "pmc")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DataProcessor:
|
| 10 |
+
"""
|
| 11 |
+
Handles loading, cleaning, and chunking of text files
|
| 12 |
+
from the PubMed Central (PMC) dataset.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, data_path: str = DATA_DIR):
|
| 16 |
+
self.data_path = data_path
|
| 17 |
+
|
| 18 |
+
def _load_files(self) -> list[dict]:
|
| 19 |
+
"""
|
| 20 |
+
Load raw text files from the dataset directory.
|
| 21 |
+
Returns a list of dictionaries with file name and raw content.
|
| 22 |
+
"""
|
| 23 |
+
count = 0
|
| 24 |
+
data_list = []
|
| 25 |
+
for file_name in os.listdir(self.data_path):
|
| 26 |
+
if not file_name.endswith(".txt"):
|
| 27 |
+
continue
|
| 28 |
+
file_path = os.path.join(self.data_path, file_name)
|
| 29 |
+
with open(file_path, "r", encoding="utf-8") as file_ref:
|
| 30 |
+
data_list.append(
|
| 31 |
+
{
|
| 32 |
+
"file_name": file_name,
|
| 33 |
+
"page_content": file_ref.read()
|
| 34 |
+
}
|
| 35 |
+
)
|
| 36 |
+
if count >= 2:
|
| 37 |
+
break
|
| 38 |
+
count += 1
|
| 39 |
+
|
| 40 |
+
return data_list
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def _decode_unicode(text: str) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Convert escaped unicode sequences to proper text.
|
| 46 |
+
"""
|
| 47 |
+
if not isinstance(text, str):
|
| 48 |
+
return text
|
| 49 |
+
try:
|
| 50 |
+
return text.encode("utf-8").decode("unicode-escape")
|
| 51 |
+
except Exception:
|
| 52 |
+
return text
|
| 53 |
+
|
| 54 |
+
def _preprocess(self, data: list[dict]) -> list[dict]:
|
| 55 |
+
"""
|
| 56 |
+
Apply preprocessing steps (e.g., unicode decoding) to raw data.
|
| 57 |
+
"""
|
| 58 |
+
cleaned_data = []
|
| 59 |
+
for record in data:
|
| 60 |
+
decoded_text = self._decode_unicode(record["page_content"])
|
| 61 |
+
cleaned_data.append(
|
| 62 |
+
{
|
| 63 |
+
"file_name": record["file_name"],
|
| 64 |
+
"page_content": decoded_text
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
return cleaned_data
|
| 68 |
+
|
| 69 |
+
def load_documents(self) -> list[Document]:
|
| 70 |
+
"""
|
| 71 |
+
Load and preprocess text files, converting them into
|
| 72 |
+
LangChain Document objects.
|
| 73 |
+
"""
|
| 74 |
+
raw_data = self._load_files()
|
| 75 |
+
cleaned_data = self._preprocess(raw_data)
|
| 76 |
+
|
| 77 |
+
return [
|
| 78 |
+
Document(
|
| 79 |
+
page_content=item["page_content"],
|
| 80 |
+
metadata={"source": item["file_name"]}
|
| 81 |
+
)
|
| 82 |
+
for item in cleaned_data
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def chunk_documents(documents: list[Document],
|
| 87 |
+
chunk_size: int = 1000,
|
| 88 |
+
chunk_overlap: int = 200) -> list[Document]:
|
| 89 |
+
"""
|
| 90 |
+
Split documents into smaller chunks for embedding and retrieval.
|
| 91 |
+
"""
|
| 92 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 93 |
+
chunk_size=chunk_size,
|
| 94 |
+
chunk_overlap=chunk_overlap,
|
| 95 |
+
length_function=len
|
| 96 |
+
)
|
| 97 |
+
return splitter.split_documents(documents)
|
| 98 |
+
|
| 99 |
+
def build(self) -> tuple[list[Document], list[Document]]:
|
| 100 |
+
"""
|
| 101 |
+
End-to-end pipeline:
|
| 102 |
+
- Load documents
|
| 103 |
+
- Chunk them
|
| 104 |
+
Returns (chunks, original documents).
|
| 105 |
+
"""
|
| 106 |
+
documents = self.load_documents()
|
| 107 |
+
chunks = self.chunk_documents(documents)
|
| 108 |
+
return chunks, documents
|
src/embedding.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import numpy as np
|
| 3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EmbeddingManager:
|
| 7 |
+
def __init__(self, model_name: str = "pritamdeka/S-BioBERT-snli-multinli-stsb"):
|
| 8 |
+
self.model_name = model_name
|
| 9 |
+
self.model = None
|
| 10 |
+
self.load_model()
|
| 11 |
+
|
| 12 |
+
def load_model(self):
|
| 13 |
+
print("Loading embedding model:", self.model_name)
|
| 14 |
+
self.model = HuggingFaceEmbeddings(model_name=self.model_name)
|
| 15 |
+
print("Model loaded.")
|
| 16 |
+
|
| 17 |
+
def get_model(self):
|
| 18 |
+
return self.model
|
| 19 |
+
|
| 20 |
+
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
| 21 |
+
if self.model is None:
|
| 22 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 23 |
+
return self.model.embed_documents(texts)
|
| 24 |
+
|
| 25 |
+
def embed_one(self, text: str) -> np.ndarray:
|
| 26 |
+
return self.model.embed_query(text)
|
src/vectorstore.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from src.constant import BASE_DIR
|
| 6 |
+
import chromadb
|
| 7 |
+
from langchain.vectorstores import Chroma
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
|
| 11 |
+
DATA_DIR = os.path.join(BASE_DIR, "db")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class VectorStore:
|
| 15 |
+
"""
|
| 16 |
+
Wrapper around Chroma vector database for persistent storage
|
| 17 |
+
and retrieval of document embeddings.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
collection_name: str = "medrag",
|
| 22 |
+
persist_directory: str = DATA_DIR):
|
| 23 |
+
self.collection_name = collection_name
|
| 24 |
+
self.persist_directory = persist_directory
|
| 25 |
+
self.client = None
|
| 26 |
+
self.collection = None
|
| 27 |
+
self._initialize_store()
|
| 28 |
+
|
| 29 |
+
def _initialize_store(self):
|
| 30 |
+
"""Initialize Chroma client and collection."""
|
| 31 |
+
try:
|
| 32 |
+
dir_path = Path(self.persist_directory)
|
| 33 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.client = chromadb.PersistentClient(self.persist_directory)
|
| 36 |
+
self.collection = self.client.get_or_create_collection(
|
| 37 |
+
name=self.collection_name,
|
| 38 |
+
metadata={"description": "RAG collection for biomedical research"}
|
| 39 |
+
)
|
| 40 |
+
print(f"Store initialized successfully: {self.collection_name}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error initializing the store: {e}")
|
| 43 |
+
raise
|
| 44 |
+
|
| 45 |
+
def get_len(self) -> int:
|
| 46 |
+
"""Return number of documents in the collection."""
|
| 47 |
+
return self.collection.count()
|
| 48 |
+
|
| 49 |
+
def add_documents(self, documents: List[Document], embeddings: np.ndarray, batch_size: int = 5000):
|
| 50 |
+
"""
|
| 51 |
+
Add documents and their embeddings to the vector store in batches.
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(embeddings, np.ndarray):
|
| 54 |
+
embeddings = embeddings.tolist() # Ensure compatibility
|
| 55 |
+
|
| 56 |
+
for start in range(0, len(documents), batch_size):
|
| 57 |
+
batch_docs = documents[start:start + batch_size]
|
| 58 |
+
batch_embeds = embeddings[start:start + batch_size]
|
| 59 |
+
|
| 60 |
+
ids, metadatas, texts, embeds = [], [], [], []
|
| 61 |
+
|
| 62 |
+
for idx, (doc, emb) in enumerate(zip(batch_docs, batch_embeds)):
|
| 63 |
+
ids.append(f"doc_{uuid4().hex}")
|
| 64 |
+
texts.append(doc.page_content)
|
| 65 |
+
metadata = dict(doc.metadata) if getattr(doc, "metadata", None) else {}
|
| 66 |
+
metadata.update({"doc_index": idx, "content_length": len(doc.page_content)})
|
| 67 |
+
metadatas.append(metadata)
|
| 68 |
+
embeds.append(emb)
|
| 69 |
+
|
| 70 |
+
self.collection.add(
|
| 71 |
+
ids=ids,
|
| 72 |
+
documents=texts,
|
| 73 |
+
embeddings=embeds,
|
| 74 |
+
metadatas=metadatas
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
print(f"Documents and embeddings added to collection: {self.collection_name}")
|
| 78 |
+
|
| 79 |
+
def get_retriever(self, embedding_function, search_kwargs: dict = None):
|
| 80 |
+
"""
|
| 81 |
+
Return a retriever interface for semantic search.
|
| 82 |
+
"""
|
| 83 |
+
if search_kwargs is None:
|
| 84 |
+
search_kwargs = {"k": 5}
|
| 85 |
+
|
| 86 |
+
vectorstore = Chroma(
|
| 87 |
+
client=self.client,
|
| 88 |
+
collection_name=self.collection_name,
|
| 89 |
+
embedding_function=embedding_function
|
| 90 |
+
)
|
| 91 |
+
return vectorstore.as_retriever(search_kwargs=search_kwargs)
|