| import os |
| import fitz |
| import re |
| import chromadb |
| from chromadb.utils import embedding_functions |
| import numpy as np |
| import torch |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| class VectorDatabase: |
| """Vector database for storing and retrieving tenant rights information from PDF.""" |
|
|
| def __init__(self, persist_directory="./chroma_db"): |
| """Initialize the vector database.""" |
| logging.info("Initializing VectorDatabase") |
| logging.info(f"NumPy version: {np.__version__}") |
| logging.info(f"PyTorch version: {torch.__version__}") |
|
|
| self.persist_directory = persist_directory |
| os.makedirs(persist_directory, exist_ok=True) |
| |
| try: |
| logging.info("Creating embedding function") |
| self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( |
| model_name="all-MiniLM-L6-v2" |
| ) |
| |
| logging.info("Initializing ChromaDB client") |
| self.client = chromadb.PersistentClient(path=persist_directory) |
| |
| logging.info("Setting up collections") |
| self.document_collection = self._get_or_create_collection("tenant_documents") |
| self.state_collection = self._get_or_create_collection("tenant_states") |
| except Exception as e: |
| logging.error(f"Initialization failed: {str(e)}") |
| raise |
|
|
| def _get_or_create_collection(self, name): |
| """Get or create a collection with the given name.""" |
| try: |
| return self.client.get_collection( |
| name=name, |
| embedding_function=self.embedding_function |
| ) |
| except Exception: |
| return self.client.create_collection( |
| name=name, |
| embedding_function=self.embedding_function |
| ) |
|
|
| def extract_pdf_content(self, pdf_path): |
| """Extract content from PDF file and identify state sections.""" |
| logging.info(f"Extracting content from PDF: {pdf_path}") |
| |
| if not os.path.exists(pdf_path): |
| raise FileNotFoundError(f"PDF file not found: {pdf_path}") |
| |
| doc = fitz.open(pdf_path) |
| full_text = "" |
| for page_num in range(len(doc)): |
| page = doc.load_page(page_num) |
| full_text += page.get_text("text") + "\n" |
| doc.close() |
| |
| state_pattern = r"(?m)^\s*([A-Z][a-z]+(?:\s[A-Z][a-z]+)*)\s+Landlord(?:-|\s)Tenant\s+(?:Law|Laws)" |
| state_matches = list(re.finditer(state_pattern, full_text)) |
| |
| if not state_matches: |
| logging.info("No state sections found. Treating as single document.") |
| return {"Full Document": full_text.strip()} |
| |
| state_sections = {} |
| for i, match in enumerate(state_matches): |
| state_name = match.group(1).strip() |
| start_pos = match.end() |
| end_pos = state_matches[i + 1].start() if i + 1 < len(state_matches) else len(full_text) |
| state_text = full_text[start_pos:end_pos].strip() |
| if state_text: |
| state_sections[state_name] = state_text |
| |
| logging.info(f"Extracted content for {len(state_sections)} states") |
| return state_sections |
|
|
| def process_and_load_pdf(self, pdf_path): |
| """Process PDF and load content into vector database.""" |
| state_sections = self.extract_pdf_content(pdf_path) |
| |
| doc_ids = self.document_collection.get()["ids"] |
| state_ids = self.state_collection.get()["ids"] |
| |
| if doc_ids: |
| self.document_collection.delete(ids=doc_ids) |
| if state_ids: |
| self.state_collection.delete(ids=state_ids) |
| |
| document_ids, document_texts, document_metadatas = [], [], [] |
| state_ids, state_texts, state_metadatas = [], [], [] |
| |
| for state, text in state_sections.items(): |
| state_id = f"state_{state.lower().replace(' ', '_')}" |
| summary = text[:1000].strip() if len(text) > 1000 else text |
| state_ids.append(state_id) |
| state_texts.append(summary) |
| state_metadatas.append({"state": state, "type": "summary"}) |
| |
| chunks = self._chunk_text(text, chunk_size=1000, overlap=200) |
| for i, chunk in enumerate(chunks): |
| doc_id = f"doc_{state.lower().replace(' ', '_')}_{i}" |
| document_ids.append(doc_id) |
| document_texts.append(chunk) |
| document_metadatas.append({ |
| "state": state, |
| "chunk_id": i, |
| "total_chunks": len(chunks), |
| "source": os.path.basename(pdf_path) |
| }) |
| |
| if document_ids: |
| self.document_collection.add( |
| ids=document_ids, |
| documents=document_texts, |
| metadatas=document_metadatas |
| ) |
| if state_ids: |
| self.state_collection.add( |
| ids=state_ids, |
| documents=state_texts, |
| metadatas=state_metadatas |
| ) |
| |
| logging.info(f"Loaded {len(document_ids)} document chunks and {len(state_ids)} state summaries") |
| return len(state_sections) |
|
|
| def _chunk_text(self, text, chunk_size=1000, overlap=200): |
| """Split text into overlapping chunks.""" |
| if not text: |
| return [] |
| |
| chunks = [] |
| start = 0 |
| text_length = len(text) |
| |
| while start < text_length: |
| end = min(start + chunk_size, text_length) |
| if end < text_length: |
| last_period = text.rfind(".", start, end) |
| last_newline = text.rfind("\n", start, end) |
| split_point = max(last_period, last_newline) |
| if split_point > start: |
| end = split_point + 1 |
| chunks.append(text[start:end].strip()) |
| start = end - overlap if end - overlap > start else end |
| |
| return chunks |
|
|
| def query(self, query_text, state=None, n_results=5): |
| """Query the vector database for relevant tenant rights information.""" |
| state_filter = {"state": state} if state else None |
| |
| document_results = self.document_collection.query( |
| query_texts=[query_text], |
| n_results=n_results, |
| where=state_filter |
| ) |
| state_results = self.state_collection.query( |
| query_texts=[query_text], |
| n_results=n_results, |
| where=state_filter |
| ) |
| |
| return {"document_results": document_results, "state_results": state_results} |
|
|
| def get_states(self): |
| """Get a list of all states in the database.""" |
| results = self.state_collection.get() |
| states = {meta["state"] for meta in results["metadatas"] if meta} |
| return sorted(list(states)) |
|
|
| if __name__ == "__main__": |
| try: |
| db = VectorDatabase() |
| pdf_path = "tenant-landlord.pdf" |
| db.process_and_load_pdf(pdf_path) |
| states = db.get_states() |
| print(f"Available states: {states}") |
| except Exception as e: |
| logging.error(f"Script execution failed: {str(e)}") |
| raise |