Spaces:
Sleeping
Sleeping
| """ | |
| Document Loading and Chunking Module | |
| """ | |
| from typing import List, Dict, Optional | |
| from datasets import load_dataset | |
| import re | |
| class DocumentLoader: | |
| """Loads and chunks documents for RAG system""" | |
| def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.documents = [] | |
| self.chunks = [] | |
| self.chunk_metadata = [] | |
| def load_from_huggingface( | |
| self, | |
| dataset_name: str, | |
| split: str = "train", | |
| text_column: Optional[str] = None, | |
| max_docs: Optional[int] = None, | |
| hf_token: Optional[str] = None | |
| ): | |
| """Load documents from HuggingFace dataset""" | |
| print(f"Loading dataset: {dataset_name}") | |
| if hf_token: | |
| print("Using HuggingFace token for authentication") | |
| try: | |
| print("Using streaming mode for faster loading...") | |
| dataset = load_dataset( | |
| dataset_name, | |
| split=split, | |
| streaming=True, | |
| token=hf_token if hf_token else None | |
| ) | |
| except Exception as e: | |
| if "429" in str(e) or "rate limit" in str(e).lower(): | |
| raise Exception( | |
| f"Rate limit error: {str(e)}\n\n" | |
| "To fix this:\n" | |
| "1. Create a free HuggingFace account at https://huggingface.co/join\n" | |
| "2. Get your token at https://huggingface.co/settings/tokens\n" | |
| "3. Add it in the 'HuggingFace Token' field above" | |
| ) | |
| raise | |
| documents = [] | |
| count = 0 | |
| for item in dataset: | |
| if max_docs and count >= max_docs: | |
| break | |
| if "chapters" in item and isinstance(item["chapters"], list): | |
| doc_text_parts = [] | |
| if "title" in item and item["title"]: | |
| doc_text_parts.append(item["title"]) | |
| if "abstract" in item and item["abstract"]: | |
| doc_text_parts.append(item["abstract"]) | |
| for chapter in item["chapters"]: | |
| if isinstance(chapter, dict): | |
| if "head" in chapter and chapter["head"]: | |
| doc_text_parts.append(chapter["head"]) | |
| if "paragraphs" in chapter and isinstance(chapter["paragraphs"], list): | |
| for para in chapter["paragraphs"]: | |
| if isinstance(para, dict) and "text" in para and para["text"]: | |
| doc_text_parts.append(para["text"]) | |
| full_text = "\n\n".join(doc_text_parts) | |
| if full_text.strip(): | |
| documents.append(full_text) | |
| count += 1 | |
| if count % 10 == 0: | |
| print(f" Loaded {count} documents...") | |
| elif text_column and text_column in item: | |
| if isinstance(item[text_column], str): | |
| documents.append(item[text_column]) | |
| count += 1 | |
| elif isinstance(item[text_column], list): | |
| documents.append("\n\n".join(str(p) for p in item[text_column])) | |
| count += 1 | |
| elif not text_column: | |
| for key in ["text", "content", "body", "context"]: | |
| if key in item and isinstance(item[key], str): | |
| documents.append(item[key]) | |
| count += 1 | |
| if count % 10 == 0: | |
| print(f" Loaded {count} documents...") | |
| break | |
| self.documents = documents | |
| print(f"Loaded {len(self.documents)} documents from dataset") | |
| def load_from_texts(self, texts: List[str]): | |
| """Load documents from list of text strings""" | |
| self.documents = texts | |
| print(f"Loaded {len(self.documents)} documents") | |
| def _split_text(self, text: str) -> List[str]: | |
| """Split text into chunks with overlap""" | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) > self.chunk_size and current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| overlap_text = current_chunk[-self.chunk_overlap:] if len(current_chunk) > self.chunk_overlap else current_chunk | |
| current_chunk = overlap_text + " " + sentence | |
| else: | |
| current_chunk += " " + sentence if current_chunk else sentence | |
| if current_chunk.strip(): | |
| chunks.append(current_chunk.strip()) | |
| if not chunks and text.strip(): | |
| chunks = [text.strip()] | |
| return chunks | |
| def _is_low_quality_chunk(self, chunk: str) -> bool: | |
| """Check if a chunk is low quality""" | |
| if len(chunk.strip()) < 50: | |
| return True | |
| if len(chunk.split()) < 10: | |
| return True | |
| return False | |
| def chunk_documents(self): | |
| """Chunk all loaded documents""" | |
| self.chunks = [] | |
| self.chunk_metadata = [] | |
| for doc_idx, doc in enumerate(self.documents): | |
| doc_chunks = self._split_text(doc) | |
| for chunk_idx, chunk in enumerate(doc_chunks): | |
| if not self._is_low_quality_chunk(chunk): | |
| self.chunks.append(chunk) | |
| self.chunk_metadata.append({ | |
| 'doc_id': doc_idx, | |
| 'chunk_id': chunk_idx, | |
| 'total_chunks_in_doc': len(doc_chunks) | |
| }) | |
| print(f"Created {len(self.chunks)} chunks from {len(self.documents)} documents") | |
| def get_chunks(self) -> List[str]: | |
| """Get list of all chunks""" | |
| return self.chunks | |
| def get_chunk_metadata(self) -> List[Dict]: | |
| """Get metadata for all chunks""" | |
| return self.chunk_metadata | |