Spaces:
Sleeping
Sleeping
File size: 6,379 Bytes
779b4bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
Document Loading and Chunking Module
"""
from typing import List, Dict, Optional
from datasets import load_dataset
import re
class DocumentLoader:
"""Loads and chunks documents for RAG system"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.documents = []
self.chunks = []
self.chunk_metadata = []
def load_from_huggingface(
self,
dataset_name: str,
split: str = "train",
text_column: Optional[str] = None,
max_docs: Optional[int] = None,
hf_token: Optional[str] = None
):
"""Load documents from HuggingFace dataset"""
print(f"Loading dataset: {dataset_name}")
if hf_token:
print("Using HuggingFace token for authentication")
try:
print("Using streaming mode for faster loading...")
dataset = load_dataset(
dataset_name,
split=split,
streaming=True,
token=hf_token if hf_token else None
)
except Exception as e:
if "429" in str(e) or "rate limit" in str(e).lower():
raise Exception(
f"Rate limit error: {str(e)}\n\n"
"To fix this:\n"
"1. Create a free HuggingFace account at https://huggingface.co/join\n"
"2. Get your token at https://huggingface.co/settings/tokens\n"
"3. Add it in the 'HuggingFace Token' field above"
)
raise
documents = []
count = 0
for item in dataset:
if max_docs and count >= max_docs:
break
if "chapters" in item and isinstance(item["chapters"], list):
doc_text_parts = []
if "title" in item and item["title"]:
doc_text_parts.append(item["title"])
if "abstract" in item and item["abstract"]:
doc_text_parts.append(item["abstract"])
for chapter in item["chapters"]:
if isinstance(chapter, dict):
if "head" in chapter and chapter["head"]:
doc_text_parts.append(chapter["head"])
if "paragraphs" in chapter and isinstance(chapter["paragraphs"], list):
for para in chapter["paragraphs"]:
if isinstance(para, dict) and "text" in para and para["text"]:
doc_text_parts.append(para["text"])
full_text = "\n\n".join(doc_text_parts)
if full_text.strip():
documents.append(full_text)
count += 1
if count % 10 == 0:
print(f" Loaded {count} documents...")
elif text_column and text_column in item:
if isinstance(item[text_column], str):
documents.append(item[text_column])
count += 1
elif isinstance(item[text_column], list):
documents.append("\n\n".join(str(p) for p in item[text_column]))
count += 1
elif not text_column:
for key in ["text", "content", "body", "context"]:
if key in item and isinstance(item[key], str):
documents.append(item[key])
count += 1
if count % 10 == 0:
print(f" Loaded {count} documents...")
break
self.documents = documents
print(f"Loaded {len(self.documents)} documents from dataset")
def load_from_texts(self, texts: List[str]):
"""Load documents from list of text strings"""
self.documents = texts
print(f"Loaded {len(self.documents)} documents")
def _split_text(self, text: str) -> List[str]:
"""Split text into chunks with overlap"""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) > self.chunk_size and current_chunk:
chunks.append(current_chunk.strip())
overlap_text = current_chunk[-self.chunk_overlap:] if len(current_chunk) > self.chunk_overlap else current_chunk
current_chunk = overlap_text + " " + sentence
else:
current_chunk += " " + sentence if current_chunk else sentence
if current_chunk.strip():
chunks.append(current_chunk.strip())
if not chunks and text.strip():
chunks = [text.strip()]
return chunks
def _is_low_quality_chunk(self, chunk: str) -> bool:
"""Check if a chunk is low quality"""
if len(chunk.strip()) < 50:
return True
if len(chunk.split()) < 10:
return True
return False
def chunk_documents(self):
"""Chunk all loaded documents"""
self.chunks = []
self.chunk_metadata = []
for doc_idx, doc in enumerate(self.documents):
doc_chunks = self._split_text(doc)
for chunk_idx, chunk in enumerate(doc_chunks):
if not self._is_low_quality_chunk(chunk):
self.chunks.append(chunk)
self.chunk_metadata.append({
'doc_id': doc_idx,
'chunk_id': chunk_idx,
'total_chunks_in_doc': len(doc_chunks)
})
print(f"Created {len(self.chunks)} chunks from {len(self.documents)} documents")
def get_chunks(self) -> List[str]:
"""Get list of all chunks"""
return self.chunks
def get_chunk_metadata(self) -> List[Dict]:
"""Get metadata for all chunks"""
return self.chunk_metadata
|