import os import shutil import json import joblib import numpy as np import requests import io import pypdf from bs4 import BeautifulSoup from huggingface_hub import HfApi, login from datasets import load_dataset from sentence_transformers import SentenceTransformer from sklearn.cluster import MiniBatchKMeans class DocumentHandler: def __init__(self, chunk_size=512, chunk_overlap=50): self.hf_token = os.environ.get("HF_TOKEN") if self.hf_token: login(token=self.hf_token) self.api = HfApi() self.cluster_model = None self.id_map = None self.embedding_model = None self.loaded = False self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap def load_embedding_model(self): if self.embedding_model is None: self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') def chunk_text(self, text, chunk_size=None, overlap=None): """ Split text into overlapping chunks for better context preservation. Args: text: Input text to chunk chunk_size: Maximum characters per chunk (default: self.chunk_size) overlap: Characters to overlap between chunks (default: self.chunk_overlap) Returns: List of text chunks """ if chunk_size is None: chunk_size = self.chunk_size if overlap is None: overlap = self.chunk_overlap if len(text) <= chunk_size: return [text] chunks = [] start = 0 while start < len(text): end = start + chunk_size # If not the last chunk, try to break at sentence boundary if end < len(text): # Look for sentence endings within the last 20% of chunk search_start = end - int(chunk_size * 0.2) chunk_section = text[search_start:end] # Find last sentence boundary for delimiter in ['. ', '.\n', '! ', '!\n', '? ', '?\n', '\n\n']: pos = chunk_section.rfind(delimiter) if pos != -1: end = search_start + pos + len(delimiter) break chunks.append(text[start:end].strip()) start = end - overlap # Prevent infinite loop if start >= len(text): break return chunks def chunk_by_paragraphs(self, text, max_chunk_size=None): """ Chunk text by paragraphs, combining small paragraphs and splitting large ones. Args: text: Input text to chunk max_chunk_size: Maximum size per chunk Returns: List of text chunks """ if max_chunk_size is None: max_chunk_size = self.chunk_size paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] chunks = [] current_chunk = [] current_size = 0 for para in paragraphs: para_size = len(para) # If paragraph is too large, split it if para_size > max_chunk_size: if current_chunk: chunks.append('\n\n'.join(current_chunk)) current_chunk = [] current_size = 0 chunks.extend(self.chunk_text(para, max_chunk_size, self.chunk_overlap)) # If adding paragraph exceeds limit, save current chunk elif current_size + para_size > max_chunk_size: if current_chunk: chunks.append('\n\n'.join(current_chunk)) current_chunk = [para] current_size = para_size # Add paragraph to current chunk else: current_chunk.append(para) current_size += para_size + 2 # +2 for \n\n # Add remaining chunk if current_chunk: chunks.append('\n\n'.join(current_chunk)) return chunks def process_file(self, file_storage, filename): """ Process file and return chunks (default behavior). Args: file_storage: File object filename: Name of the file Returns: List of text chunks with metadata """ text_content = "" try: filename = filename.lower() if filename.endswith('.pdf'): pdf_stream = io.BytesIO(file_storage.read()) reader = pypdf.PdfReader(pdf_stream) chunks = [] for page in reader.pages: chunks.append(page.extract_text()) text_content = "\n".join(chunks) elif filename.endswith(('.txt', '.md', '.py', '.js', '.html', '.json', '.csv')): text_content = file_storage.read().decode('utf-8', errors='ignore') else: return [{"error": f"Unsupported file type: {filename}"}] cleaned = self._clean_text(text_content) text_chunks = self.chunk_by_paragraphs(cleaned) # Add metadata to each chunk result = [] for idx, chunk in enumerate(text_chunks): chunk_data = { "text": chunk, "source": filename, "chunk_id": idx, "total_chunks": len(text_chunks) } result.append(chunk_data) return result except Exception as e: return [{"error": f"Error processing file {filename}: {str(e)}"}] def process_url(self, url): """ Process URL and return chunks (default behavior). Args: url: URL to process Returns: List of text chunks with metadata """ try: headers = {'User-Agent': 'VisMemBot/1.0'} response = requests.get(url, headers=headers, timeout=10) content_type = response.headers.get('Content-Type', '').lower() text_content = "" if 'application/pdf' in content_type or url.lower().endswith('.pdf'): pdf_stream = io.BytesIO(response.content) reader = pypdf.PdfReader(pdf_stream) chunks = [] for page in reader.pages: chunks.append(page.extract_text()) text_content = "\n".join(chunks) title = f"PDF: {url}" else: soup = BeautifulSoup(response.content, 'html.parser') for script in soup(["script", "style", "nav", "footer", "header"]): script.extract() text_content = soup.get_text() title = soup.title.string if soup.title else url cleaned = self._clean_text(text_content) text_chunks = self.chunk_by_paragraphs(cleaned) # Add metadata to each chunk result = [] for idx, chunk in enumerate(text_chunks): chunk_data = { "text": chunk, "source": url, "title": title, "chunk_id": idx, "total_chunks": len(text_chunks) } result.append(chunk_data) return result except Exception as e: return [{"error": f"Error processing URL {url}: {str(e)}"}] def _clean_text(self, text): lines = (line.strip() for line in text.splitlines()) chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) text = '\n'.join(chunk for chunk in chunks if chunk) return text def build_dataset_index(self, repo_id, dataset_name="wikitext", config="wikitext-103-v1", split="train"): try: self.load_embedding_model() local_path = "lightweight_index" if os.path.exists(local_path): shutil.rmtree(local_path) os.makedirs(local_path) yield f"Streaming {dataset_name}..." dataset = load_dataset(dataset_name, config, split=split, streaming=True) embeddings_list = [] doc_ids = [] yield "Vectorizing documents with chunking..." for i, doc in enumerate(dataset.take(3000)): text = doc.get("text", "") if len(text) > 50: # Chunk long documents chunks = self.chunk_text(text, chunk_size=512, overlap=50) for chunk_idx, chunk in enumerate(chunks): embeddings_list.append(self.embedding_model.encode(chunk)) doc_ids.append(f"doc_{i}_chunk_{chunk_idx}") embeddings = np.array(embeddings_list) yield f"Clustering {len(embeddings)} vectors..." n_clusters = min(300, len(embeddings)//5) kmeans = MiniBatchKMeans(n_clusters=n_clusters, batch_size=256, n_init="auto") kmeans.fit(embeddings) labels = kmeans.labels_ cluster_id_map = {int(i): [] for i in range(len(kmeans.cluster_centers_))} for i, label in enumerate(labels): cluster_id_map[int(label)].append(doc_ids[i]) yield "Saving artifacts..." joblib.dump(kmeans, os.path.join(local_path, "kmeans_model.joblib")) with open(os.path.join(local_path, "id_map.json"), "w") as f: json.dump(cluster_id_map, f) yield f"Uploading to Hub: {repo_id}..." self.api.create_repo(repo_id=repo_id, repo_type="dataset", token=self.hf_token, exist_ok=True) self.api.upload_folder(folder_path=local_path, repo_id=repo_id, repo_type="dataset", token=self.hf_token) yield "Done. Index built." except Exception as e: yield f"Error: {str(e)}" def load_index(self, repo_id): try: self.load_embedding_model() local_path = self.api.snapshot_download(repo_id=repo_id, repo_type="dataset", token=self.hf_token) self.cluster_model = joblib.load(os.path.join(local_path, "kmeans_model.joblib")) with open(os.path.join(local_path, "id_map.json"), "r") as f: self.id_map = {int(k): v for k, v in json.load(f).items()} self.loaded = True return True, f"Index loaded with {len(self.id_map)} clusters." except Exception as e: return False, str(e) def retrieve(self, query): if not self.loaded: return "" q_vec = self.embedding_model.encode([query]) cluster_id = self.cluster_model.predict(q_vec)[0] hits = self.id_map.get(cluster_id, []) return f"[RAG Database]: Found {len(hits)} relevant documents in Cluster #{cluster_id}."