File size: 6,379 Bytes
779b4bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Document Loading and Chunking Module
"""
from typing import List, Dict, Optional
from datasets import load_dataset
import re


class DocumentLoader:
    """Loads and chunks documents for RAG system"""
    
    def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.documents = []
        self.chunks = []
        self.chunk_metadata = []
        
    def load_from_huggingface(
        self, 
        dataset_name: str, 
        split: str = "train",
        text_column: Optional[str] = None,
        max_docs: Optional[int] = None,
        hf_token: Optional[str] = None
    ):
        """Load documents from HuggingFace dataset"""
        print(f"Loading dataset: {dataset_name}")
        if hf_token:
            print("Using HuggingFace token for authentication")
        
        try:
            print("Using streaming mode for faster loading...")
            dataset = load_dataset(
                dataset_name, 
                split=split,
                streaming=True,
                token=hf_token if hf_token else None
            )
        except Exception as e:
            if "429" in str(e) or "rate limit" in str(e).lower():
                raise Exception(
                    f"Rate limit error: {str(e)}\n\n"
                    "To fix this:\n"
                    "1. Create a free HuggingFace account at https://huggingface.co/join\n"
                    "2. Get your token at https://huggingface.co/settings/tokens\n"
                    "3. Add it in the 'HuggingFace Token' field above"
                )
            raise
        
        documents = []
        count = 0
        for item in dataset:
            if max_docs and count >= max_docs:
                break
            if "chapters" in item and isinstance(item["chapters"], list):
                doc_text_parts = []
                
                if "title" in item and item["title"]:
                    doc_text_parts.append(item["title"])
                
                if "abstract" in item and item["abstract"]:
                    doc_text_parts.append(item["abstract"])
                
                for chapter in item["chapters"]:
                    if isinstance(chapter, dict):
                        if "head" in chapter and chapter["head"]:
                            doc_text_parts.append(chapter["head"])
                        
                        if "paragraphs" in chapter and isinstance(chapter["paragraphs"], list):
                            for para in chapter["paragraphs"]:
                                if isinstance(para, dict) and "text" in para and para["text"]:
                                    doc_text_parts.append(para["text"])
                
                full_text = "\n\n".join(doc_text_parts)
                if full_text.strip():
                    documents.append(full_text)
                    count += 1
                    if count % 10 == 0:
                        print(f"  Loaded {count} documents...")
            
            elif text_column and text_column in item:
                if isinstance(item[text_column], str):
                    documents.append(item[text_column])
                    count += 1
                elif isinstance(item[text_column], list):
                    documents.append("\n\n".join(str(p) for p in item[text_column]))
                    count += 1
            
            elif not text_column:
                for key in ["text", "content", "body", "context"]:
                    if key in item and isinstance(item[key], str):
                        documents.append(item[key])
                        count += 1
                        if count % 10 == 0:
                            print(f"  Loaded {count} documents...")
                        break
        
        self.documents = documents
        print(f"Loaded {len(self.documents)} documents from dataset")
        
    def load_from_texts(self, texts: List[str]):
        """Load documents from list of text strings"""
        self.documents = texts
        print(f"Loaded {len(self.documents)} documents")
    
    def _split_text(self, text: str) -> List[str]:
        """Split text into chunks with overlap"""
        sentences = re.split(r'(?<=[.!?])\s+', text)
        
        chunks = []
        current_chunk = ""
        
        for sentence in sentences:
            if len(current_chunk) + len(sentence) > self.chunk_size and current_chunk:
                chunks.append(current_chunk.strip())
                overlap_text = current_chunk[-self.chunk_overlap:] if len(current_chunk) > self.chunk_overlap else current_chunk
                current_chunk = overlap_text + " " + sentence
            else:
                current_chunk += " " + sentence if current_chunk else sentence
                
        if current_chunk.strip():
            chunks.append(current_chunk.strip())
            
        if not chunks and text.strip():
            chunks = [text.strip()]
            
        return chunks
        
    def _is_low_quality_chunk(self, chunk: str) -> bool:
        """Check if a chunk is low quality"""
        
        if len(chunk.strip()) < 50:
            return True
        
        if len(chunk.split()) < 10:
            return True
        
        return False
    
    def chunk_documents(self):
        """Chunk all loaded documents"""
        self.chunks = []
        self.chunk_metadata = []
        
        for doc_idx, doc in enumerate(self.documents):
            doc_chunks = self._split_text(doc)
            
            for chunk_idx, chunk in enumerate(doc_chunks):
                if not self._is_low_quality_chunk(chunk):
                    self.chunks.append(chunk)
                    self.chunk_metadata.append({
                        'doc_id': doc_idx,
                        'chunk_id': chunk_idx,
                        'total_chunks_in_doc': len(doc_chunks)
                    })
                
        print(f"Created {len(self.chunks)} chunks from {len(self.documents)} documents")
        
    def get_chunks(self) -> List[str]:
        """Get list of all chunks"""
        return self.chunks
        
    def get_chunk_metadata(self) -> List[Dict]:
        """Get metadata for all chunks"""
        return self.chunk_metadata