|
|
"""
|
|
|
Test cases for the document text extraction system.
|
|
|
"""
|
|
|
|
|
|
import unittest
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
import tempfile
|
|
|
import os
|
|
|
|
|
|
from src.data_preparation import DocumentProcessor, NERDatasetCreator
|
|
|
from src.model import ModelConfig, create_model_and_trainer
|
|
|
from src.inference import DocumentInference
|
|
|
|
|
|
|
|
|
class TestDocumentProcessor(unittest.TestCase):
|
|
|
"""Test cases for document processing."""
|
|
|
|
|
|
def setUp(self):
|
|
|
"""Set up test fixtures."""
|
|
|
self.processor = DocumentProcessor()
|
|
|
|
|
|
def test_clean_text(self):
|
|
|
"""Test text cleaning functionality."""
|
|
|
dirty_text = " This is a test text!!! "
|
|
|
clean_text = self.processor.clean_text(dirty_text)
|
|
|
self.assertEqual(clean_text, "This is a test text!")
|
|
|
|
|
|
def test_entity_patterns(self):
|
|
|
"""Test entity pattern matching."""
|
|
|
test_text = "Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
|
|
|
|
|
|
|
|
self.assertIn('NAME', self.processor.entity_patterns)
|
|
|
self.assertIn('DATE', self.processor.entity_patterns)
|
|
|
self.assertIn('INVOICE_NO', self.processor.entity_patterns)
|
|
|
self.assertIn('AMOUNT', self.processor.entity_patterns)
|
|
|
|
|
|
|
|
|
class TestNERDatasetCreator(unittest.TestCase):
|
|
|
"""Test cases for NER dataset creation."""
|
|
|
|
|
|
def setUp(self):
|
|
|
"""Set up test fixtures."""
|
|
|
self.processor = DocumentProcessor()
|
|
|
self.dataset_creator = NERDatasetCreator(self.processor)
|
|
|
|
|
|
def test_auto_label_text(self):
|
|
|
"""Test automatic text labeling."""
|
|
|
test_text = "Invoice sent to Robert White on 15/09/2025 Amount: $1,250"
|
|
|
labeled_tokens = self.dataset_creator.auto_label_text(test_text)
|
|
|
|
|
|
|
|
|
self.assertIsInstance(labeled_tokens, list)
|
|
|
self.assertGreater(len(labeled_tokens), 0)
|
|
|
|
|
|
|
|
|
for token, label in labeled_tokens:
|
|
|
self.assertIsInstance(token, str)
|
|
|
self.assertIsInstance(label, str)
|
|
|
|
|
|
def test_create_training_example(self):
|
|
|
"""Test training example creation."""
|
|
|
test_text = "Invoice INV-1001 for $500"
|
|
|
example = self.dataset_creator.create_training_example(test_text)
|
|
|
|
|
|
|
|
|
self.assertIn('tokens', example)
|
|
|
self.assertIn('labels', example)
|
|
|
self.assertIn('text', example)
|
|
|
|
|
|
|
|
|
self.assertEqual(len(example['tokens']), len(example['labels']))
|
|
|
|
|
|
def test_create_sample_dataset(self):
|
|
|
"""Test sample dataset creation."""
|
|
|
dataset = self.dataset_creator.create_sample_dataset()
|
|
|
|
|
|
|
|
|
self.assertIsInstance(dataset, list)
|
|
|
self.assertGreater(len(dataset), 0)
|
|
|
|
|
|
|
|
|
first_example = dataset[0]
|
|
|
self.assertIn('tokens', first_example)
|
|
|
self.assertIn('labels', first_example)
|
|
|
self.assertIn('text', first_example)
|
|
|
|
|
|
|
|
|
class TestModelConfig(unittest.TestCase):
|
|
|
"""Test cases for model configuration."""
|
|
|
|
|
|
def test_default_config(self):
|
|
|
"""Test default configuration creation."""
|
|
|
config = ModelConfig()
|
|
|
|
|
|
|
|
|
self.assertEqual(config.model_name, "distilbert-base-uncased")
|
|
|
self.assertEqual(config.max_length, 512)
|
|
|
self.assertEqual(config.batch_size, 16)
|
|
|
|
|
|
|
|
|
self.assertIsInstance(config.entity_labels, list)
|
|
|
self.assertGreater(len(config.entity_labels), 0)
|
|
|
self.assertIn('O', config.entity_labels)
|
|
|
|
|
|
|
|
|
self.assertIsInstance(config.label2id, dict)
|
|
|
self.assertIsInstance(config.id2label, dict)
|
|
|
self.assertEqual(len(config.label2id), len(config.entity_labels))
|
|
|
|
|
|
def test_custom_config(self):
|
|
|
"""Test custom configuration."""
|
|
|
custom_labels = ['O', 'B-TEST', 'I-TEST']
|
|
|
config = ModelConfig(
|
|
|
batch_size=32,
|
|
|
learning_rate=1e-5,
|
|
|
entity_labels=custom_labels
|
|
|
)
|
|
|
|
|
|
self.assertEqual(config.batch_size, 32)
|
|
|
self.assertEqual(config.learning_rate, 1e-5)
|
|
|
self.assertEqual(config.entity_labels, custom_labels)
|
|
|
self.assertEqual(config.num_labels, 3)
|
|
|
|
|
|
|
|
|
class TestModelCreation(unittest.TestCase):
|
|
|
"""Test cases for model creation."""
|
|
|
|
|
|
def test_create_model_and_trainer(self):
|
|
|
"""Test model and trainer creation."""
|
|
|
config = ModelConfig(
|
|
|
batch_size=4,
|
|
|
num_epochs=1,
|
|
|
entity_labels=['O', 'B-TEST', 'I-TEST']
|
|
|
)
|
|
|
|
|
|
model, trainer = create_model_and_trainer(config)
|
|
|
|
|
|
|
|
|
self.assertIsNotNone(model)
|
|
|
self.assertIsNotNone(trainer)
|
|
|
|
|
|
|
|
|
self.assertEqual(trainer.config.batch_size, 4)
|
|
|
self.assertEqual(trainer.config.num_epochs, 1)
|
|
|
|
|
|
|
|
|
class TestInference(unittest.TestCase):
|
|
|
"""Test cases for inference pipeline."""
|
|
|
|
|
|
@classmethod
|
|
|
def setUpClass(cls):
|
|
|
"""Set up class-level fixtures."""
|
|
|
|
|
|
|
|
|
cls.model_path = "test_model"
|
|
|
cls.test_text = "Invoice sent to John Doe on 01/15/2025 Amount: $500.00"
|
|
|
|
|
|
def test_entity_validation(self):
|
|
|
"""Test entity validation patterns."""
|
|
|
|
|
|
test_patterns = {
|
|
|
'DATE': ['01/15/2025', '2025-01-15', 'January 15, 2025'],
|
|
|
'AMOUNT': ['$500.00', '$1,250.50', '1000.00 USD'],
|
|
|
'EMAIL': ['test@email.com', 'user.name@domain.co.uk'],
|
|
|
'PHONE': ['(555) 123-4567', '+1-555-987-6543', '555-123-4567']
|
|
|
}
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
|
date_pattern = r'\b\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}\b'
|
|
|
self.assertTrue(re.search(date_pattern, '01/15/2025'))
|
|
|
|
|
|
amount_pattern = r'\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?'
|
|
|
self.assertTrue(re.search(amount_pattern, '$1,250.50'))
|
|
|
|
|
|
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
|
|
self.assertTrue(re.search(email_pattern, 'test@email.com'))
|
|
|
|
|
|
|
|
|
class TestEndToEnd(unittest.TestCase):
|
|
|
"""End-to-end integration tests."""
|
|
|
|
|
|
def test_data_preparation_flow(self):
|
|
|
"""Test the complete data preparation flow."""
|
|
|
|
|
|
processor = DocumentProcessor()
|
|
|
dataset_creator = NERDatasetCreator(processor)
|
|
|
|
|
|
|
|
|
dataset = dataset_creator.create_sample_dataset()
|
|
|
|
|
|
|
|
|
self.assertIsInstance(dataset, list)
|
|
|
self.assertGreater(len(dataset), 0)
|
|
|
|
|
|
for example in dataset:
|
|
|
self.assertIn('tokens', example)
|
|
|
self.assertIn('labels', example)
|
|
|
self.assertIn('text', example)
|
|
|
self.assertEqual(len(example['tokens']), len(example['labels']))
|
|
|
|
|
|
def test_model_config_flow(self):
|
|
|
"""Test model configuration and creation flow."""
|
|
|
|
|
|
config = ModelConfig(batch_size=4, num_epochs=1)
|
|
|
|
|
|
|
|
|
model, trainer = create_model_and_trainer(config)
|
|
|
|
|
|
|
|
|
self.assertIsNotNone(model)
|
|
|
self.assertIsNotNone(trainer)
|
|
|
self.assertEqual(trainer.config.batch_size, 4)
|
|
|
self.assertEqual(trainer.config.num_epochs, 1)
|
|
|
|
|
|
def test_save_and_load_dataset(self):
|
|
|
"""Test saving and loading dataset."""
|
|
|
|
|
|
processor = DocumentProcessor()
|
|
|
dataset_creator = NERDatasetCreator(processor)
|
|
|
dataset = dataset_creator.create_sample_dataset()
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
|
temp_path = f.name
|
|
|
json.dump(dataset, f, indent=2)
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(temp_path, 'r') as f:
|
|
|
loaded_dataset = json.load(f)
|
|
|
|
|
|
self.assertEqual(len(loaded_dataset), len(dataset))
|
|
|
self.assertEqual(loaded_dataset[0]['text'], dataset[0]['text'])
|
|
|
|
|
|
finally:
|
|
|
|
|
|
os.unlink(temp_path)
|
|
|
|
|
|
|
|
|
def run_tests():
|
|
|
"""Run all tests."""
|
|
|
print("Running Document Text Extraction Tests")
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
test_suite = unittest.TestSuite()
|
|
|
|
|
|
|
|
|
test_classes = [
|
|
|
TestDocumentProcessor,
|
|
|
TestNERDatasetCreator,
|
|
|
TestModelConfig,
|
|
|
TestModelCreation,
|
|
|
TestInference,
|
|
|
TestEndToEnd
|
|
|
]
|
|
|
|
|
|
for test_class in test_classes:
|
|
|
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
|
|
test_suite.addTests(tests)
|
|
|
|
|
|
|
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
|
result = runner.run(test_suite)
|
|
|
|
|
|
|
|
|
if result.wasSuccessful():
|
|
|
print(f"\nAll tests passed! ({result.testsRun} tests)")
|
|
|
else:
|
|
|
print(f"\n{len(result.failures)} failures, {len(result.errors)} errors")
|
|
|
|
|
|
if result.failures:
|
|
|
print("\nFailures:")
|
|
|
for test, failure in result.failures:
|
|
|
print(f" {test}: {failure}")
|
|
|
|
|
|
if result.errors:
|
|
|
print("\nErrors:")
|
|
|
for test, error in result.errors:
|
|
|
print(f" {test}: {error}")
|
|
|
|
|
|
return result.wasSuccessful()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run_tests() |