llm-project / rag_system /document_loader.py
=
Initial commit: RAG Q&A system for agricultural research
779b4bd
"""
Document Loading and Chunking Module
"""
from typing import List, Dict, Optional
from datasets import load_dataset
import re
class DocumentLoader:
"""Loads and chunks documents for RAG system"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.documents = []
self.chunks = []
self.chunk_metadata = []
def load_from_huggingface(
self,
dataset_name: str,
split: str = "train",
text_column: Optional[str] = None,
max_docs: Optional[int] = None,
hf_token: Optional[str] = None
):
"""Load documents from HuggingFace dataset"""
print(f"Loading dataset: {dataset_name}")
if hf_token:
print("Using HuggingFace token for authentication")
try:
print("Using streaming mode for faster loading...")
dataset = load_dataset(
dataset_name,
split=split,
streaming=True,
token=hf_token if hf_token else None
)
except Exception as e:
if "429" in str(e) or "rate limit" in str(e).lower():
raise Exception(
f"Rate limit error: {str(e)}\n\n"
"To fix this:\n"
"1. Create a free HuggingFace account at https://huggingface.co/join\n"
"2. Get your token at https://huggingface.co/settings/tokens\n"
"3. Add it in the 'HuggingFace Token' field above"
)
raise
documents = []
count = 0
for item in dataset:
if max_docs and count >= max_docs:
break
if "chapters" in item and isinstance(item["chapters"], list):
doc_text_parts = []
if "title" in item and item["title"]:
doc_text_parts.append(item["title"])
if "abstract" in item and item["abstract"]:
doc_text_parts.append(item["abstract"])
for chapter in item["chapters"]:
if isinstance(chapter, dict):
if "head" in chapter and chapter["head"]:
doc_text_parts.append(chapter["head"])
if "paragraphs" in chapter and isinstance(chapter["paragraphs"], list):
for para in chapter["paragraphs"]:
if isinstance(para, dict) and "text" in para and para["text"]:
doc_text_parts.append(para["text"])
full_text = "\n\n".join(doc_text_parts)
if full_text.strip():
documents.append(full_text)
count += 1
if count % 10 == 0:
print(f" Loaded {count} documents...")
elif text_column and text_column in item:
if isinstance(item[text_column], str):
documents.append(item[text_column])
count += 1
elif isinstance(item[text_column], list):
documents.append("\n\n".join(str(p) for p in item[text_column]))
count += 1
elif not text_column:
for key in ["text", "content", "body", "context"]:
if key in item and isinstance(item[key], str):
documents.append(item[key])
count += 1
if count % 10 == 0:
print(f" Loaded {count} documents...")
break
self.documents = documents
print(f"Loaded {len(self.documents)} documents from dataset")
def load_from_texts(self, texts: List[str]):
"""Load documents from list of text strings"""
self.documents = texts
print(f"Loaded {len(self.documents)} documents")
def _split_text(self, text: str) -> List[str]:
"""Split text into chunks with overlap"""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) > self.chunk_size and current_chunk:
chunks.append(current_chunk.strip())
overlap_text = current_chunk[-self.chunk_overlap:] if len(current_chunk) > self.chunk_overlap else current_chunk
current_chunk = overlap_text + " " + sentence
else:
current_chunk += " " + sentence if current_chunk else sentence
if current_chunk.strip():
chunks.append(current_chunk.strip())
if not chunks and text.strip():
chunks = [text.strip()]
return chunks
def _is_low_quality_chunk(self, chunk: str) -> bool:
"""Check if a chunk is low quality"""
if len(chunk.strip()) < 50:
return True
if len(chunk.split()) < 10:
return True
return False
def chunk_documents(self):
"""Chunk all loaded documents"""
self.chunks = []
self.chunk_metadata = []
for doc_idx, doc in enumerate(self.documents):
doc_chunks = self._split_text(doc)
for chunk_idx, chunk in enumerate(doc_chunks):
if not self._is_low_quality_chunk(chunk):
self.chunks.append(chunk)
self.chunk_metadata.append({
'doc_id': doc_idx,
'chunk_id': chunk_idx,
'total_chunks_in_doc': len(doc_chunks)
})
print(f"Created {len(self.chunks)} chunks from {len(self.documents)} documents")
def get_chunks(self) -> List[str]:
"""Get list of all chunks"""
return self.chunks
def get_chunk_metadata(self) -> List[Dict]:
"""Get metadata for all chunks"""
return self.chunk_metadata