Spaces:
Paused
Paused
| import os | |
| import shutil | |
| import json | |
| import joblib | |
| import numpy as np | |
| import requests | |
| import io | |
| import pypdf | |
| from bs4 import BeautifulSoup | |
| from huggingface_hub import HfApi, login | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.cluster import MiniBatchKMeans | |
| class DocumentHandler: | |
| def __init__(self, chunk_size=512, chunk_overlap=50): | |
| self.hf_token = os.environ.get("HF_TOKEN") | |
| if self.hf_token: | |
| login(token=self.hf_token) | |
| self.api = HfApi() | |
| self.cluster_model = None | |
| self.id_map = None | |
| self.embedding_model = None | |
| self.loaded = False | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| def load_embedding_model(self): | |
| if self.embedding_model is None: | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| def chunk_text(self, text, chunk_size=None, overlap=None): | |
| """ | |
| Split text into overlapping chunks for better context preservation. | |
| Args: | |
| text: Input text to chunk | |
| chunk_size: Maximum characters per chunk (default: self.chunk_size) | |
| overlap: Characters to overlap between chunks (default: self.chunk_overlap) | |
| Returns: | |
| List of text chunks | |
| """ | |
| if chunk_size is None: | |
| chunk_size = self.chunk_size | |
| if overlap is None: | |
| overlap = self.chunk_overlap | |
| if len(text) <= chunk_size: | |
| return [text] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| # If not the last chunk, try to break at sentence boundary | |
| if end < len(text): | |
| # Look for sentence endings within the last 20% of chunk | |
| search_start = end - int(chunk_size * 0.2) | |
| chunk_section = text[search_start:end] | |
| # Find last sentence boundary | |
| for delimiter in ['. ', '.\n', '! ', '!\n', '? ', '?\n', '\n\n']: | |
| pos = chunk_section.rfind(delimiter) | |
| if pos != -1: | |
| end = search_start + pos + len(delimiter) | |
| break | |
| chunks.append(text[start:end].strip()) | |
| start = end - overlap | |
| # Prevent infinite loop | |
| if start >= len(text): | |
| break | |
| return chunks | |
| def chunk_by_paragraphs(self, text, max_chunk_size=None): | |
| """ | |
| Chunk text by paragraphs, combining small paragraphs and splitting large ones. | |
| Args: | |
| text: Input text to chunk | |
| max_chunk_size: Maximum size per chunk | |
| Returns: | |
| List of text chunks | |
| """ | |
| if max_chunk_size is None: | |
| max_chunk_size = self.chunk_size | |
| paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] | |
| chunks = [] | |
| current_chunk = [] | |
| current_size = 0 | |
| for para in paragraphs: | |
| para_size = len(para) | |
| # If paragraph is too large, split it | |
| if para_size > max_chunk_size: | |
| if current_chunk: | |
| chunks.append('\n\n'.join(current_chunk)) | |
| current_chunk = [] | |
| current_size = 0 | |
| chunks.extend(self.chunk_text(para, max_chunk_size, self.chunk_overlap)) | |
| # If adding paragraph exceeds limit, save current chunk | |
| elif current_size + para_size > max_chunk_size: | |
| if current_chunk: | |
| chunks.append('\n\n'.join(current_chunk)) | |
| current_chunk = [para] | |
| current_size = para_size | |
| # Add paragraph to current chunk | |
| else: | |
| current_chunk.append(para) | |
| current_size += para_size + 2 # +2 for \n\n | |
| # Add remaining chunk | |
| if current_chunk: | |
| chunks.append('\n\n'.join(current_chunk)) | |
| return chunks | |
| def process_file(self, file_storage, filename): | |
| """ | |
| Process file and return chunks (default behavior). | |
| Args: | |
| file_storage: File object | |
| filename: Name of the file | |
| Returns: | |
| List of text chunks with metadata | |
| """ | |
| text_content = "" | |
| try: | |
| filename = filename.lower() | |
| if filename.endswith('.pdf'): | |
| pdf_stream = io.BytesIO(file_storage.read()) | |
| reader = pypdf.PdfReader(pdf_stream) | |
| chunks = [] | |
| for page in reader.pages: | |
| chunks.append(page.extract_text()) | |
| text_content = "\n".join(chunks) | |
| elif filename.endswith(('.txt', '.md', '.py', '.js', '.html', '.json', '.csv')): | |
| text_content = file_storage.read().decode('utf-8', errors='ignore') | |
| else: | |
| return [{"error": f"Unsupported file type: {filename}"}] | |
| cleaned = self._clean_text(text_content) | |
| text_chunks = self.chunk_by_paragraphs(cleaned) | |
| # Add metadata to each chunk | |
| result = [] | |
| for idx, chunk in enumerate(text_chunks): | |
| chunk_data = { | |
| "text": chunk, | |
| "source": filename, | |
| "chunk_id": idx, | |
| "total_chunks": len(text_chunks) | |
| } | |
| result.append(chunk_data) | |
| return result | |
| except Exception as e: | |
| return [{"error": f"Error processing file {filename}: {str(e)}"}] | |
| def process_url(self, url): | |
| """ | |
| Process URL and return chunks (default behavior). | |
| Args: | |
| url: URL to process | |
| Returns: | |
| List of text chunks with metadata | |
| """ | |
| try: | |
| headers = {'User-Agent': 'VisMemBot/1.0'} | |
| response = requests.get(url, headers=headers, timeout=10) | |
| content_type = response.headers.get('Content-Type', '').lower() | |
| text_content = "" | |
| if 'application/pdf' in content_type or url.lower().endswith('.pdf'): | |
| pdf_stream = io.BytesIO(response.content) | |
| reader = pypdf.PdfReader(pdf_stream) | |
| chunks = [] | |
| for page in reader.pages: | |
| chunks.append(page.extract_text()) | |
| text_content = "\n".join(chunks) | |
| title = f"PDF: {url}" | |
| else: | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| for script in soup(["script", "style", "nav", "footer", "header"]): | |
| script.extract() | |
| text_content = soup.get_text() | |
| title = soup.title.string if soup.title else url | |
| cleaned = self._clean_text(text_content) | |
| text_chunks = self.chunk_by_paragraphs(cleaned) | |
| # Add metadata to each chunk | |
| result = [] | |
| for idx, chunk in enumerate(text_chunks): | |
| chunk_data = { | |
| "text": chunk, | |
| "source": url, | |
| "title": title, | |
| "chunk_id": idx, | |
| "total_chunks": len(text_chunks) | |
| } | |
| result.append(chunk_data) | |
| return result | |
| except Exception as e: | |
| return [{"error": f"Error processing URL {url}: {str(e)}"}] | |
| def _clean_text(self, text): | |
| lines = (line.strip() for line in text.splitlines()) | |
| chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) | |
| text = '\n'.join(chunk for chunk in chunks if chunk) | |
| return text | |
| def build_dataset_index(self, repo_id, dataset_name="wikitext", config="wikitext-103-v1", split="train"): | |
| try: | |
| self.load_embedding_model() | |
| local_path = "lightweight_index" | |
| if os.path.exists(local_path): shutil.rmtree(local_path) | |
| os.makedirs(local_path) | |
| yield f"Streaming {dataset_name}..." | |
| dataset = load_dataset(dataset_name, config, split=split, streaming=True) | |
| embeddings_list = [] | |
| doc_ids = [] | |
| yield "Vectorizing documents with chunking..." | |
| for i, doc in enumerate(dataset.take(3000)): | |
| text = doc.get("text", "") | |
| if len(text) > 50: | |
| # Chunk long documents | |
| chunks = self.chunk_text(text, chunk_size=512, overlap=50) | |
| for chunk_idx, chunk in enumerate(chunks): | |
| embeddings_list.append(self.embedding_model.encode(chunk)) | |
| doc_ids.append(f"doc_{i}_chunk_{chunk_idx}") | |
| embeddings = np.array(embeddings_list) | |
| yield f"Clustering {len(embeddings)} vectors..." | |
| n_clusters = min(300, len(embeddings)//5) | |
| kmeans = MiniBatchKMeans(n_clusters=n_clusters, batch_size=256, n_init="auto") | |
| kmeans.fit(embeddings) | |
| labels = kmeans.labels_ | |
| cluster_id_map = {int(i): [] for i in range(len(kmeans.cluster_centers_))} | |
| for i, label in enumerate(labels): | |
| cluster_id_map[int(label)].append(doc_ids[i]) | |
| yield "Saving artifacts..." | |
| joblib.dump(kmeans, os.path.join(local_path, "kmeans_model.joblib")) | |
| with open(os.path.join(local_path, "id_map.json"), "w") as f: | |
| json.dump(cluster_id_map, f) | |
| yield f"Uploading to Hub: {repo_id}..." | |
| self.api.create_repo(repo_id=repo_id, repo_type="dataset", token=self.hf_token, exist_ok=True) | |
| self.api.upload_folder(folder_path=local_path, repo_id=repo_id, repo_type="dataset", token=self.hf_token) | |
| yield "Done. Index built." | |
| except Exception as e: | |
| yield f"Error: {str(e)}" | |
| def load_index(self, repo_id): | |
| try: | |
| self.load_embedding_model() | |
| local_path = self.api.snapshot_download(repo_id=repo_id, repo_type="dataset", token=self.hf_token) | |
| self.cluster_model = joblib.load(os.path.join(local_path, "kmeans_model.joblib")) | |
| with open(os.path.join(local_path, "id_map.json"), "r") as f: | |
| self.id_map = {int(k): v for k, v in json.load(f).items()} | |
| self.loaded = True | |
| return True, f"Index loaded with {len(self.id_map)} clusters." | |
| except Exception as e: | |
| return False, str(e) | |
| def retrieve(self, query): | |
| if not self.loaded: return "" | |
| q_vec = self.embedding_model.encode([query]) | |
| cluster_id = self.cluster_model.predict(q_vec)[0] | |
| hits = self.id_map.get(cluster_id, []) | |
| return f"[RAG Database]: Found {len(hits)} relevant documents in Cluster #{cluster_id}." |