| | import os |
| | import fitz |
| | import re |
| | import chromadb |
| | from chromadb.utils import embedding_functions |
| | import numpy as np |
| | import torch |
| | import logging |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| |
|
| | class VectorDatabase: |
| | """Vector database for storing and retrieving tenant rights information from PDF.""" |
| |
|
| | def __init__(self, persist_directory="./data/chroma_db"): |
| | """Initialize the vector database.""" |
| | logging.info("Initializing VectorDatabase") |
| | logging.info(f"NumPy version: {np.__version__}") |
| | logging.info(f"PyTorch version: {torch.__version__}") |
| |
|
| | self.persist_directory = persist_directory |
| | os.makedirs(persist_directory, exist_ok=True) |
| | |
| | try: |
| | logging.info("Creating embedding function") |
| | self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( |
| | model_name="all-MiniLM-L6-v2" |
| | ) |
| | |
| | logging.info("Initializing ChromaDB client") |
| | self.client = chromadb.PersistentClient(path=persist_directory) |
| | |
| | logging.info("Setting up collections") |
| | self.document_collection = self._get_or_create_collection("tenant_documents") |
| | self.state_collection = self._get_or_create_collection("tenant_states") |
| | except Exception as e: |
| | logging.error(f"Initialization failed: {str(e)}") |
| | raise |
| |
|
| | def _get_or_create_collection(self, name): |
| | """Get or create a collection with the given name.""" |
| | try: |
| | return self.client.get_collection( |
| | name=name, |
| | embedding_function=self.embedding_function |
| | ) |
| | except Exception: |
| | return self.client.create_collection( |
| | name=name, |
| | embedding_function=self.embedding_function |
| | ) |
| |
|
| | def extract_pdf_content(self, pdf_path): |
| | """Extract content from PDF file and identify state sections.""" |
| | logging.info(f"Extracting content from PDF: {pdf_path}") |
| | |
| | if not os.path.exists(pdf_path): |
| | raise FileNotFoundError(f"PDF file not found: {pdf_path}") |
| | |
| | doc = fitz.open(pdf_path) |
| | full_text = "" |
| | for page_num in range(len(doc)): |
| | page = doc.load_page(page_num) |
| | full_text += page.get_text("text") + "\n" |
| | doc.close() |
| | |
| | state_pattern = r"(?m)^\s*([A-Z][a-z]+(?:\s[A-Z][a-z]+)*)\s+Landlord(?:-|\s)Tenant\s+(?:Law|Laws)" |
| | state_matches = list(re.finditer(state_pattern, full_text)) |
| | |
| | if not state_matches: |
| | logging.info("No state sections found. Treating as single document.") |
| | return {"Full Document": full_text.strip()} |
| | |
| | state_sections = {} |
| | for i, match in enumerate(state_matches): |
| | state_name = match.group(1).strip() |
| | start_pos = match.end() |
| | end_pos = state_matches[i + 1].start() if i + 1 < len(state_matches) else len(full_text) |
| | state_text = full_text[start_pos:end_pos].strip() |
| | if state_text: |
| | state_sections[state_name] = state_text |
| | |
| | logging.info(f"Extracted content for {len(state_sections)} states") |
| | return state_sections |
| |
|
| | def process_and_load_pdf(self, pdf_path): |
| | """Process PDF and load content into vector database.""" |
| | state_sections = self.extract_pdf_content(pdf_path) |
| | |
| | doc_ids = self.document_collection.get()["ids"] |
| | state_ids = self.state_collection.get()["ids"] |
| | |
| | if doc_ids: |
| | self.document_collection.delete(ids=doc_ids) |
| | if state_ids: |
| | self.state_collection.delete(ids=state_ids) |
| | |
| | document_ids, document_texts, document_metadatas = [], [], [] |
| | state_ids, state_texts, state_metadatas = [], [], [] |
| | |
| | for state, text in state_sections.items(): |
| | state_id = f"state_{state.lower().replace(' ', '_')}" |
| | summary = text[:1000].strip() if len(text) > 1000 else text |
| | state_ids.append(state_id) |
| | state_texts.append(summary) |
| | state_metadatas.append({"state": state, "type": "summary"}) |
| | |
| | chunks = self._chunk_text(text, chunk_size=1000, overlap=200) |
| | for i, chunk in enumerate(chunks): |
| | doc_id = f"doc_{state.lower().replace(' ', '_')}_{i}" |
| | document_ids.append(doc_id) |
| | document_texts.append(chunk) |
| | document_metadatas.append({ |
| | "state": state, |
| | "chunk_id": i, |
| | "total_chunks": len(chunks), |
| | "source": os.path.basename(pdf_path) |
| | }) |
| | |
| | if document_ids: |
| | self.document_collection.add( |
| | ids=document_ids, |
| | documents=document_texts, |
| | metadatas=document_metadatas |
| | ) |
| | if state_ids: |
| | self.state_collection.add( |
| | ids=state_ids, |
| | documents=state_texts, |
| | metadatas=state_metadatas |
| | ) |
| | |
| | logging.info(f"Loaded {len(document_ids)} document chunks and {len(state_ids)} state summaries") |
| | return len(state_sections) |
| |
|
| | def _chunk_text(self, text, chunk_size=1000, overlap=200): |
| | """Split text into overlapping chunks.""" |
| | if not text: |
| | return [] |
| | |
| | chunks = [] |
| | start = 0 |
| | text_length = len(text) |
| | |
| | while start < text_length: |
| | end = min(start + chunk_size, text_length) |
| | if end < text_length: |
| | last_period = text.rfind(".", start, end) |
| | last_newline = text.rfind("\n", start, end) |
| | split_point = max(last_period, last_newline) |
| | if split_point > start: |
| | end = split_point + 1 |
| | chunks.append(text[start:end].strip()) |
| | start = end - overlap if end - overlap > start else end |
| | |
| | return chunks |
| |
|
| | def query(self, query_text, state=None, n_results=5): |
| | """Query the vector database for relevant tenant rights information.""" |
| | state_filter = {"state": state} if state else None |
| | |
| | document_results = self.document_collection.query( |
| | query_texts=[query_text], |
| | n_results=n_results, |
| | where=state_filter |
| | ) |
| | state_results = self.state_collection.query( |
| | query_texts=[query_text], |
| | n_results=n_results, |
| | where=state_filter |
| | ) |
| | |
| | return {"document_results": document_results, "state_results": state_results} |
| |
|
| | def get_states(self): |
| | """Get a list of all states in the database.""" |
| | results = self.state_collection.get() |
| | states = {meta["state"] for meta in results["metadatas"] if meta} |
| | return sorted(list(states)) |
| |
|
| | if __name__ == "__main__": |
| | try: |
| | db = VectorDatabase() |
| | pdf_path = "data/tenant-landlord.pdf" |
| | db.process_and_load_pdf(pdf_path) |
| | states = db.get_states() |
| | print(f"Available states: {states}") |
| | except Exception as e: |
| | logging.error(f"Script execution failed: {str(e)}") |
| | raise |