small-language-model / src /inference.py
sanjanb's picture
Upload folder using huggingface_hub
eb53bb5 verified
"""
Inference pipeline for document text extraction.
Processes new documents and extracts structured information using trained SLM.
"""
import json
import torch
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
import numpy as np
from src.data_preparation import DocumentProcessor
from src.model import DocumentNERModel, NERTrainer, ModelConfig
class DocumentInference:
"""Inference pipeline for extracting structured data from documents."""
def __init__(self, model_path: str):
"""Initialize inference pipeline with trained model."""
self.model_path = model_path
self.config = self._load_config()
self.model = None
self.trainer = None
self.document_processor = DocumentProcessor()
# Load the trained model
self._load_model()
# Post-processing patterns for field validation and formatting
self.postprocess_patterns = {
'DATE': [
r'\b\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}\b',
r'\b\d{4}[/\-]\d{1,2}[/\-]\d{1,2}\b',
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4}\b'
],
'AMOUNT': [
r'\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?',
r'\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|EUR|GBP)'
],
'PHONE': [
r'\+?\d{1,3}[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
r'\(\d{3}\)\s*\d{3}-\d{4}'
],
'EMAIL': [
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
]
}
def _load_config(self) -> ModelConfig:
"""Load training configuration."""
config_path = Path(self.model_path) / "training_config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = ModelConfig(**config_dict)
else:
print("No training config found. Using default configuration.")
config = ModelConfig()
return config
def _load_model(self):
"""Load the trained model and tokenizer."""
try:
# Create model and trainer
self.model = DocumentNERModel(self.config)
self.trainer = NERTrainer(self.model, self.config)
# Load the trained weights
self.trainer.load_model(self.model_path)
print(f"Model loaded successfully from {self.model_path}")
except Exception as e:
raise Exception(f"Failed to load model from {self.model_path}: {e}")
def predict_entities(self, text: str) -> List[Dict[str, Any]]:
"""Predict entities from text using the trained model."""
# Tokenize the text
tokens = text.split()
# Prepare input for the model
inputs = self.trainer.tokenizer(
tokens,
is_split_into_words=True,
padding='max_length',
truncation=True,
max_length=self.config.max_length,
return_tensors='pt'
)
# Move to device
inputs = {k: v.to(self.trainer.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
predictions, probabilities = self.model.predict(
inputs['input_ids'],
inputs['attention_mask']
)
# Convert predictions to labels
word_ids = inputs['input_ids'][0].cpu().numpy()
pred_labels = predictions[0].cpu().numpy()
probs = probabilities[0].cpu().numpy()
# Align predictions with original tokens
word_ids_list = self.trainer.tokenizer.convert_ids_to_tokens(word_ids)
# Extract entities
entities = self._extract_entities_from_predictions(
tokens, pred_labels, probs, word_ids_list
)
return entities
def _extract_entities_from_predictions(self, tokens: List[str],
pred_labels: np.ndarray,
probs: np.ndarray,
word_ids_list: List[str]) -> List[Dict[str, Any]]:
"""Extract entities from model predictions."""
entities = []
current_entity = None
# Map tokenizer output back to original tokens
token_idx = 0
for i, (token_id, label_id) in enumerate(zip(word_ids_list, pred_labels)):
if token_id in ['[CLS]', '[SEP]', '[PAD]']:
continue
label = self.config.id2label.get(label_id, 'O')
confidence = float(np.max(probs[i]))
if label.startswith('B-'):
# Start of new entity
if current_entity:
entities.append(current_entity)
entity_type = label[2:] # Remove 'B-' prefix
current_entity = {
'entity': entity_type,
'text': token_id if not token_id.startswith('##') else token_id[2:],
'start': token_idx,
'end': token_idx + 1,
'confidence': confidence
}
elif label.startswith('I-') and current_entity:
# Continue current entity
entity_type = label[2:] # Remove 'I-' prefix
if current_entity['entity'] == entity_type:
if token_id.startswith('##'):
current_entity['text'] += token_id[2:]
else:
current_entity['text'] += ' ' + token_id
current_entity['end'] = token_idx + 1
current_entity['confidence'] = min(current_entity['confidence'], confidence)
else:
# 'O' label or end of entity
if current_entity:
entities.append(current_entity)
current_entity = None
if not token_id.startswith('##'):
token_idx += 1
# Add the last entity if it exists
if current_entity:
entities.append(current_entity)
return entities
def postprocess_entities(self, entities: List[Dict[str, Any]],
original_text: str) -> Dict[str, Any]:
"""Post-process and structure extracted entities."""
structured_data = {}
for entity in entities:
entity_type = entity['entity']
entity_text = entity['text']
confidence = entity['confidence']
# Apply post-processing patterns for validation
if entity_type in self.postprocess_patterns:
is_valid = self._validate_entity(entity_text, entity_type)
if not is_valid:
continue
# Format the entity value
formatted_value = self._format_entity_value(entity_text, entity_type)
# Store the best entity for each type (highest confidence)
if entity_type not in structured_data or confidence > structured_data[entity_type]['confidence']:
structured_data[entity_type] = {
'value': formatted_value,
'confidence': confidence,
'original_text': entity_text
}
# Convert to final format
final_data = {}
entity_mapping = {
'NAME': 'Name',
'DATE': 'Date',
'INVOICE_NO': 'InvoiceNo',
'AMOUNT': 'Amount',
'ADDRESS': 'Address',
'PHONE': 'Phone',
'EMAIL': 'Email'
}
for entity_type, entity_data in structured_data.items():
human_readable_key = entity_mapping.get(entity_type, entity_type)
final_data[human_readable_key] = entity_data['value']
return final_data
def _validate_entity(self, text: str, entity_type: str) -> bool:
"""Validate entity using regex patterns."""
patterns = self.postprocess_patterns.get(entity_type, [])
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _format_entity_value(self, text: str, entity_type: str) -> str:
"""Format entity value based on its type."""
text = text.strip()
if entity_type == 'DATE':
# Normalize date format
date_patterns = [
(r'(\d{1,2})[/\-](\d{1,2})[/\-](\d{2,4})', r'\1/\2/\3'),
(r'(\d{4})[/\-](\d{1,2})[/\-](\d{1,2})', r'\3/\2/\1')
]
for pattern, replacement in date_patterns:
match = re.search(pattern, text)
if match:
return re.sub(pattern, replacement, text)
elif entity_type == 'AMOUNT':
# Normalize amount format
amount_match = re.search(r'[\$\d,\.]+', text)
if amount_match:
return amount_match.group()
elif entity_type == 'PHONE':
# Normalize phone format
digits = re.sub(r'[^\d]', '', text)
if len(digits) == 10:
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
elif len(digits) == 11 and digits[0] == '1':
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
elif entity_type == 'NAME':
# Capitalize name properly
return ' '.join(word.capitalize() for word in text.split())
return text
def process_document(self, file_path: str) -> Dict[str, Any]:
"""Process a document and extract structured information."""
print(f"Processing document: {file_path}")
try:
# Extract text from document
text = self.document_processor.process_document(file_path)
if not text.strip():
return {
'error': 'No text could be extracted from the document',
'file_path': file_path
}
# Predict entities
entities = self.predict_entities(text)
# Post-process and structure data
structured_data = self.postprocess_entities(entities, text)
# Create result
result = {
'file_path': file_path,
'extracted_text': text[:500] + '...' if len(text) > 500 else text,
'entities': entities,
'structured_data': structured_data,
'processing_timestamp': datetime.now().isoformat(),
'model_path': self.model_path
}
print(f"Successfully processed {file_path}")
print(f" Found {len(entities)} entities")
print(f" Structured fields: {list(structured_data.keys())}")
return result
except Exception as e:
error_result = {
'error': str(e),
'file_path': file_path,
'processing_timestamp': datetime.now().isoformat()
}
print(f"Error processing {file_path}: {e}")
return error_result
def process_text_directly(self, text: str) -> Dict[str, Any]:
"""Process text directly without file operations."""
print("Processing text directly...")
try:
# Clean the text
cleaned_text = self.document_processor.clean_text(text)
# Predict entities
entities = self.predict_entities(cleaned_text)
# Post-process and structure data
structured_data = self.postprocess_entities(entities, cleaned_text)
# Create result
result = {
'original_text': text,
'cleaned_text': cleaned_text,
'entities': entities,
'structured_data': structured_data,
'processing_timestamp': datetime.now().isoformat(),
'model_path': self.model_path
}
print(f"Successfully processed text")
print(f" Found {len(entities)} entities")
print(f" Structured fields: {list(structured_data.keys())}")
return result
except Exception as e:
error_result = {
'error': str(e),
'original_text': text,
'processing_timestamp': datetime.now().isoformat()
}
print(f"Error processing text: {e}")
return error_result
def batch_process_documents(self, file_paths: List[str]) -> List[Dict[str, Any]]:
"""Process multiple documents in batch."""
print(f"Processing {len(file_paths)} documents...")
results = []
for i, file_path in enumerate(file_paths):
print(f"\nProcessing {i+1}/{len(file_paths)}: {Path(file_path).name}")
result = self.process_document(file_path)
results.append(result)
print(f"\nBatch processing completed!")
print(f" Successfully processed: {sum(1 for r in results if 'error' not in r)}")
print(f" Errors: {sum(1 for r in results if 'error' in r)}")
return results
def save_results(self, results: List[Dict[str, Any]], output_path: str):
"""Save processing results to JSON file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_path}")
def create_demo_inference(model_path: str = "models/document_ner_model") -> DocumentInference:
"""Create inference pipeline for demonstration."""
try:
inference = DocumentInference(model_path)
return inference
except Exception as e:
print(f"Failed to create inference pipeline: {e}")
print("Make sure you have trained the model first by running training_pipeline.py")
raise
def demo_text_extraction():
"""Demonstrate text extraction with sample texts."""
print("DOCUMENT TEXT EXTRACTION - INFERENCE DEMO")
print("=" * 60)
# Sample texts for demonstration
sample_texts = [
"Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250",
"Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Phone: (555) 123-4567",
"Receipt for Michael Brown 456 Oak Street Boston MA Email: michael@email.com Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75"
]
# Create inference pipeline
try:
inference = create_demo_inference()
results = []
for i, text in enumerate(sample_texts):
print(f"\nProcessing Sample Text {i+1}:")
print("-" * 40)
print(f"Text: {text}")
result = inference.process_text_directly(text)
results.append(result)
if 'error' not in result:
print(f"Structured Output: {json.dumps(result['structured_data'], indent=2)}")
else:
print(f"Error: {result['error']}")
# Save results
inference.save_results(results, "results/demo_extraction_results.json")
print("\nDemo completed successfully!")
except Exception as e:
print(f"Demo failed: {e}")
def main():
"""Main function for inference demonstration."""
demo_text_extraction()
if __name__ == "__main__":
main()