InvoiceAgent / invoice_rag_system.py
YashArya16's picture
Upload invoice_rag_system.py
397b599 verified
import os
import re
import json
import pickle
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime
import logging
# PDF and text processing
import PyPDF2
import pdfplumber
import pandas as pd
# Vector embeddings and similarity
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import faiss
import groq
client = groq.Client(
api_key=os.getenv("GROQ_API_KEY")
)
def get_response(prompt: str) -> str:
"""Get response from Groq LLM"""
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
max_tokens=4096,
temperature=0.7,
)
return response.choices[0].message.content.strip()
@dataclass
class InvoiceChunk:
"""Structured representation of an invoice chunk"""
content: str
chunk_type: str # 'header', 'vendor', 'items', 'totals', 'footer'
metadata: Dict[str, Any]
embedding: Optional[np.ndarray] = None
source_file: str = ""
page_number: int = 0
class InvoicePatternExtractor:
"""Extract structured patterns from invoice text"""
def __init__(self):
# Common invoice patterns
self.patterns = {
'invoice_number': [
r'invoice\s*#?\s*:?\s*([A-Z0-9-]+)',
r'inv\s*#?\s*:?\s*([A-Z0-9-]+)',
r'bill\s*#?\s*:?\s*([A-Z0-9-]+)'
],
'date': [
r'date\s*:?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
r'invoice\s*date\s*:?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})'
],
'total_amount': [
r'total\s*:?\s*\$?([\d,]+\.?\d*)',
r'amount\s*due\s*:?\s*\$?([\d,]+\.?\d*)',
r'grand\s*total\s*:?\s*\$?([\d,]+\.?\d*)'
],
'vendor_info': [
r'from\s*:?\s*(.+?)(?=to|bill|invoice)',
r'vendor\s*:?\s*(.+?)(?=\n|\r)',
r'company\s*:?\s*(.+?)(?=\n|\r)'
],
'line_items': [
r'(\d+\.?\d*)\s+(.+?)\s+\$?([\d,]+\.?\d*)',
r'(.+?)\s+qty\s*:?\s*(\d+)\s+\$?([\d,]+\.?\d*)'
]
}
def extract_patterns(self, text: str) -> Dict[str, List[str]]:
"""Extract all patterns from text"""
results = {}
text_lower = text.lower()
for pattern_name, regex_list in self.patterns.items():
matches = []
for regex in regex_list:
found = re.findall(regex, text_lower, re.IGNORECASE | re.MULTILINE)
matches.extend([match if isinstance(match, str) else ' '.join(match)
for match in found])
results[pattern_name] = list(set(matches)) # Remove duplicates
return results
class InvoicePDFProcessor:
"""Process PDF invoices and extract structured content"""
def __init__(self):
self.pattern_extractor = InvoicePatternExtractor()
def extract_text_with_layout(self, pdf_path: str) -> List[Dict[str, Any]]:
"""Extract text while preserving layout information"""
pages_data = []
try:
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages):
# Extract text
text = page.extract_text() or ""
# Extract tables
tables = page.extract_tables()
# Get page dimensions for layout analysis
page_data = {
'page_number': page_num + 1,
'text': text,
'tables': tables,
'bbox': page.bbox,
'width': page.width,
'height': page.height
}
pages_data.append(page_data)
except Exception as e:
logging.error(f"Error processing PDF {pdf_path}: {e}")
# Fallback to PyPDF2
pages_data = self._fallback_pdf_extraction(pdf_path)
return pages_data
def _fallback_pdf_extraction(self, pdf_path: str) -> List[Dict[str, Any]]:
"""Fallback PDF extraction using PyPDF2"""
pages_data = []
try:
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page_num, page in enumerate(pdf_reader.pages):
text = page.extract_text()
pages_data.append({
'page_number': page_num + 1,
'text': text,
'tables': [],
'bbox': None,
'width': None,
'height': None
})
except Exception as e:
logging.error(f"Fallback extraction failed for {pdf_path}: {e}")
return pages_data
def create_semantic_chunks(self, pages_data: List[Dict], source_file: str) -> List[InvoiceChunk]:
"""Create semantically meaningful chunks from invoice pages"""
chunks = []
for page_data in pages_data:
text = page_data['text']
page_num = page_data['page_number']
# Extract patterns from the text
patterns = self.pattern_extractor.extract_patterns(text)
# Identify different sections of the invoice
sections = self._identify_sections(text, patterns)
for section_type, content in sections.items():
if content.strip():
metadata = {
'patterns': patterns,
'section_type': section_type,
'page_number': page_num,
'has_tables': len(page_data.get('tables', [])) > 0,
'source_file': source_file,
'extracted_at': datetime.now().isoformat()
}
chunk = InvoiceChunk(
content=content,
chunk_type=section_type,
metadata=metadata,
source_file=source_file,
page_number=page_num
)
chunks.append(chunk)
return chunks
def _identify_sections(self, text: str, patterns: Dict) -> Dict[str, str]:
"""Identify different sections of an invoice"""
lines = text.split('\n')
sections = {
'header': '',
'vendor': '',
'client': '',
'items': '',
'totals': '',
'footer': ''
}
current_section = 'header'
for i, line in enumerate(lines):
line_lower = line.lower().strip()
# Section identification logic
if any(keyword in line_lower for keyword in ['bill to', 'ship to', 'customer']):
current_section = 'client'
elif any(keyword in line_lower for keyword in ['description', 'item', 'qty', 'quantity']):
current_section = 'items'
elif any(keyword in line_lower for keyword in ['subtotal', 'tax', 'total', 'amount due']):
current_section = 'totals'
elif any(keyword in line_lower for keyword in ['thank you', 'terms', 'payment']):
current_section = 'footer'
elif i < 5 and any(keyword in line_lower for keyword in ['invoice', 'bill', 'from']):
current_section = 'vendor' if 'from' in line_lower else 'header'
sections[current_section] += line + '\n'
return sections
class InvoiceRAGSystem:
"""Main RAG system for invoice pattern recognition"""
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
self.embedding_model = SentenceTransformer(model_name)
self.pdf_processor = InvoicePDFProcessor()
self.chunks: List[InvoiceChunk] = []
self.index = None
self.chunk_embeddings = []
def train_on_invoices(self, invoice_folder: str):
"""Train the RAG system on a folder of invoice PDFs"""
logging.info(f"Training on invoices in {invoice_folder}")
pdf_files = [f for f in os.listdir(invoice_folder) if f.endswith('.pdf')]
for pdf_file in pdf_files:
pdf_path = os.path.join(invoice_folder, pdf_file)
logging.info(f"Processing {pdf_file}")
# Process PDF
pages_data = self.pdf_processor.extract_text_with_layout(pdf_path)
# Create chunks
file_chunks = self.pdf_processor.create_semantic_chunks(pages_data, pdf_file)
# Generate embeddings
for chunk in file_chunks:
embedding = self.embedding_model.encode(chunk.content)
chunk.embedding = embedding
self.chunk_embeddings.append(embedding)
self.chunks.extend(file_chunks)
# Build FAISS index
self._build_index()
logging.info(f"Training complete. Processed {len(self.chunks)} chunks from {len(pdf_files)} invoices")
def _build_index(self):
"""Build FAISS index for efficient similarity search"""
if not self.chunk_embeddings:
return
embeddings_array = np.array(self.chunk_embeddings).astype('float32')
dimension = embeddings_array.shape[1]
# Use IndexFlatIP for cosine similarity
self.index = faiss.IndexFlatIP(dimension)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings_array)
self.index.add(embeddings_array)
def retrieve_similar_patterns(self, query: str, top_k: int = 5,
section_filter: Optional[str] = None) -> List[Tuple[InvoiceChunk, float]]:
"""Retrieve similar invoice patterns based on query"""
if not self.index:
return []
# Encode query
query_embedding = self.embedding_model.encode([query]).astype('float32')
faiss.normalize_L2(query_embedding)
# Search
scores, indices = self.index.search(query_embedding, min(top_k * 2, len(self.chunks)))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx]
# Apply section filter if specified
if section_filter and chunk.chunk_type != section_filter:
continue
results.append((chunk, float(score)))
if len(results) >= top_k:
break
return results
def extract_invoice_info(self, query: str, context_sections: Optional[List[str]] = None) -> Dict[str, Any]:
"""Extract specific information from invoices using RAG"""
# Retrieve relevant chunks
if context_sections:
all_results = []
for section in context_sections:
section_results = self.retrieve_similar_patterns(query, top_k=3, section_filter=section)
all_results.extend(section_results)
else:
all_results = self.retrieve_similar_patterns(query, top_k=5)
# Prepare context for LLM
context_chunks = []
patterns_found = {}
for chunk, score in all_results:
context_chunks.append({
'content': chunk.content,
'type': chunk.chunk_type,
'source': chunk.source_file,
'score': score,
'patterns': chunk.metadata.get('patterns', {})
})
# Collect patterns
for pattern_type, values in chunk.metadata.get('patterns', {}).items():
if pattern_type not in patterns_found:
patterns_found[pattern_type] = []
patterns_found[pattern_type].extend(values)
return {
'query': query,
'context_chunks': context_chunks,
'extracted_patterns': patterns_found,
'num_sources': len(set(chunk.source_file for chunk, _ in all_results))
}
def get_pattern_summary(self) -> Dict[str, Any]:
"""Get summary of patterns learned from training data"""
pattern_stats = {}
section_stats = {}
for chunk in self.chunks:
# Count section types
section_type = chunk.chunk_type
section_stats[section_type] = section_stats.get(section_type, 0) + 1
# Count patterns
for pattern_type, values in chunk.metadata.get('patterns', {}).items():
if pattern_type not in pattern_stats:
pattern_stats[pattern_type] = {'count': 0, 'examples': set()}
pattern_stats[pattern_type]['count'] += len(values)
pattern_stats[pattern_type]['examples'].update(values[:3]) # Keep first 3 examples
# Convert sets to lists for JSON serialization
for pattern_type in pattern_stats:
pattern_stats[pattern_type]['examples'] = list(pattern_stats[pattern_type]['examples'])
return {
'total_chunks': len(self.chunks),
'total_invoices': len(set(chunk.source_file for chunk in self.chunks)),
'section_distribution': section_stats,
'pattern_statistics': pattern_stats
}
def save_model(self, save_path: str):
"""Save the trained model"""
model_data = {
'chunks': self.chunks,
'chunk_embeddings': self.chunk_embeddings
}
with open(save_path, 'wb') as f:
pickle.dump(model_data, f)
# Save FAISS index separately
if self.index:
faiss.write_index(self.index, save_path.replace('.pkl', '.faiss'))
def load_model(self, load_path: str):
"""Load a trained model"""
with open(load_path, 'rb') as f:
model_data = pickle.load(f)
self.chunks = model_data['chunks']
self.chunk_embeddings = model_data['chunk_embeddings']
# Load FAISS index
faiss_path = load_path.replace('.pkl', '.faiss')
if os.path.exists(faiss_path):
self.index = faiss.read_index(faiss_path)
# Example usage and testing
def main():
# Setup logging
logging.basicConfig(level=logging.INFO)
# Initialize RAG system
rag_system = InvoiceRAGSystem()
# Train on invoice folder (replace with your path)
invoice_folder = "invoices"
if os.path.exists(invoice_folder):
rag_system.train_on_invoices(invoice_folder)
# Get pattern summary
summary = rag_system.get_pattern_summary()
print("Pattern Summary:")
print(json.dumps(summary, indent=2))
# Example queries
queries = [
"What are the invoice numbers?",
"Show me vendor information",
"Extract total amounts",
"Find products with batch number, price per pc, quantities, total amount per product",
"What is the invoice date?",
]
for query in queries:
print(f"\n=== Query: {query} ===")
results = rag_system.extract_invoice_info(query)
# Feed the context and query to the LLM pipeline
context_text = "\n\n".join(
f"[{chunk['type']}] {chunk['content']}" for chunk in results['context_chunks']
)
prompt = f"Context:\n{context_text}\n\nQuestion: {query}\nAnswer:"
llm_response = get_response(prompt)
print(f"LLM Answer:\n{llm_response}")
# print(f"Found patterns: {results['extracted_patterns']}")
# print(f"Context from {results['num_sources']} sources")
# for i, chunk in enumerate(results['context_chunks'][:2], 1):
# print(f"{i}. [{chunk['type']}] {chunk['content'][:100]}...")
# Save the trained model
rag_system.save_model("invoice_rag_model.pkl")
print("\nModel saved to invoice_rag_model.pkl")
else:
print(f"Invoice folder {invoice_folder} not found. Please update the path.")
print("To use this system:")
print("1. Create a folder with invoice PDFs")
print("2. Update the invoice_folder path")
print("3. Run the training process")
if __name__ == "__main__":
main()