Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import pickle | |
| from typing import List, Dict, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| import logging | |
| # PDF and text processing | |
| import PyPDF2 | |
| import pdfplumber | |
| import pandas as pd | |
| # Vector embeddings and similarity | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import faiss | |
| import groq | |
| client = groq.Client( | |
| api_key=os.getenv("GROQ_API_KEY") | |
| ) | |
| def get_response(prompt: str) -> str: | |
| """Get response from Groq LLM""" | |
| response = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3.3-70b-versatile", | |
| max_tokens=4096, | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| class InvoiceChunk: | |
| """Structured representation of an invoice chunk""" | |
| content: str | |
| chunk_type: str # 'header', 'vendor', 'items', 'totals', 'footer' | |
| metadata: Dict[str, Any] | |
| embedding: Optional[np.ndarray] = None | |
| source_file: str = "" | |
| page_number: int = 0 | |
| class InvoicePatternExtractor: | |
| """Extract structured patterns from invoice text""" | |
| def __init__(self): | |
| # Common invoice patterns | |
| self.patterns = { | |
| 'invoice_number': [ | |
| r'invoice\s*#?\s*:?\s*([A-Z0-9-]+)', | |
| r'inv\s*#?\s*:?\s*([A-Z0-9-]+)', | |
| r'bill\s*#?\s*:?\s*([A-Z0-9-]+)' | |
| ], | |
| 'date': [ | |
| r'date\s*:?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', | |
| r'invoice\s*date\s*:?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', | |
| r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})' | |
| ], | |
| 'total_amount': [ | |
| r'total\s*:?\s*\$?([\d,]+\.?\d*)', | |
| r'amount\s*due\s*:?\s*\$?([\d,]+\.?\d*)', | |
| r'grand\s*total\s*:?\s*\$?([\d,]+\.?\d*)' | |
| ], | |
| 'vendor_info': [ | |
| r'from\s*:?\s*(.+?)(?=to|bill|invoice)', | |
| r'vendor\s*:?\s*(.+?)(?=\n|\r)', | |
| r'company\s*:?\s*(.+?)(?=\n|\r)' | |
| ], | |
| 'line_items': [ | |
| r'(\d+\.?\d*)\s+(.+?)\s+\$?([\d,]+\.?\d*)', | |
| r'(.+?)\s+qty\s*:?\s*(\d+)\s+\$?([\d,]+\.?\d*)' | |
| ] | |
| } | |
| def extract_patterns(self, text: str) -> Dict[str, List[str]]: | |
| """Extract all patterns from text""" | |
| results = {} | |
| text_lower = text.lower() | |
| for pattern_name, regex_list in self.patterns.items(): | |
| matches = [] | |
| for regex in regex_list: | |
| found = re.findall(regex, text_lower, re.IGNORECASE | re.MULTILINE) | |
| matches.extend([match if isinstance(match, str) else ' '.join(match) | |
| for match in found]) | |
| results[pattern_name] = list(set(matches)) # Remove duplicates | |
| return results | |
| class InvoicePDFProcessor: | |
| """Process PDF invoices and extract structured content""" | |
| def __init__(self): | |
| self.pattern_extractor = InvoicePatternExtractor() | |
| def extract_text_with_layout(self, pdf_path: str) -> List[Dict[str, Any]]: | |
| """Extract text while preserving layout information""" | |
| pages_data = [] | |
| try: | |
| with pdfplumber.open(pdf_path) as pdf: | |
| for page_num, page in enumerate(pdf.pages): | |
| # Extract text | |
| text = page.extract_text() or "" | |
| # Extract tables | |
| tables = page.extract_tables() | |
| # Get page dimensions for layout analysis | |
| page_data = { | |
| 'page_number': page_num + 1, | |
| 'text': text, | |
| 'tables': tables, | |
| 'bbox': page.bbox, | |
| 'width': page.width, | |
| 'height': page.height | |
| } | |
| pages_data.append(page_data) | |
| except Exception as e: | |
| logging.error(f"Error processing PDF {pdf_path}: {e}") | |
| # Fallback to PyPDF2 | |
| pages_data = self._fallback_pdf_extraction(pdf_path) | |
| return pages_data | |
| def _fallback_pdf_extraction(self, pdf_path: str) -> List[Dict[str, Any]]: | |
| """Fallback PDF extraction using PyPDF2""" | |
| pages_data = [] | |
| try: | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page_num, page in enumerate(pdf_reader.pages): | |
| text = page.extract_text() | |
| pages_data.append({ | |
| 'page_number': page_num + 1, | |
| 'text': text, | |
| 'tables': [], | |
| 'bbox': None, | |
| 'width': None, | |
| 'height': None | |
| }) | |
| except Exception as e: | |
| logging.error(f"Fallback extraction failed for {pdf_path}: {e}") | |
| return pages_data | |
| def create_semantic_chunks(self, pages_data: List[Dict], source_file: str) -> List[InvoiceChunk]: | |
| """Create semantically meaningful chunks from invoice pages""" | |
| chunks = [] | |
| for page_data in pages_data: | |
| text = page_data['text'] | |
| page_num = page_data['page_number'] | |
| # Extract patterns from the text | |
| patterns = self.pattern_extractor.extract_patterns(text) | |
| # Identify different sections of the invoice | |
| sections = self._identify_sections(text, patterns) | |
| for section_type, content in sections.items(): | |
| if content.strip(): | |
| metadata = { | |
| 'patterns': patterns, | |
| 'section_type': section_type, | |
| 'page_number': page_num, | |
| 'has_tables': len(page_data.get('tables', [])) > 0, | |
| 'source_file': source_file, | |
| 'extracted_at': datetime.now().isoformat() | |
| } | |
| chunk = InvoiceChunk( | |
| content=content, | |
| chunk_type=section_type, | |
| metadata=metadata, | |
| source_file=source_file, | |
| page_number=page_num | |
| ) | |
| chunks.append(chunk) | |
| return chunks | |
| def _identify_sections(self, text: str, patterns: Dict) -> Dict[str, str]: | |
| """Identify different sections of an invoice""" | |
| lines = text.split('\n') | |
| sections = { | |
| 'header': '', | |
| 'vendor': '', | |
| 'client': '', | |
| 'items': '', | |
| 'totals': '', | |
| 'footer': '' | |
| } | |
| current_section = 'header' | |
| for i, line in enumerate(lines): | |
| line_lower = line.lower().strip() | |
| # Section identification logic | |
| if any(keyword in line_lower for keyword in ['bill to', 'ship to', 'customer']): | |
| current_section = 'client' | |
| elif any(keyword in line_lower for keyword in ['description', 'item', 'qty', 'quantity']): | |
| current_section = 'items' | |
| elif any(keyword in line_lower for keyword in ['subtotal', 'tax', 'total', 'amount due']): | |
| current_section = 'totals' | |
| elif any(keyword in line_lower for keyword in ['thank you', 'terms', 'payment']): | |
| current_section = 'footer' | |
| elif i < 5 and any(keyword in line_lower for keyword in ['invoice', 'bill', 'from']): | |
| current_section = 'vendor' if 'from' in line_lower else 'header' | |
| sections[current_section] += line + '\n' | |
| return sections | |
| class InvoiceRAGSystem: | |
| """Main RAG system for invoice pattern recognition""" | |
| def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): | |
| self.embedding_model = SentenceTransformer(model_name) | |
| self.pdf_processor = InvoicePDFProcessor() | |
| self.chunks: List[InvoiceChunk] = [] | |
| self.index = None | |
| self.chunk_embeddings = [] | |
| def train_on_invoices(self, invoice_folder: str): | |
| """Train the RAG system on a folder of invoice PDFs""" | |
| logging.info(f"Training on invoices in {invoice_folder}") | |
| pdf_files = [f for f in os.listdir(invoice_folder) if f.endswith('.pdf')] | |
| for pdf_file in pdf_files: | |
| pdf_path = os.path.join(invoice_folder, pdf_file) | |
| logging.info(f"Processing {pdf_file}") | |
| # Process PDF | |
| pages_data = self.pdf_processor.extract_text_with_layout(pdf_path) | |
| # Create chunks | |
| file_chunks = self.pdf_processor.create_semantic_chunks(pages_data, pdf_file) | |
| # Generate embeddings | |
| for chunk in file_chunks: | |
| embedding = self.embedding_model.encode(chunk.content) | |
| chunk.embedding = embedding | |
| self.chunk_embeddings.append(embedding) | |
| self.chunks.extend(file_chunks) | |
| # Build FAISS index | |
| self._build_index() | |
| logging.info(f"Training complete. Processed {len(self.chunks)} chunks from {len(pdf_files)} invoices") | |
| def _build_index(self): | |
| """Build FAISS index for efficient similarity search""" | |
| if not self.chunk_embeddings: | |
| return | |
| embeddings_array = np.array(self.chunk_embeddings).astype('float32') | |
| dimension = embeddings_array.shape[1] | |
| # Use IndexFlatIP for cosine similarity | |
| self.index = faiss.IndexFlatIP(dimension) | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings_array) | |
| self.index.add(embeddings_array) | |
| def retrieve_similar_patterns(self, query: str, top_k: int = 5, | |
| section_filter: Optional[str] = None) -> List[Tuple[InvoiceChunk, float]]: | |
| """Retrieve similar invoice patterns based on query""" | |
| if not self.index: | |
| return [] | |
| # Encode query | |
| query_embedding = self.embedding_model.encode([query]).astype('float32') | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| scores, indices = self.index.search(query_embedding, min(top_k * 2, len(self.chunks))) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx < len(self.chunks): | |
| chunk = self.chunks[idx] | |
| # Apply section filter if specified | |
| if section_filter and chunk.chunk_type != section_filter: | |
| continue | |
| results.append((chunk, float(score))) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| def extract_invoice_info(self, query: str, context_sections: Optional[List[str]] = None) -> Dict[str, Any]: | |
| """Extract specific information from invoices using RAG""" | |
| # Retrieve relevant chunks | |
| if context_sections: | |
| all_results = [] | |
| for section in context_sections: | |
| section_results = self.retrieve_similar_patterns(query, top_k=3, section_filter=section) | |
| all_results.extend(section_results) | |
| else: | |
| all_results = self.retrieve_similar_patterns(query, top_k=5) | |
| # Prepare context for LLM | |
| context_chunks = [] | |
| patterns_found = {} | |
| for chunk, score in all_results: | |
| context_chunks.append({ | |
| 'content': chunk.content, | |
| 'type': chunk.chunk_type, | |
| 'source': chunk.source_file, | |
| 'score': score, | |
| 'patterns': chunk.metadata.get('patterns', {}) | |
| }) | |
| # Collect patterns | |
| for pattern_type, values in chunk.metadata.get('patterns', {}).items(): | |
| if pattern_type not in patterns_found: | |
| patterns_found[pattern_type] = [] | |
| patterns_found[pattern_type].extend(values) | |
| return { | |
| 'query': query, | |
| 'context_chunks': context_chunks, | |
| 'extracted_patterns': patterns_found, | |
| 'num_sources': len(set(chunk.source_file for chunk, _ in all_results)) | |
| } | |
| def get_pattern_summary(self) -> Dict[str, Any]: | |
| """Get summary of patterns learned from training data""" | |
| pattern_stats = {} | |
| section_stats = {} | |
| for chunk in self.chunks: | |
| # Count section types | |
| section_type = chunk.chunk_type | |
| section_stats[section_type] = section_stats.get(section_type, 0) + 1 | |
| # Count patterns | |
| for pattern_type, values in chunk.metadata.get('patterns', {}).items(): | |
| if pattern_type not in pattern_stats: | |
| pattern_stats[pattern_type] = {'count': 0, 'examples': set()} | |
| pattern_stats[pattern_type]['count'] += len(values) | |
| pattern_stats[pattern_type]['examples'].update(values[:3]) # Keep first 3 examples | |
| # Convert sets to lists for JSON serialization | |
| for pattern_type in pattern_stats: | |
| pattern_stats[pattern_type]['examples'] = list(pattern_stats[pattern_type]['examples']) | |
| return { | |
| 'total_chunks': len(self.chunks), | |
| 'total_invoices': len(set(chunk.source_file for chunk in self.chunks)), | |
| 'section_distribution': section_stats, | |
| 'pattern_statistics': pattern_stats | |
| } | |
| def save_model(self, save_path: str): | |
| """Save the trained model""" | |
| model_data = { | |
| 'chunks': self.chunks, | |
| 'chunk_embeddings': self.chunk_embeddings | |
| } | |
| with open(save_path, 'wb') as f: | |
| pickle.dump(model_data, f) | |
| # Save FAISS index separately | |
| if self.index: | |
| faiss.write_index(self.index, save_path.replace('.pkl', '.faiss')) | |
| def load_model(self, load_path: str): | |
| """Load a trained model""" | |
| with open(load_path, 'rb') as f: | |
| model_data = pickle.load(f) | |
| self.chunks = model_data['chunks'] | |
| self.chunk_embeddings = model_data['chunk_embeddings'] | |
| # Load FAISS index | |
| faiss_path = load_path.replace('.pkl', '.faiss') | |
| if os.path.exists(faiss_path): | |
| self.index = faiss.read_index(faiss_path) | |
| # Example usage and testing | |
| def main(): | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Initialize RAG system | |
| rag_system = InvoiceRAGSystem() | |
| # Train on invoice folder (replace with your path) | |
| invoice_folder = "invoices" | |
| if os.path.exists(invoice_folder): | |
| rag_system.train_on_invoices(invoice_folder) | |
| # Get pattern summary | |
| summary = rag_system.get_pattern_summary() | |
| print("Pattern Summary:") | |
| print(json.dumps(summary, indent=2)) | |
| # Example queries | |
| queries = [ | |
| "What are the invoice numbers?", | |
| "Show me vendor information", | |
| "Extract total amounts", | |
| "Find products with batch number, price per pc, quantities, total amount per product", | |
| "What is the invoice date?", | |
| ] | |
| for query in queries: | |
| print(f"\n=== Query: {query} ===") | |
| results = rag_system.extract_invoice_info(query) | |
| # Feed the context and query to the LLM pipeline | |
| context_text = "\n\n".join( | |
| f"[{chunk['type']}] {chunk['content']}" for chunk in results['context_chunks'] | |
| ) | |
| prompt = f"Context:\n{context_text}\n\nQuestion: {query}\nAnswer:" | |
| llm_response = get_response(prompt) | |
| print(f"LLM Answer:\n{llm_response}") | |
| # print(f"Found patterns: {results['extracted_patterns']}") | |
| # print(f"Context from {results['num_sources']} sources") | |
| # for i, chunk in enumerate(results['context_chunks'][:2], 1): | |
| # print(f"{i}. [{chunk['type']}] {chunk['content'][:100]}...") | |
| # Save the trained model | |
| rag_system.save_model("invoice_rag_model.pkl") | |
| print("\nModel saved to invoice_rag_model.pkl") | |
| else: | |
| print(f"Invoice folder {invoice_folder} not found. Please update the path.") | |
| print("To use this system:") | |
| print("1. Create a folder with invoice PDFs") | |
| print("2. Update the invoice_folder path") | |
| print("3. Run the training process") | |
| if __name__ == "__main__": | |
| main() | |