import os import re import json import pickle from typing import List, Dict, Any, Tuple, Optional from dataclasses import dataclass from datetime import datetime import logging # PDF and text processing import PyPDF2 import pdfplumber import pandas as pd # Vector embeddings and similarity import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import faiss import groq client = groq.Client( api_key=os.getenv("GROQ_API_KEY") ) def get_response(prompt: str) -> str: """Get response from Groq LLM""" response = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.3-70b-versatile", max_tokens=4096, temperature=0.7, ) return response.choices[0].message.content.strip() @dataclass class InvoiceChunk: """Structured representation of an invoice chunk""" content: str chunk_type: str # 'header', 'vendor', 'items', 'totals', 'footer' metadata: Dict[str, Any] embedding: Optional[np.ndarray] = None source_file: str = "" page_number: int = 0 class InvoicePatternExtractor: """Extract structured patterns from invoice text""" def __init__(self): # Common invoice patterns self.patterns = { 'invoice_number': [ r'invoice\s*#?\s*:?\s*([A-Z0-9-]+)', r'inv\s*#?\s*:?\s*([A-Z0-9-]+)', r'bill\s*#?\s*:?\s*([A-Z0-9-]+)' ], 'date': [ r'date\s*:?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', r'invoice\s*date\s*:?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})' ], 'total_amount': [ r'total\s*:?\s*\$?([\d,]+\.?\d*)', r'amount\s*due\s*:?\s*\$?([\d,]+\.?\d*)', r'grand\s*total\s*:?\s*\$?([\d,]+\.?\d*)' ], 'vendor_info': [ r'from\s*:?\s*(.+?)(?=to|bill|invoice)', r'vendor\s*:?\s*(.+?)(?=\n|\r)', r'company\s*:?\s*(.+?)(?=\n|\r)' ], 'line_items': [ r'(\d+\.?\d*)\s+(.+?)\s+\$?([\d,]+\.?\d*)', r'(.+?)\s+qty\s*:?\s*(\d+)\s+\$?([\d,]+\.?\d*)' ] } def extract_patterns(self, text: str) -> Dict[str, List[str]]: """Extract all patterns from text""" results = {} text_lower = text.lower() for pattern_name, regex_list in self.patterns.items(): matches = [] for regex in regex_list: found = re.findall(regex, text_lower, re.IGNORECASE | re.MULTILINE) matches.extend([match if isinstance(match, str) else ' '.join(match) for match in found]) results[pattern_name] = list(set(matches)) # Remove duplicates return results class InvoicePDFProcessor: """Process PDF invoices and extract structured content""" def __init__(self): self.pattern_extractor = InvoicePatternExtractor() def extract_text_with_layout(self, pdf_path: str) -> List[Dict[str, Any]]: """Extract text while preserving layout information""" pages_data = [] try: with pdfplumber.open(pdf_path) as pdf: for page_num, page in enumerate(pdf.pages): # Extract text text = page.extract_text() or "" # Extract tables tables = page.extract_tables() # Get page dimensions for layout analysis page_data = { 'page_number': page_num + 1, 'text': text, 'tables': tables, 'bbox': page.bbox, 'width': page.width, 'height': page.height } pages_data.append(page_data) except Exception as e: logging.error(f"Error processing PDF {pdf_path}: {e}") # Fallback to PyPDF2 pages_data = self._fallback_pdf_extraction(pdf_path) return pages_data def _fallback_pdf_extraction(self, pdf_path: str) -> List[Dict[str, Any]]: """Fallback PDF extraction using PyPDF2""" pages_data = [] try: with open(pdf_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) for page_num, page in enumerate(pdf_reader.pages): text = page.extract_text() pages_data.append({ 'page_number': page_num + 1, 'text': text, 'tables': [], 'bbox': None, 'width': None, 'height': None }) except Exception as e: logging.error(f"Fallback extraction failed for {pdf_path}: {e}") return pages_data def create_semantic_chunks(self, pages_data: List[Dict], source_file: str) -> List[InvoiceChunk]: """Create semantically meaningful chunks from invoice pages""" chunks = [] for page_data in pages_data: text = page_data['text'] page_num = page_data['page_number'] # Extract patterns from the text patterns = self.pattern_extractor.extract_patterns(text) # Identify different sections of the invoice sections = self._identify_sections(text, patterns) for section_type, content in sections.items(): if content.strip(): metadata = { 'patterns': patterns, 'section_type': section_type, 'page_number': page_num, 'has_tables': len(page_data.get('tables', [])) > 0, 'source_file': source_file, 'extracted_at': datetime.now().isoformat() } chunk = InvoiceChunk( content=content, chunk_type=section_type, metadata=metadata, source_file=source_file, page_number=page_num ) chunks.append(chunk) return chunks def _identify_sections(self, text: str, patterns: Dict) -> Dict[str, str]: """Identify different sections of an invoice""" lines = text.split('\n') sections = { 'header': '', 'vendor': '', 'client': '', 'items': '', 'totals': '', 'footer': '' } current_section = 'header' for i, line in enumerate(lines): line_lower = line.lower().strip() # Section identification logic if any(keyword in line_lower for keyword in ['bill to', 'ship to', 'customer']): current_section = 'client' elif any(keyword in line_lower for keyword in ['description', 'item', 'qty', 'quantity']): current_section = 'items' elif any(keyword in line_lower for keyword in ['subtotal', 'tax', 'total', 'amount due']): current_section = 'totals' elif any(keyword in line_lower for keyword in ['thank you', 'terms', 'payment']): current_section = 'footer' elif i < 5 and any(keyword in line_lower for keyword in ['invoice', 'bill', 'from']): current_section = 'vendor' if 'from' in line_lower else 'header' sections[current_section] += line + '\n' return sections class InvoiceRAGSystem: """Main RAG system for invoice pattern recognition""" def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): self.embedding_model = SentenceTransformer(model_name) self.pdf_processor = InvoicePDFProcessor() self.chunks: List[InvoiceChunk] = [] self.index = None self.chunk_embeddings = [] def train_on_invoices(self, invoice_folder: str): """Train the RAG system on a folder of invoice PDFs""" logging.info(f"Training on invoices in {invoice_folder}") pdf_files = [f for f in os.listdir(invoice_folder) if f.endswith('.pdf')] for pdf_file in pdf_files: pdf_path = os.path.join(invoice_folder, pdf_file) logging.info(f"Processing {pdf_file}") # Process PDF pages_data = self.pdf_processor.extract_text_with_layout(pdf_path) # Create chunks file_chunks = self.pdf_processor.create_semantic_chunks(pages_data, pdf_file) # Generate embeddings for chunk in file_chunks: embedding = self.embedding_model.encode(chunk.content) chunk.embedding = embedding self.chunk_embeddings.append(embedding) self.chunks.extend(file_chunks) # Build FAISS index self._build_index() logging.info(f"Training complete. Processed {len(self.chunks)} chunks from {len(pdf_files)} invoices") def _build_index(self): """Build FAISS index for efficient similarity search""" if not self.chunk_embeddings: return embeddings_array = np.array(self.chunk_embeddings).astype('float32') dimension = embeddings_array.shape[1] # Use IndexFlatIP for cosine similarity self.index = faiss.IndexFlatIP(dimension) # Normalize embeddings for cosine similarity faiss.normalize_L2(embeddings_array) self.index.add(embeddings_array) def retrieve_similar_patterns(self, query: str, top_k: int = 5, section_filter: Optional[str] = None) -> List[Tuple[InvoiceChunk, float]]: """Retrieve similar invoice patterns based on query""" if not self.index: return [] # Encode query query_embedding = self.embedding_model.encode([query]).astype('float32') faiss.normalize_L2(query_embedding) # Search scores, indices = self.index.search(query_embedding, min(top_k * 2, len(self.chunks))) results = [] for score, idx in zip(scores[0], indices[0]): if idx < len(self.chunks): chunk = self.chunks[idx] # Apply section filter if specified if section_filter and chunk.chunk_type != section_filter: continue results.append((chunk, float(score))) if len(results) >= top_k: break return results def extract_invoice_info(self, query: str, context_sections: Optional[List[str]] = None) -> Dict[str, Any]: """Extract specific information from invoices using RAG""" # Retrieve relevant chunks if context_sections: all_results = [] for section in context_sections: section_results = self.retrieve_similar_patterns(query, top_k=3, section_filter=section) all_results.extend(section_results) else: all_results = self.retrieve_similar_patterns(query, top_k=5) # Prepare context for LLM context_chunks = [] patterns_found = {} for chunk, score in all_results: context_chunks.append({ 'content': chunk.content, 'type': chunk.chunk_type, 'source': chunk.source_file, 'score': score, 'patterns': chunk.metadata.get('patterns', {}) }) # Collect patterns for pattern_type, values in chunk.metadata.get('patterns', {}).items(): if pattern_type not in patterns_found: patterns_found[pattern_type] = [] patterns_found[pattern_type].extend(values) return { 'query': query, 'context_chunks': context_chunks, 'extracted_patterns': patterns_found, 'num_sources': len(set(chunk.source_file for chunk, _ in all_results)) } def get_pattern_summary(self) -> Dict[str, Any]: """Get summary of patterns learned from training data""" pattern_stats = {} section_stats = {} for chunk in self.chunks: # Count section types section_type = chunk.chunk_type section_stats[section_type] = section_stats.get(section_type, 0) + 1 # Count patterns for pattern_type, values in chunk.metadata.get('patterns', {}).items(): if pattern_type not in pattern_stats: pattern_stats[pattern_type] = {'count': 0, 'examples': set()} pattern_stats[pattern_type]['count'] += len(values) pattern_stats[pattern_type]['examples'].update(values[:3]) # Keep first 3 examples # Convert sets to lists for JSON serialization for pattern_type in pattern_stats: pattern_stats[pattern_type]['examples'] = list(pattern_stats[pattern_type]['examples']) return { 'total_chunks': len(self.chunks), 'total_invoices': len(set(chunk.source_file for chunk in self.chunks)), 'section_distribution': section_stats, 'pattern_statistics': pattern_stats } def save_model(self, save_path: str): """Save the trained model""" model_data = { 'chunks': self.chunks, 'chunk_embeddings': self.chunk_embeddings } with open(save_path, 'wb') as f: pickle.dump(model_data, f) # Save FAISS index separately if self.index: faiss.write_index(self.index, save_path.replace('.pkl', '.faiss')) def load_model(self, load_path: str): """Load a trained model""" with open(load_path, 'rb') as f: model_data = pickle.load(f) self.chunks = model_data['chunks'] self.chunk_embeddings = model_data['chunk_embeddings'] # Load FAISS index faiss_path = load_path.replace('.pkl', '.faiss') if os.path.exists(faiss_path): self.index = faiss.read_index(faiss_path) # Example usage and testing def main(): # Setup logging logging.basicConfig(level=logging.INFO) # Initialize RAG system rag_system = InvoiceRAGSystem() # Train on invoice folder (replace with your path) invoice_folder = "invoices" if os.path.exists(invoice_folder): rag_system.train_on_invoices(invoice_folder) # Get pattern summary summary = rag_system.get_pattern_summary() print("Pattern Summary:") print(json.dumps(summary, indent=2)) # Example queries queries = [ "What are the invoice numbers?", "Show me vendor information", "Extract total amounts", "Find products with batch number, price per pc, quantities, total amount per product", "What is the invoice date?", ] for query in queries: print(f"\n=== Query: {query} ===") results = rag_system.extract_invoice_info(query) # Feed the context and query to the LLM pipeline context_text = "\n\n".join( f"[{chunk['type']}] {chunk['content']}" for chunk in results['context_chunks'] ) prompt = f"Context:\n{context_text}\n\nQuestion: {query}\nAnswer:" llm_response = get_response(prompt) print(f"LLM Answer:\n{llm_response}") # print(f"Found patterns: {results['extracted_patterns']}") # print(f"Context from {results['num_sources']} sources") # for i, chunk in enumerate(results['context_chunks'][:2], 1): # print(f"{i}. [{chunk['type']}] {chunk['content'][:100]}...") # Save the trained model rag_system.save_model("invoice_rag_model.pkl") print("\nModel saved to invoice_rag_model.pkl") else: print(f"Invoice folder {invoice_folder} not found. Please update the path.") print("To use this system:") print("1. Create a folder with invoice PDFs") print("2. Update the invoice_folder path") print("3. Run the training process") if __name__ == "__main__": main()