Spaces:
Sleeping
Sleeping
| from google import genai | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict, Tuple | |
| import pdf2image | |
| import os | |
| from pathlib import Path | |
| import concurrent.futures | |
| from dataclasses import dataclass | |
| from functools import partial | |
| import logging | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class InvoiceItem(BaseModel): | |
| """Represents a single item in an invoice.""" | |
| product_name: str = Field(description="The name of the product") | |
| batch_number: str = Field(description="The batch number of the product") | |
| expiry_date: str = Field(description="The expiry date (format: MM/YY)") | |
| mrp: str = Field(description="Maximum Retail Price") | |
| quantity: int = Field(description="Product quantity") | |
| class InvoiceData(BaseModel): | |
| """Represents the complete invoice data including headers.""" | |
| headers: List[str] = Field( | |
| description="Column headers from the invoice table", | |
| default_factory=list | |
| ) | |
| items: List[InvoiceItem] = Field( | |
| description="List of extracted invoice items", | |
| default_factory=list | |
| ) | |
| class HeaderExtraction(BaseModel): | |
| """Model for extracting headers separately.""" | |
| headers: List[str] = Field( | |
| description="The column headers found in the invoice table" | |
| ) | |
| class PageData: | |
| """Container for page processing data.""" | |
| idx: int | |
| image_path: str | |
| headers: List[str] | |
| items: List[InvoiceItem] | |
| def extract_headers(client: genai.Client, image_path: str, model_id: str) -> List[str]: | |
| """ | |
| Extract column headers from the first page of the invoice. | |
| Args: | |
| client: The Gemini API client | |
| image_path: Path to the image file | |
| model_id: The model ID to use for extraction | |
| Returns: | |
| List of column headers | |
| """ | |
| header_prompt = """ | |
| Extract only the column headers from this invoice table. | |
| Return them exactly as they appear, maintaining their order from left to right. | |
| Only extract the headers, not any data from the rows. | |
| """ | |
| image_file = client.files.upload( | |
| file=image_path, | |
| config={'display_name': 'invoice_header_page'} | |
| ) | |
| response = client.models.generate_content( | |
| model=model_id, | |
| contents=[header_prompt, image_file], | |
| config={ | |
| 'response_mime_type': 'application/json', | |
| 'response_schema': HeaderExtraction | |
| } | |
| ) | |
| return response.parsed.headers if response.parsed else [] | |
| def setup_client() -> genai.Client: | |
| """Create and return a Gemini API client.""" | |
| return genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
| def save_image(image: Image, temp_dir: Path, idx: int) -> str: | |
| """ | |
| Save a single page image to disk. | |
| Args: | |
| image: The PDF page image (PIL Image) | |
| temp_dir: Directory to save the image | |
| idx: Page index | |
| Returns: | |
| Path to the saved image | |
| """ | |
| image_path = str(temp_dir / f"page_{idx+1}.jpg") | |
| image.save(image_path, "JPEG") | |
| return image_path | |
| def process_single_page( | |
| page_data: Tuple[int, Image.Image, Path, List[str], genai.Client, str] | |
| ) -> PageData: | |
| """ | |
| Process a single page of the PDF. | |
| Args: | |
| page_data: Tuple containing (page_index, page_image, temp_dir, headers, client, model_id) | |
| Returns: | |
| PageData object containing extracted information | |
| """ | |
| idx, image, temp_dir, headers, client, model_id = page_data | |
| # Save image | |
| image_path = save_image(image, temp_dir, idx) | |
| # First page: extract headers | |
| if idx == 0: | |
| headers = extract_headers(client, image_path, model_id) | |
| prompt = """ | |
| Extract product details from this invoice table. | |
| Use the exact column headers you see in the table. | |
| """ | |
| else: | |
| headers_str = ", ".join(headers) | |
| prompt = f""" | |
| Extract product details from this invoice table. | |
| This is page {idx + 1} of the same invoice. | |
| Use these column headers: {headers_str} | |
| Ensure the extracted data aligns with these columns in order. | |
| """ | |
| # Process image | |
| image_file = client.files.upload( | |
| file=image_path, | |
| config={'display_name': f'invoice_page_{idx+1}'} | |
| ) | |
| response = client.models.generate_content( | |
| model=model_id, | |
| contents=[prompt, image_file], | |
| config={ | |
| 'response_mime_type': 'application/json', | |
| 'response_schema': InvoiceData | |
| } | |
| ) | |
| items = response.parsed.items if response.parsed and response.parsed.items else [] | |
| return PageData(idx=idx, image_path=image_path, headers=headers, items=items) | |
| def process_pdf_with_headers(pdf_path: str, max_workers: int = 3) -> InvoiceData: | |
| """ | |
| Process a PDF invoice while preserving column header context using parallel processing. | |
| Args: | |
| pdf_path: Path to the PDF file | |
| max_workers: Maximum number of concurrent workers | |
| Returns: | |
| InvoiceData object containing headers and extracted items | |
| """ | |
| # Convert PDF pages to images | |
| images = pdf2image.convert_from_path(pdf_path) | |
| # Create temp directory | |
| temp_dir = Path("content/temp") | |
| temp_dir.mkdir(parents=True, exist_ok=True) | |
| # Initialize shared resources | |
| client = setup_client() | |
| model_id = "gemini-2.0-flash" | |
| headers: List[str] = [] | |
| # Prepare data for parallel processing | |
| page_data = [] | |
| try: | |
| # Process first page separately to get headers | |
| first_page = process_single_page((0, images[0], temp_dir, headers, client, model_id)) | |
| headers = first_page.headers | |
| all_items = first_page.items | |
| # Prepare remaining pages for parallel processing | |
| remaining_pages = [ | |
| (i, img, temp_dir, headers, client, model_id) | |
| for i, img in enumerate(images[1:], start=1) | |
| ] | |
| # Process remaining pages in parallel | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| future_to_page = { | |
| executor.submit(process_single_page, page): page[0] | |
| for page in remaining_pages | |
| } | |
| # Collect results as they complete | |
| for future in concurrent.futures.as_completed(future_to_page): | |
| page_idx = future_to_page[future] | |
| try: | |
| page_result = future.result() | |
| all_items.extend(page_result.items) | |
| except Exception as e: | |
| logging.error(f"Error processing page {page_idx}: {str(e)}") | |
| finally: | |
| # Cleanup temporary files | |
| for file in temp_dir.glob("*.jpg"): | |
| try: | |
| file.unlink() | |
| except Exception as e: | |
| logging.warning(f"Failed to delete temporary file {file}: {str(e)}") | |
| return InvoiceData(headers=headers, items=all_items) | |
| def main(): | |
| """Main function to demonstrate usage.""" | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| try: | |
| invoice_data = process_pdf_with_headers( | |
| "/Users/krishnaadithya/Desktop/dev/invoice_processing_2.0/pdf_only/expiry_invoice/DR REDDYS PE 1194.pdf", | |
| max_workers=3 # Adjust based on your system and API limits | |
| ) | |
| # Print headers | |
| print("Column Headers:", ", ".join(invoice_data.headers)) | |
| print("\nExtracted Items:") | |
| # Print results | |
| for item in invoice_data.items: | |
| print(f"Product: {item.product_name}") | |
| print(f"Batch: {item.batch_number}") | |
| print(f"Expiry: {item.expiry_date}") | |
| print(f"MRP: {item.mrp}") | |
| print(f"Quantity: {item.quantity}") | |
| print("-" * 50) | |
| except Exception as e: | |
| logging.error(f"Error processing invoice: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |