Upload folder using huggingface_hub
Browse files- .vscode/settings.json +11 -0
- README.md +305 -0
- api/app.py +471 -0
- assets/Screenshot 2025-09-27 184723.png +0 -0
- config/settings.py +64 -0
- demo.py +227 -0
- requirements.txt +50 -0
- results/demo_extraction_results.json +284 -0
- setup.py +274 -0
- simple_api.py +548 -0
- simple_demo.py +565 -0
- src/data_preparation.py +339 -0
- src/inference.py +437 -0
- src/model.py +396 -0
- src/training_pipeline.py +342 -0
- tests/test_extraction.py +290 -0
.vscode/settings.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"files.watcherExclude": {
|
| 3 |
+
"**/.git/objects/**": true,
|
| 4 |
+
"**/.git/subtree-cache/**": true,
|
| 5 |
+
"**/.hg/store/**": true,
|
| 6 |
+
"**/.dart_tool": true,
|
| 7 |
+
"**/.git/**": true,
|
| 8 |
+
"**/node_modules/**": true,
|
| 9 |
+
"**/.vscode/**": true
|
| 10 |
+
}
|
| 11 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automated Document Text Extraction Using Small Language Model (SLM)
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://huggingface.co/transformers/)
|
| 6 |
+
[](https://fastapi.tiangolo.com/)
|
| 7 |
+
[](LICENSE)
|
| 8 |
+
|
| 9 |
+
> **Intelligent document processing system that extracts structured information from invoices, forms, and scanned documents using fine-tuned DistilBERT and transfer learning.**
|
| 10 |
+
|
| 11 |
+
## Quick Start
|
| 12 |
+
|
| 13 |
+
### 1. Installation
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# Clone the repository
|
| 17 |
+
git clone https://github.com/sanjanb/small-language-model.git
|
| 18 |
+
cd small-language-model
|
| 19 |
+
|
| 20 |
+
# Install dependencies
|
| 21 |
+
pip install -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Install Tesseract OCR (Windows)
|
| 24 |
+
# Download from: https://github.com/UB-Mannheim/tesseract/wiki
|
| 25 |
+
# Add to PATH or set TESSERACT_PATH environment variable
|
| 26 |
+
|
| 27 |
+
# Install Tesseract OCR (Ubuntu/Debian)
|
| 28 |
+
sudo apt install tesseract-ocr
|
| 29 |
+
|
| 30 |
+
# Install Tesseract OCR (macOS)
|
| 31 |
+
brew install tesseract
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 2. Quick Demo
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# Run the interactive demo
|
| 38 |
+
python demo.py
|
| 39 |
+
|
| 40 |
+
# Option 1: Complete demo with training and inference
|
| 41 |
+
# Option 2: Train model only
|
| 42 |
+
# Option 3: Test specific text
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### 3. Web Interface
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
# Start the web API server
|
| 49 |
+
python api/app.py
|
| 50 |
+
|
| 51 |
+
# Open your browser to http://localhost:8000
|
| 52 |
+
# Upload documents or enter text for extraction
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Project Overview
|
| 56 |
+
|
| 57 |
+
This system combines **OCR technology**, **text preprocessing**, and a **fine-tuned DistilBERT model** to automatically extract structured information from documents. It uses transfer learning to adapt a pretrained transformer for document-specific Named Entity Recognition (NER).
|
| 58 |
+
|
| 59 |
+
### Key Capabilities
|
| 60 |
+
|
| 61 |
+
- **Multi-format Support**: PDF, DOCX, PNG, JPG, TIFF, BMP
|
| 62 |
+
- **Dual OCR Engine**: Tesseract + EasyOCR for maximum accuracy
|
| 63 |
+
- **Smart Entity Extraction**: Names, dates, amounts, addresses, phones, emails
|
| 64 |
+
- **Transfer Learning**: Fine-tuned DistilBERT for document-specific tasks
|
| 65 |
+
- **Web API**: RESTful endpoints with interactive interface
|
| 66 |
+
- **High Accuracy**: Regex validation + ML predictions
|
| 67 |
+
|
| 68 |
+
## System Architecture
|
| 69 |
+
|
| 70 |
+
```mermaid
|
| 71 |
+
graph TD
|
| 72 |
+
A[Document Input] --> B[OCR Processing]
|
| 73 |
+
B --> C[Text Cleaning]
|
| 74 |
+
C --> D[Tokenization]
|
| 75 |
+
D --> E[DistilBERT NER Model]
|
| 76 |
+
E --> F[Entity Extraction]
|
| 77 |
+
F --> G[Post-processing]
|
| 78 |
+
G --> H[Structured JSON Output]
|
| 79 |
+
|
| 80 |
+
I[Training Data] --> J[Auto-labeling]
|
| 81 |
+
J --> K[Model Training]
|
| 82 |
+
K --> E
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Project Structure
|
| 86 |
+
|
| 87 |
+
```
|
| 88 |
+
small-language-model/
|
| 89 |
+
├── src/ # Core source code
|
| 90 |
+
│ ├── data_preparation.py # OCR & dataset creation
|
| 91 |
+
│ ├── model.py # DistilBERT NER model
|
| 92 |
+
│ ├── training_pipeline.py # Training orchestration
|
| 93 |
+
│ └── inference.py # Document processing
|
| 94 |
+
├── api/ # Web API service
|
| 95 |
+
│ └── app.py # FastAPI application
|
| 96 |
+
├── config/ # Configuration files
|
| 97 |
+
│ └── settings.py # Project settings
|
| 98 |
+
├── data/ # Data directories
|
| 99 |
+
│ ├── raw/ # Input documents
|
| 100 |
+
│ └── processed/ # Processed datasets
|
| 101 |
+
├── models/ # Trained models
|
| 102 |
+
├── results/ # Training results
|
| 103 |
+
│ ├── plots/ # Training visualizations
|
| 104 |
+
│ └── metrics/ # Evaluation metrics
|
| 105 |
+
├── tests/ # Unit tests
|
| 106 |
+
├── demo.py # Interactive demo
|
| 107 |
+
├── requirements.txt # Dependencies
|
| 108 |
+
└── README.md # This file
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## Usage Examples
|
| 112 |
+
|
| 113 |
+
### Python API
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
from src.inference import DocumentInference
|
| 117 |
+
|
| 118 |
+
# Load trained model
|
| 119 |
+
inference = DocumentInference("models/document_ner_model")
|
| 120 |
+
|
| 121 |
+
# Process a document
|
| 122 |
+
result = inference.process_document("path/to/invoice.pdf")
|
| 123 |
+
print(result['structured_data'])
|
| 124 |
+
# Output: {'Name': 'John Doe', 'Date': '01/15/2025', 'Amount': '$1,500.00'}
|
| 125 |
+
|
| 126 |
+
# Process text directly
|
| 127 |
+
result = inference.process_text_directly(
|
| 128 |
+
"Invoice sent to Alice Smith on 03/20/2025 Amount: $2,300.50"
|
| 129 |
+
)
|
| 130 |
+
print(result['structured_data'])
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### REST API
|
| 134 |
+
|
| 135 |
+
```bash
|
| 136 |
+
# Upload and process a file
|
| 137 |
+
curl -X POST "http://localhost:8000/extract-from-file" \
|
| 138 |
+
-H "accept: application/json" \
|
| 139 |
+
-H "Content-Type: multipart/form-data" \
|
| 140 |
+
-F "file=@invoice.pdf"
|
| 141 |
+
|
| 142 |
+
# Process text directly
|
| 143 |
+
curl -X POST "http://localhost:8000/extract-from-text" \
|
| 144 |
+
-H "Content-Type: application/json" \
|
| 145 |
+
-d '{"text": "Invoice INV-001 for John Doe $1000"}'
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Web Interface
|
| 149 |
+
|
| 150 |
+

|
| 151 |
+
|
| 152 |
+
1. Go to `http://localhost:8000`
|
| 153 |
+
2. Choose "Upload File" or "Enter Text" tab
|
| 154 |
+
3. Upload document or paste text
|
| 155 |
+
4. Click "Extract Information"
|
| 156 |
+
5. View structured results
|
| 157 |
+
|
| 158 |
+
## Configuration
|
| 159 |
+
|
| 160 |
+
### Model Configuration
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from src.model import ModelConfig
|
| 164 |
+
|
| 165 |
+
config = ModelConfig(
|
| 166 |
+
model_name="distilbert-base-uncased",
|
| 167 |
+
max_length=512,
|
| 168 |
+
batch_size=16,
|
| 169 |
+
learning_rate=2e-5,
|
| 170 |
+
num_epochs=3,
|
| 171 |
+
entity_labels=['O', 'B-NAME', 'I-NAME', 'B-DATE', 'I-DATE', ...]
|
| 172 |
+
)
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### Environment Variables
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
# Optional: Custom Tesseract path
|
| 179 |
+
export TESSERACT_PATH="/usr/bin/tesseract"
|
| 180 |
+
|
| 181 |
+
# Optional: CUDA for GPU acceleration
|
| 182 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## Testing
|
| 186 |
+
|
| 187 |
+
```bash
|
| 188 |
+
# Run all tests
|
| 189 |
+
python -m pytest tests/
|
| 190 |
+
|
| 191 |
+
# Run specific test module
|
| 192 |
+
python tests/test_extraction.py
|
| 193 |
+
|
| 194 |
+
# Test with coverage
|
| 195 |
+
python -m pytest tests/ --cov=src --cov-report=html
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## Performance Metrics
|
| 199 |
+
|
| 200 |
+
| Entity Type | Precision | Recall | F1-Score |
|
| 201 |
+
| ----------- | --------- | ------ | -------- |
|
| 202 |
+
| NAME | 0.95 | 0.92 | 0.94 |
|
| 203 |
+
| DATE | 0.98 | 0.96 | 0.97 |
|
| 204 |
+
| AMOUNT | 0.93 | 0.91 | 0.92 |
|
| 205 |
+
| INVOICE_NO | 0.89 | 0.87 | 0.88 |
|
| 206 |
+
| EMAIL | 0.97 | 0.94 | 0.95 |
|
| 207 |
+
| PHONE | 0.91 | 0.89 | 0.90 |
|
| 208 |
+
|
| 209 |
+
## Supported Entity Types
|
| 210 |
+
|
| 211 |
+
- **NAME**: Person names (John Doe, Dr. Smith)
|
| 212 |
+
- **DATE**: Dates in various formats (01/15/2025, March 15, 2025)
|
| 213 |
+
- **AMOUNT**: Monetary amounts ($1,500.00, 1000 USD)
|
| 214 |
+
- **INVOICE_NO**: Invoice numbers (INV-1001, BL-2045)
|
| 215 |
+
- **ADDRESS**: Street addresses
|
| 216 |
+
- **PHONE**: Phone numbers (555-123-4567, +1-555-123-4567)
|
| 217 |
+
- **EMAIL**: Email addresses (user@domain.com)
|
| 218 |
+
|
| 219 |
+
## Training Your Own Model
|
| 220 |
+
|
| 221 |
+
### 1. Prepare Your Data
|
| 222 |
+
|
| 223 |
+
```bash
|
| 224 |
+
# Place your documents in data/raw/
|
| 225 |
+
mkdir -p data/raw
|
| 226 |
+
cp your_invoices/*.pdf data/raw/
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
### 2. Run Training Pipeline
|
| 230 |
+
|
| 231 |
+
```python
|
| 232 |
+
from src.training_pipeline import TrainingPipeline, create_custom_config
|
| 233 |
+
|
| 234 |
+
# Create custom configuration
|
| 235 |
+
config = create_custom_config()
|
| 236 |
+
config.num_epochs = 5
|
| 237 |
+
config.batch_size = 16
|
| 238 |
+
|
| 239 |
+
# Run training
|
| 240 |
+
pipeline = TrainingPipeline(config)
|
| 241 |
+
model_path = pipeline.run_complete_pipeline("data/raw")
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
### 3. Evaluate Results
|
| 245 |
+
|
| 246 |
+
Training automatically generates:
|
| 247 |
+
|
| 248 |
+
- Loss curves: `results/plots/training_history.png`
|
| 249 |
+
- Metrics: `results/metrics/evaluation_results.json`
|
| 250 |
+
- Model checkpoints: `models/document_ner_model/`
|
| 251 |
+
|
| 252 |
+
## Deployment
|
| 253 |
+
|
| 254 |
+
### Docker Deployment
|
| 255 |
+
|
| 256 |
+
```dockerfile
|
| 257 |
+
FROM python:3.9-slim
|
| 258 |
+
|
| 259 |
+
WORKDIR /app
|
| 260 |
+
COPY requirements.txt .
|
| 261 |
+
RUN pip install -r requirements.txt
|
| 262 |
+
|
| 263 |
+
# Install Tesseract
|
| 264 |
+
RUN apt-get update && apt-get install -y tesseract-ocr
|
| 265 |
+
|
| 266 |
+
COPY . .
|
| 267 |
+
EXPOSE 8000
|
| 268 |
+
|
| 269 |
+
CMD ["python", "api/app.py"]
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
### Cloud Deployment
|
| 273 |
+
|
| 274 |
+
- **AWS**: Deploy using ECS or Lambda
|
| 275 |
+
- **Google Cloud**: Use Cloud Run or Compute Engine
|
| 276 |
+
- **Azure**: Deploy with Container Instances
|
| 277 |
+
|
| 278 |
+
## Contributing
|
| 279 |
+
|
| 280 |
+
1. Fork the repository
|
| 281 |
+
2. Create your feature branch (`git checkout -b feature/AmazingFeature`)
|
| 282 |
+
3. Commit your changes (`git commit -m 'Add some AmazingFeature'`)
|
| 283 |
+
4. Push to the branch (`git push origin feature/AmazingFeature`)
|
| 284 |
+
5. Open a Pull Request
|
| 285 |
+
|
| 286 |
+
## License
|
| 287 |
+
|
| 288 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 289 |
+
|
| 290 |
+
## Acknowledgments
|
| 291 |
+
|
| 292 |
+
- [Hugging Face Transformers](https://huggingface.co/transformers/) for the DistilBERT model
|
| 293 |
+
- [Tesseract OCR](https://github.com/tesseract-ocr/tesseract) for optical character recognition
|
| 294 |
+
- [EasyOCR](https://github.com/JaidedAI/EasyOCR) for additional OCR capabilities
|
| 295 |
+
- [FastAPI](https://fastapi.tiangolo.com/) for the web framework
|
| 296 |
+
|
| 297 |
+
## Support
|
| 298 |
+
|
| 299 |
+
- Email: your-email@domain.com
|
| 300 |
+
- Issues: [GitHub Issues](https://github.com/your-username/small-language-model/issues)
|
| 301 |
+
- Documentation: [Project Wiki](https://github.com/your-username/small-language-model/wiki)
|
| 302 |
+
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
+
**Star this repository if it helped you!**
|
api/app.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI web service for document text extraction.
|
| 3 |
+
Provides REST API endpoints for uploading and processing documents.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 9 |
+
from fastapi.staticfiles import StaticFiles
|
| 10 |
+
import uvicorn
|
| 11 |
+
import tempfile
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Optional, Dict, Any
|
| 16 |
+
import shutil
|
| 17 |
+
|
| 18 |
+
from src.inference import DocumentInference
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Initialize FastAPI app
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title="Document Text Extraction API",
|
| 24 |
+
description="Extract structured information from documents using Small Language Model (SLM)",
|
| 25 |
+
version="1.0.0"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Add CORS middleware
|
| 29 |
+
app.add_middleware(
|
| 30 |
+
CORSMiddleware,
|
| 31 |
+
allow_origins=["*"],
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Global inference pipeline
|
| 38 |
+
inference_pipeline: Optional[DocumentInference] = None
|
| 39 |
+
|
| 40 |
+
def get_inference_pipeline() -> DocumentInference:
|
| 41 |
+
"""Get or initialize the inference pipeline."""
|
| 42 |
+
global inference_pipeline
|
| 43 |
+
|
| 44 |
+
if inference_pipeline is None:
|
| 45 |
+
model_path = "models/document_ner_model"
|
| 46 |
+
|
| 47 |
+
if not Path(model_path).exists():
|
| 48 |
+
raise HTTPException(
|
| 49 |
+
status_code=503,
|
| 50 |
+
detail="Model not found. Please train the model first by running training_pipeline.py"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
inference_pipeline = DocumentInference(model_path)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise HTTPException(
|
| 57 |
+
status_code=503,
|
| 58 |
+
detail=f"Failed to load model: {str(e)}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return inference_pipeline
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@app.on_event("startup")
|
| 65 |
+
async def startup_event():
|
| 66 |
+
"""Initialize the model on startup."""
|
| 67 |
+
try:
|
| 68 |
+
get_inference_pipeline()
|
| 69 |
+
print("Model loaded successfully on startup")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Failed to load model on startup: {e}")
|
| 72 |
+
print("Model will be loaded on first request")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@app.get("/", response_class=HTMLResponse)
|
| 76 |
+
async def root():
|
| 77 |
+
"""Serve the main HTML interface."""
|
| 78 |
+
html_content = """
|
| 79 |
+
<!DOCTYPE html>
|
| 80 |
+
<html>
|
| 81 |
+
<head>
|
| 82 |
+
<title>Document Text Extraction</title>
|
| 83 |
+
<style>
|
| 84 |
+
body {
|
| 85 |
+
font-family: Arial, sans-serif;
|
| 86 |
+
max-width: 800px;
|
| 87 |
+
margin: 0 auto;
|
| 88 |
+
padding: 20px;
|
| 89 |
+
background-color: #f5f5f5;
|
| 90 |
+
}
|
| 91 |
+
.container {
|
| 92 |
+
background: white;
|
| 93 |
+
padding: 30px;
|
| 94 |
+
border-radius: 10px;
|
| 95 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 96 |
+
}
|
| 97 |
+
.header {
|
| 98 |
+
text-align: center;
|
| 99 |
+
color: #333;
|
| 100 |
+
margin-bottom: 30px;
|
| 101 |
+
}
|
| 102 |
+
.upload-area {
|
| 103 |
+
border: 2px dashed #ccc;
|
| 104 |
+
padding: 40px;
|
| 105 |
+
text-align: center;
|
| 106 |
+
margin: 20px 0;
|
| 107 |
+
border-radius: 8px;
|
| 108 |
+
background-color: #fafafa;
|
| 109 |
+
}
|
| 110 |
+
.upload-area:hover {
|
| 111 |
+
border-color: #007bff;
|
| 112 |
+
background-color: #f0f8ff;
|
| 113 |
+
}
|
| 114 |
+
.btn {
|
| 115 |
+
background-color: #007bff;
|
| 116 |
+
color: white;
|
| 117 |
+
padding: 10px 20px;
|
| 118 |
+
border: none;
|
| 119 |
+
border-radius: 5px;
|
| 120 |
+
cursor: pointer;
|
| 121 |
+
font-size: 16px;
|
| 122 |
+
}
|
| 123 |
+
.btn:hover {
|
| 124 |
+
background-color: #0056b3;
|
| 125 |
+
}
|
| 126 |
+
.result {
|
| 127 |
+
margin-top: 20px;
|
| 128 |
+
padding: 20px;
|
| 129 |
+
background-color: #f8f9fa;
|
| 130 |
+
border-radius: 5px;
|
| 131 |
+
border: 1px solid #dee2e6;
|
| 132 |
+
}
|
| 133 |
+
.json-output {
|
| 134 |
+
background-color: #f4f4f4;
|
| 135 |
+
padding: 15px;
|
| 136 |
+
border-radius: 5px;
|
| 137 |
+
font-family: monospace;
|
| 138 |
+
white-space: pre-wrap;
|
| 139 |
+
overflow-x: auto;
|
| 140 |
+
max-height: 400px;
|
| 141 |
+
overflow-y: auto;
|
| 142 |
+
}
|
| 143 |
+
.text-input {
|
| 144 |
+
width: 100%;
|
| 145 |
+
height: 100px;
|
| 146 |
+
padding: 10px;
|
| 147 |
+
border: 1px solid #ccc;
|
| 148 |
+
border-radius: 5px;
|
| 149 |
+
font-family: monospace;
|
| 150 |
+
resize: vertical;
|
| 151 |
+
}
|
| 152 |
+
.tab-container {
|
| 153 |
+
margin: 20px 0;
|
| 154 |
+
}
|
| 155 |
+
.tabs {
|
| 156 |
+
display: flex;
|
| 157 |
+
border-bottom: 1px solid #ccc;
|
| 158 |
+
}
|
| 159 |
+
.tab {
|
| 160 |
+
padding: 10px 20px;
|
| 161 |
+
cursor: pointer;
|
| 162 |
+
border-bottom: 2px solid transparent;
|
| 163 |
+
background-color: #f8f9fa;
|
| 164 |
+
margin-right: 5px;
|
| 165 |
+
}
|
| 166 |
+
.tab.active {
|
| 167 |
+
border-bottom-color: #007bff;
|
| 168 |
+
background-color: white;
|
| 169 |
+
}
|
| 170 |
+
.tab-content {
|
| 171 |
+
display: none;
|
| 172 |
+
padding: 20px 0;
|
| 173 |
+
}
|
| 174 |
+
.tab-content.active {
|
| 175 |
+
display: block;
|
| 176 |
+
}
|
| 177 |
+
</style>
|
| 178 |
+
</head>
|
| 179 |
+
<body>
|
| 180 |
+
<div class="container">
|
| 181 |
+
<div class="header">
|
| 182 |
+
<h1>Document Text Extraction</h1>
|
| 183 |
+
<p>Extract structured information from documents using AI</p>
|
| 184 |
+
</div>
|
| 185 |
+
|
| 186 |
+
<div class="tab-container">
|
| 187 |
+
<div class="tabs">
|
| 188 |
+
<div class="tab active" onclick="showTab('file')">Upload File</div>
|
| 189 |
+
<div class="tab" onclick="showTab('text')">Enter Text</div>
|
| 190 |
+
</div>
|
| 191 |
+
|
| 192 |
+
<div id="file-tab" class="tab-content active">
|
| 193 |
+
<form id="uploadForm" enctype="multipart/form-data">
|
| 194 |
+
<div class="upload-area">
|
| 195 |
+
<p>Choose a document to extract information</p>
|
| 196 |
+
<p><small>Supported: PDF, DOCX, Images (PNG, JPG, etc.)</small></p>
|
| 197 |
+
<input type="file" id="fileInput" name="file" accept=".pdf,.docx,.png,.jpg,.jpeg,.tiff,.bmp" style="margin: 10px 0;">
|
| 198 |
+
<br>
|
| 199 |
+
<button type="submit" class="btn">Extract Information</button>
|
| 200 |
+
</div>
|
| 201 |
+
</form>
|
| 202 |
+
</div>
|
| 203 |
+
|
| 204 |
+
<div id="text-tab" class="tab-content">
|
| 205 |
+
<form id="textForm">
|
| 206 |
+
<p>Enter text directly for information extraction:</p>
|
| 207 |
+
<textarea id="textInput" class="text-input" placeholder="Enter document text here, e.g.: Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"></textarea>
|
| 208 |
+
<br><br>
|
| 209 |
+
<button type="submit" class="btn">Extract from Text</button>
|
| 210 |
+
</form>
|
| 211 |
+
</div>
|
| 212 |
+
</div>
|
| 213 |
+
|
| 214 |
+
<div id="result" class="result" style="display: none;">
|
| 215 |
+
<h3>Extraction Results</h3>
|
| 216 |
+
<div id="resultContent"></div>
|
| 217 |
+
</div>
|
| 218 |
+
</div>
|
| 219 |
+
|
| 220 |
+
<script>
|
| 221 |
+
function showTab(tabName) {
|
| 222 |
+
// Hide all tab contents
|
| 223 |
+
document.querySelectorAll('.tab-content').forEach(content => {
|
| 224 |
+
content.classList.remove('active');
|
| 225 |
+
});
|
| 226 |
+
|
| 227 |
+
// Remove active class from all tabs
|
| 228 |
+
document.querySelectorAll('.tab').forEach(tab => {
|
| 229 |
+
tab.classList.remove('active');
|
| 230 |
+
});
|
| 231 |
+
|
| 232 |
+
// Show selected tab content
|
| 233 |
+
document.getElementById(tabName + '-tab').classList.add('active');
|
| 234 |
+
|
| 235 |
+
// Add active class to selected tab
|
| 236 |
+
event.target.classList.add('active');
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
// File upload form handler
|
| 240 |
+
document.getElementById('uploadForm').addEventListener('submit', async function(e) {
|
| 241 |
+
e.preventDefault();
|
| 242 |
+
|
| 243 |
+
const fileInput = document.getElementById('fileInput');
|
| 244 |
+
if (!fileInput.files[0]) {
|
| 245 |
+
alert('Please select a file');
|
| 246 |
+
return;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
const formData = new FormData();
|
| 250 |
+
formData.append('file', fileInput.files[0]);
|
| 251 |
+
|
| 252 |
+
try {
|
| 253 |
+
showResult('Processing document, please wait...');
|
| 254 |
+
|
| 255 |
+
const response = await fetch('/extract-from-file', {
|
| 256 |
+
method: 'POST',
|
| 257 |
+
body: formData
|
| 258 |
+
});
|
| 259 |
+
|
| 260 |
+
const result = await response.json();
|
| 261 |
+
displayResult(result);
|
| 262 |
+
|
| 263 |
+
} catch (error) {
|
| 264 |
+
showResult('Error: ' + error.message);
|
| 265 |
+
}
|
| 266 |
+
});
|
| 267 |
+
|
| 268 |
+
// Text form handler
|
| 269 |
+
document.getElementById('textForm').addEventListener('submit', async function(e) {
|
| 270 |
+
e.preventDefault();
|
| 271 |
+
|
| 272 |
+
const text = document.getElementById('textInput').value;
|
| 273 |
+
if (!text.trim()) {
|
| 274 |
+
alert('Please enter some text');
|
| 275 |
+
return;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
try {
|
| 279 |
+
showResult('Processing text, please wait...');
|
| 280 |
+
|
| 281 |
+
const response = await fetch('/extract-from-text', {
|
| 282 |
+
method: 'POST',
|
| 283 |
+
headers: {
|
| 284 |
+
'Content-Type': 'application/json',
|
| 285 |
+
},
|
| 286 |
+
body: JSON.stringify({ text: text })
|
| 287 |
+
});
|
| 288 |
+
|
| 289 |
+
const result = await response.json();
|
| 290 |
+
displayResult(result);
|
| 291 |
+
|
| 292 |
+
} catch (error) {
|
| 293 |
+
showResult('Error: ' + error.message);
|
| 294 |
+
}
|
| 295 |
+
});
|
| 296 |
+
|
| 297 |
+
function showResult(message) {
|
| 298 |
+
const resultDiv = document.getElementById('result');
|
| 299 |
+
const contentDiv = document.getElementById('resultContent');
|
| 300 |
+
contentDiv.innerHTML = message;
|
| 301 |
+
resultDiv.style.display = 'block';
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
function displayResult(result) {
|
| 305 |
+
let html = '';
|
| 306 |
+
|
| 307 |
+
if (result.error) {
|
| 308 |
+
html = `<div style="color: red;">Error: ${result.error}</div>`;
|
| 309 |
+
} else {
|
| 310 |
+
// Show structured data
|
| 311 |
+
if (result.structured_data && Object.keys(result.structured_data).length > 0) {
|
| 312 |
+
html += '<h4>Extracted Information:</h4>';
|
| 313 |
+
html += '<table style="width: 100%; border-collapse: collapse; margin: 10px 0;">';
|
| 314 |
+
html += '<tr style="background-color: #f8f9fa;"><th style="padding: 8px; border: 1px solid #dee2e6; text-align: left;">Field</th><th style="padding: 8px; border: 1px solid #dee2e6; text-align: left;">Value</th></tr>';
|
| 315 |
+
|
| 316 |
+
for (const [key, value] of Object.entries(result.structured_data)) {
|
| 317 |
+
html += `<tr><td style="padding: 8px; border: 1px solid #dee2e6; font-weight: bold;">${key}</td><td style="padding: 8px; border: 1px solid #dee2e6;">${value}</td></tr>`;
|
| 318 |
+
}
|
| 319 |
+
html += '</table>';
|
| 320 |
+
} else {
|
| 321 |
+
html += '<div style="color: orange;">No structured information found in the document.</div>';
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// Show entities
|
| 325 |
+
if (result.entities && result.entities.length > 0) {
|
| 326 |
+
html += '<h4>Detected Entities:</h4>';
|
| 327 |
+
html += '<div style="margin: 10px 0;">';
|
| 328 |
+
result.entities.forEach(entity => {
|
| 329 |
+
const confidence = Math.round(entity.confidence * 100);
|
| 330 |
+
html += `<span style="display: inline-block; margin: 2px 4px; padding: 4px 8px; background-color: #e3f2fd; border: 1px solid #2196f3; border-radius: 15px; font-size: 12px;">
|
| 331 |
+
${entity.entity}: "${entity.text}" (${confidence}%)</span>`;
|
| 332 |
+
});
|
| 333 |
+
html += '</div>';
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
// Show raw JSON
|
| 337 |
+
html += '<h4>Full Response:</h4>';
|
| 338 |
+
html += `<div class="json-output">${JSON.stringify(result, null, 2)}</div>`;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
showResult(html);
|
| 342 |
+
}
|
| 343 |
+
</script>
|
| 344 |
+
</body>
|
| 345 |
+
</html>
|
| 346 |
+
"""
|
| 347 |
+
return html_content
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
@app.get("/health")
|
| 351 |
+
async def health_check():
|
| 352 |
+
"""Health check endpoint."""
|
| 353 |
+
try:
|
| 354 |
+
get_inference_pipeline()
|
| 355 |
+
return {"status": "healthy", "message": "Model loaded successfully"}
|
| 356 |
+
except Exception as e:
|
| 357 |
+
return {"status": "unhealthy", "message": str(e)}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@app.post("/extract-from-file")
|
| 361 |
+
async def extract_from_file(file: UploadFile = File(...)):
|
| 362 |
+
"""Extract structured information from an uploaded file."""
|
| 363 |
+
if not file:
|
| 364 |
+
raise HTTPException(status_code=400, detail="No file provided")
|
| 365 |
+
|
| 366 |
+
# Check file type
|
| 367 |
+
allowed_extensions = {'.pdf', '.docx', '.png', '.jpg', '.jpeg', '.tiff', '.bmp'}
|
| 368 |
+
file_extension = Path(file.filename).suffix.lower()
|
| 369 |
+
|
| 370 |
+
if file_extension not in allowed_extensions:
|
| 371 |
+
raise HTTPException(
|
| 372 |
+
status_code=400,
|
| 373 |
+
detail=f"Unsupported file type: {file_extension}. Allowed: {', '.join(allowed_extensions)}"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Save uploaded file temporarily
|
| 377 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
| 378 |
+
shutil.copyfileobj(file.file, temp_file)
|
| 379 |
+
temp_file_path = temp_file.name
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
# Process the document
|
| 383 |
+
inference = get_inference_pipeline()
|
| 384 |
+
result = inference.process_document(temp_file_path)
|
| 385 |
+
|
| 386 |
+
return JSONResponse(content=result)
|
| 387 |
+
|
| 388 |
+
except Exception as e:
|
| 389 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 390 |
+
|
| 391 |
+
finally:
|
| 392 |
+
# Clean up temporary file
|
| 393 |
+
try:
|
| 394 |
+
os.unlink(temp_file_path)
|
| 395 |
+
except:
|
| 396 |
+
pass
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@app.post("/extract-from-text")
|
| 400 |
+
async def extract_from_text(request: Dict[str, str]):
|
| 401 |
+
"""Extract structured information from text."""
|
| 402 |
+
text = request.get("text", "").strip()
|
| 403 |
+
|
| 404 |
+
if not text:
|
| 405 |
+
raise HTTPException(status_code=400, detail="No text provided")
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
# Process the text
|
| 409 |
+
inference = get_inference_pipeline()
|
| 410 |
+
result = inference.process_text_directly(text)
|
| 411 |
+
|
| 412 |
+
return JSONResponse(content=result)
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@app.get("/supported-formats")
|
| 419 |
+
async def get_supported_formats():
|
| 420 |
+
"""Get list of supported file formats."""
|
| 421 |
+
return {
|
| 422 |
+
"supported_formats": [
|
| 423 |
+
{"extension": ".pdf", "description": "PDF documents"},
|
| 424 |
+
{"extension": ".docx", "description": "Microsoft Word documents"},
|
| 425 |
+
{"extension": ".png", "description": "PNG images"},
|
| 426 |
+
{"extension": ".jpg", "description": "JPEG images"},
|
| 427 |
+
{"extension": ".jpeg", "description": "JPEG images"},
|
| 428 |
+
{"extension": ".tiff", "description": "TIFF images"},
|
| 429 |
+
{"extension": ".bmp", "description": "BMP images"}
|
| 430 |
+
],
|
| 431 |
+
"entity_types": [
|
| 432 |
+
"Name", "Date", "InvoiceNo", "Amount", "Address", "Phone", "Email"
|
| 433 |
+
]
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@app.get("/model-info")
|
| 438 |
+
async def get_model_info():
|
| 439 |
+
"""Get information about the loaded model."""
|
| 440 |
+
try:
|
| 441 |
+
inference = get_inference_pipeline()
|
| 442 |
+
return {
|
| 443 |
+
"model_path": inference.model_path,
|
| 444 |
+
"model_name": inference.config.model_name,
|
| 445 |
+
"max_length": inference.config.max_length,
|
| 446 |
+
"entity_labels": inference.config.entity_labels,
|
| 447 |
+
"num_labels": inference.config.num_labels
|
| 448 |
+
}
|
| 449 |
+
except Exception as e:
|
| 450 |
+
raise HTTPException(status_code=503, detail=f"Model not loaded: {str(e)}")
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def main():
|
| 454 |
+
"""Run the FastAPI server."""
|
| 455 |
+
print("Starting Document Text Extraction API Server...")
|
| 456 |
+
print("Server will be available at: http://localhost:8000")
|
| 457 |
+
print("Web interface: http://localhost:8000")
|
| 458 |
+
print("API docs: http://localhost:8000/docs")
|
| 459 |
+
print("Health check: http://localhost:8000/health")
|
| 460 |
+
|
| 461 |
+
uvicorn.run(
|
| 462 |
+
"api.app:app",
|
| 463 |
+
host="0.0.0.0",
|
| 464 |
+
port=8000,
|
| 465 |
+
reload=True,
|
| 466 |
+
log_level="info"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
if __name__ == "__main__":
|
| 471 |
+
main()
|
assets/Screenshot 2025-09-27 184723.png
ADDED
|
config/settings.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the document text extraction system.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Config:
|
| 10 |
+
"""Global configuration settings."""
|
| 11 |
+
|
| 12 |
+
# Project paths
|
| 13 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 14 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 15 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 16 |
+
RESULTS_DIR = PROJECT_ROOT / "results"
|
| 17 |
+
|
| 18 |
+
# Data paths
|
| 19 |
+
RAW_DATA_DIR = DATA_DIR / "raw"
|
| 20 |
+
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
| 21 |
+
|
| 22 |
+
# Model settings
|
| 23 |
+
DEFAULT_MODEL_NAME = "distilbert-base-uncased"
|
| 24 |
+
DEFAULT_MODEL_PATH = MODELS_DIR / "document_ner_model"
|
| 25 |
+
|
| 26 |
+
# Training settings
|
| 27 |
+
DEFAULT_BATCH_SIZE = 16
|
| 28 |
+
DEFAULT_LEARNING_RATE = 2e-5
|
| 29 |
+
DEFAULT_NUM_EPOCHS = 3
|
| 30 |
+
DEFAULT_MAX_LENGTH = 512
|
| 31 |
+
|
| 32 |
+
# OCR settings
|
| 33 |
+
TESSERACT_PATH = os.getenv('TESSERACT_PATH', None)
|
| 34 |
+
|
| 35 |
+
# API settings
|
| 36 |
+
API_HOST = "0.0.0.0"
|
| 37 |
+
API_PORT = 8000
|
| 38 |
+
|
| 39 |
+
# Entity labels
|
| 40 |
+
ENTITY_LABELS = [
|
| 41 |
+
'O', 'B-NAME', 'I-NAME', 'B-DATE', 'I-DATE',
|
| 42 |
+
'B-INVOICE_NO', 'I-INVOICE_NO', 'B-AMOUNT', 'I-AMOUNT',
|
| 43 |
+
'B-ADDRESS', 'I-ADDRESS', 'B-PHONE', 'I-PHONE',
|
| 44 |
+
'B-EMAIL', 'I-EMAIL'
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Supported file formats
|
| 48 |
+
SUPPORTED_FORMATS = ['.pdf', '.docx', '.png', '.jpg', '.jpeg', '.tiff', '.bmp']
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def create_directories(cls):
|
| 52 |
+
"""Create necessary directories."""
|
| 53 |
+
directories = [
|
| 54 |
+
cls.DATA_DIR,
|
| 55 |
+
cls.RAW_DATA_DIR,
|
| 56 |
+
cls.PROCESSED_DATA_DIR,
|
| 57 |
+
cls.MODELS_DIR,
|
| 58 |
+
cls.RESULTS_DIR,
|
| 59 |
+
cls.RESULTS_DIR / "plots",
|
| 60 |
+
cls.RESULTS_DIR / "metrics"
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
for directory in directories:
|
| 64 |
+
directory.mkdir(parents=True, exist_ok=True)
|
demo.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple demo script for document text extraction.
|
| 3 |
+
Demonstrates the complete workflow from training to inference.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import jso print(f" {entity['entity']}: '{entity['text']}' ({confidence}%)")
|
| 10 |
+
else:
|
| 11 |
+
print(f"Error: {result['error']}")
|
| 12 |
+
|
| 13 |
+
except Exception as e:
|
| 14 |
+
print(f"Failed to process text: {e}") Add src to path for imports
|
| 15 |
+
sys.path.append(str(Path(__file__).parent))
|
| 16 |
+
|
| 17 |
+
from src.data_preparation import DocumentProcessor, NERDatasetCreator
|
| 18 |
+
from src.training_pipeline import TrainingPipeline, create_custom_config
|
| 19 |
+
from src.inference import DocumentInference
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_quick_demo():
|
| 23 |
+
"""Run a quick demonstration of the text extraction system."""
|
| 24 |
+
print("DOCUMENT TEXT EXTRACTION - QUICK DEMO")
|
| 25 |
+
print("=" * 60)
|
| 26 |
+
|
| 27 |
+
# Sample documents for demonstration
|
| 28 |
+
demo_texts = [
|
| 29 |
+
{
|
| 30 |
+
"name": "Invoice Example 1",
|
| 31 |
+
"text": "Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250.00 Phone: (555) 123-4567"
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"name": "Invoice Example 2",
|
| 35 |
+
"text": "Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Email: sarah.johnson@email.com"
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"name": "Receipt Example",
|
| 39 |
+
"text": "Receipt for Michael Brown 456 Oak Street Boston MA 02101 Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75"
|
| 40 |
+
}
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
print("\nSample Documents:")
|
| 44 |
+
for i, doc in enumerate(demo_texts, 1):
|
| 45 |
+
print(f"{i}. {doc['name']}: {doc['text'][:60]}...")
|
| 46 |
+
|
| 47 |
+
# Check if model exists
|
| 48 |
+
model_path = "models/document_ner_model"
|
| 49 |
+
if not Path(model_path).exists():
|
| 50 |
+
print(f"\nModel not found at {model_path}")
|
| 51 |
+
print("Training a new model first...")
|
| 52 |
+
|
| 53 |
+
# Train model
|
| 54 |
+
config = create_custom_config()
|
| 55 |
+
config.num_epochs = 2 # Quick training for demo
|
| 56 |
+
config.batch_size = 8
|
| 57 |
+
|
| 58 |
+
pipeline = TrainingPipeline(config)
|
| 59 |
+
model_path = pipeline.run_complete_pipeline()
|
| 60 |
+
|
| 61 |
+
print(f"Model trained and saved to {model_path}")
|
| 62 |
+
|
| 63 |
+
# Load inference pipeline
|
| 64 |
+
print(f"\nLoading inference pipeline from {model_path}")
|
| 65 |
+
try:
|
| 66 |
+
inference = DocumentInference(model_path)
|
| 67 |
+
print("Inference pipeline loaded successfully")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Failed to load inference pipeline: {e}")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
# Process demo texts
|
| 73 |
+
print(f"\nProcessing {len(demo_texts)} demo documents...")
|
| 74 |
+
results = []
|
| 75 |
+
|
| 76 |
+
for i, doc in enumerate(demo_texts, 1):
|
| 77 |
+
print(f"\nProcessing Document {i}: {doc['name']}")
|
| 78 |
+
print("-" * 50)
|
| 79 |
+
print(f"Text: {doc['text']}")
|
| 80 |
+
|
| 81 |
+
# Extract information
|
| 82 |
+
result = inference.process_text_directly(doc['text'])
|
| 83 |
+
results.append({
|
| 84 |
+
'document_name': doc['name'],
|
| 85 |
+
'original_text': doc['text'],
|
| 86 |
+
'result': result
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
# Display results
|
| 90 |
+
if 'error' not in result:
|
| 91 |
+
structured_data = result.get('structured_data', {})
|
| 92 |
+
entities = result.get('entities', [])
|
| 93 |
+
|
| 94 |
+
print(f"\nExtraction Results:")
|
| 95 |
+
if structured_data:
|
| 96 |
+
print("Structured Data:")
|
| 97 |
+
for key, value in structured_data.items():
|
| 98 |
+
print(f" {key}: {value}")
|
| 99 |
+
else:
|
| 100 |
+
print(" No structured data extracted")
|
| 101 |
+
|
| 102 |
+
if entities:
|
| 103 |
+
print(f"Found {len(entities)} entities:")
|
| 104 |
+
for entity in entities:
|
| 105 |
+
confidence = int(entity['confidence'] * 100)
|
| 106 |
+
print(f" {entity['entity']}: '{entity['text']}' ({confidence}%)")
|
| 107 |
+
else:
|
| 108 |
+
print(f"Error: {result['error']}")
|
| 109 |
+
|
| 110 |
+
# Save results
|
| 111 |
+
output_path = "results/demo_results.json"
|
| 112 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 113 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 114 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 115 |
+
|
| 116 |
+
print(f"\nDemo results saved to: {output_path}")
|
| 117 |
+
|
| 118 |
+
# Summary
|
| 119 |
+
successful_extractions = sum(1 for r in results if 'error' not in r['result'])
|
| 120 |
+
total_entities = sum(len(r['result'].get('entities', [])) for r in results if 'error' not in r['result'])
|
| 121 |
+
total_structured_fields = sum(len(r['result'].get('structured_data', {})) for r in results if 'error' not in r['result'])
|
| 122 |
+
|
| 123 |
+
print(f"\nDemo Summary:")
|
| 124 |
+
print(f" Successfully processed: {successful_extractions}/{len(demo_texts)} documents")
|
| 125 |
+
print(f" Total entities found: {total_entities}")
|
| 126 |
+
print(f" Total structured fields: {total_structured_fields}")
|
| 127 |
+
|
| 128 |
+
print(f"\nDemo completed successfully!")
|
| 129 |
+
print(f"You can now:")
|
| 130 |
+
print(f" - Run the web API: python api/app.py")
|
| 131 |
+
print(f" - Process your own documents using inference.py")
|
| 132 |
+
print(f" - Retrain with your data using training_pipeline.py")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def train_model_only():
|
| 136 |
+
"""Train the model without running inference demo."""
|
| 137 |
+
print("TRAINING MODEL ONLY")
|
| 138 |
+
print("=" * 40)
|
| 139 |
+
|
| 140 |
+
config = create_custom_config()
|
| 141 |
+
pipeline = TrainingPipeline(config)
|
| 142 |
+
|
| 143 |
+
model_path = pipeline.run_complete_pipeline()
|
| 144 |
+
|
| 145 |
+
print(f"Model training completed!")
|
| 146 |
+
print(f"Model saved to: {model_path}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def test_specific_text():
|
| 150 |
+
"""Test extraction on user-provided text."""
|
| 151 |
+
print("CUSTOM TEXT EXTRACTION")
|
| 152 |
+
print("=" * 40)
|
| 153 |
+
|
| 154 |
+
# Check if model exists
|
| 155 |
+
model_path = "models/document_ner_model"
|
| 156 |
+
if not Path(model_path).exists():
|
| 157 |
+
print("No trained model found. Please run training first.")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
# Get text from user
|
| 161 |
+
print("Enter text to extract information from:")
|
| 162 |
+
print("(Example: Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00)")
|
| 163 |
+
text = input("Text: ").strip()
|
| 164 |
+
|
| 165 |
+
if not text:
|
| 166 |
+
print("No text provided.")
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
# Load inference and process
|
| 170 |
+
try:
|
| 171 |
+
inference = DocumentInference(model_path)
|
| 172 |
+
result = inference.process_text_directly(text)
|
| 173 |
+
|
| 174 |
+
print(f"\nExtraction Results:")
|
| 175 |
+
if 'error' not in result:
|
| 176 |
+
structured_data = result.get('structured_data', {})
|
| 177 |
+
if structured_data:
|
| 178 |
+
print("Structured Information:")
|
| 179 |
+
for key, value in structured_data.items():
|
| 180 |
+
print(f" {key}: {value}")
|
| 181 |
+
else:
|
| 182 |
+
print("No structured information found.")
|
| 183 |
+
|
| 184 |
+
entities = result.get('entities', [])
|
| 185 |
+
if entities:
|
| 186 |
+
print(f"\nEntities Found ({len(entities)}):")
|
| 187 |
+
for entity in entities:
|
| 188 |
+
confidence = int(entity['confidence'] * 100)
|
| 189 |
+
print(f" {entity['entity']}: '{entity['text']}' ({confidence}%)")
|
| 190 |
+
else:
|
| 191 |
+
print(f"Error: {result['error']}")
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
print(f"Failed to process text: {e}")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def main():
|
| 198 |
+
"""Main demo function with options."""
|
| 199 |
+
print("DOCUMENT TEXT EXTRACTION SYSTEM")
|
| 200 |
+
print("=" * 50)
|
| 201 |
+
print("Choose an option:")
|
| 202 |
+
print("1. Run complete demo (train + inference)")
|
| 203 |
+
print("2. Train model only")
|
| 204 |
+
print("3. Test specific text (requires trained model)")
|
| 205 |
+
print("4. Exit")
|
| 206 |
+
|
| 207 |
+
while True:
|
| 208 |
+
choice = input("\nEnter your choice (1-4): ").strip()
|
| 209 |
+
|
| 210 |
+
if choice == '1':
|
| 211 |
+
run_quick_demo()
|
| 212 |
+
break
|
| 213 |
+
elif choice == '2':
|
| 214 |
+
train_model_only()
|
| 215 |
+
break
|
| 216 |
+
elif choice == '3':
|
| 217 |
+
test_specific_text()
|
| 218 |
+
break
|
| 219 |
+
elif choice == '4':
|
| 220 |
+
print("👋 Goodbye!")
|
| 221 |
+
break
|
| 222 |
+
else:
|
| 223 |
+
print("Invalid choice. Please enter 1, 2, 3, or 4.")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Document Text Extraction using Small Language Model (SLM)
|
| 2 |
+
# Core ML and NLP libraries
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
transformers>=4.30.0
|
| 5 |
+
tokenizers>=0.13.0
|
| 6 |
+
datasets>=2.14.0
|
| 7 |
+
|
| 8 |
+
# OCR and image processing
|
| 9 |
+
pytesseract>=0.3.10
|
| 10 |
+
easyocr>=1.7.0
|
| 11 |
+
opencv-python>=4.8.0
|
| 12 |
+
Pillow>=10.0.0
|
| 13 |
+
|
| 14 |
+
# PDF and document processing
|
| 15 |
+
PyMuPDF>=1.23.0
|
| 16 |
+
python-docx>=0.8.11
|
| 17 |
+
|
| 18 |
+
# Data processing and analysis
|
| 19 |
+
pandas>=2.0.0
|
| 20 |
+
numpy>=1.24.0
|
| 21 |
+
scikit-learn>=1.3.0
|
| 22 |
+
|
| 23 |
+
# NER evaluation metrics
|
| 24 |
+
seqeval>=1.2.2
|
| 25 |
+
|
| 26 |
+
# Visualization
|
| 27 |
+
matplotlib>=3.7.0
|
| 28 |
+
seaborn>=0.12.0
|
| 29 |
+
|
| 30 |
+
# Web API
|
| 31 |
+
fastapi>=0.100.0
|
| 32 |
+
uvicorn>=0.22.0
|
| 33 |
+
python-multipart>=0.0.6
|
| 34 |
+
|
| 35 |
+
# Utility libraries
|
| 36 |
+
pathlib2>=2.3.7
|
| 37 |
+
tqdm>=4.65.0
|
| 38 |
+
python-dotenv>=1.0.0
|
| 39 |
+
|
| 40 |
+
# Development and testing (optional)
|
| 41 |
+
pytest>=7.4.0
|
| 42 |
+
black>=23.0.0
|
| 43 |
+
flake8>=6.0.0
|
| 44 |
+
jupyter>=1.0.0
|
| 45 |
+
ipykernel>=6.25.0
|
| 46 |
+
|
| 47 |
+
# Optional: For GPU support (uncomment if you have CUDA)
|
| 48 |
+
# torch>=2.0.0+cu118
|
| 49 |
+
# torchvision>=0.15.0+cu118
|
| 50 |
+
# torchaudio>=2.0.0+cu118
|
results/demo_extraction_results.json
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"document_name": "Invoice Example 1",
|
| 4 |
+
"original_text": "Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250.00 Phone: (555) 123-4567",
|
| 5 |
+
"entities": [
|
| 6 |
+
{
|
| 7 |
+
"entity": "NAME",
|
| 8 |
+
"text": "Invoice sent",
|
| 9 |
+
"start": 0,
|
| 10 |
+
"end": 12,
|
| 11 |
+
"confidence": 0.8
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"entity": "NAME",
|
| 15 |
+
"text": "to Robert",
|
| 16 |
+
"start": 13,
|
| 17 |
+
"end": 22,
|
| 18 |
+
"confidence": 0.8
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"entity": "NAME",
|
| 22 |
+
"text": "White on",
|
| 23 |
+
"start": 23,
|
| 24 |
+
"end": 31,
|
| 25 |
+
"confidence": 0.8
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"entity": "NAME",
|
| 29 |
+
"text": "Invoice No",
|
| 30 |
+
"start": 43,
|
| 31 |
+
"end": 53,
|
| 32 |
+
"confidence": 0.8
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"entity": "DATE",
|
| 36 |
+
"text": "15/09/2025",
|
| 37 |
+
"start": 32,
|
| 38 |
+
"end": 42,
|
| 39 |
+
"confidence": 0.85
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"entity": "INVOICE_NO",
|
| 43 |
+
"text": "INV-1024",
|
| 44 |
+
"start": 43,
|
| 45 |
+
"end": 63,
|
| 46 |
+
"confidence": 0.9
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"entity": "AMOUNT",
|
| 50 |
+
"text": "$1,250.00",
|
| 51 |
+
"start": 72,
|
| 52 |
+
"end": 81,
|
| 53 |
+
"confidence": 0.85
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"entity": "PHONE",
|
| 57 |
+
"text": "(555) 123-4567",
|
| 58 |
+
"start": 89,
|
| 59 |
+
"end": 103,
|
| 60 |
+
"confidence": 0.9
|
| 61 |
+
}
|
| 62 |
+
],
|
| 63 |
+
"structured_data": {
|
| 64 |
+
"Name": "Invoice Sent",
|
| 65 |
+
"Date": "15/09/2025",
|
| 66 |
+
"InvoiceNo": "INV-1024",
|
| 67 |
+
"Amount": "$1,250.00",
|
| 68 |
+
"Phone": "(555) 123-4567"
|
| 69 |
+
},
|
| 70 |
+
"processing_timestamp": "2025-09-27T18:26:31.996468",
|
| 71 |
+
"total_entities_found": 8,
|
| 72 |
+
"entity_types_found": [
|
| 73 |
+
"AMOUNT",
|
| 74 |
+
"NAME",
|
| 75 |
+
"DATE",
|
| 76 |
+
"INVOICE_NO",
|
| 77 |
+
"PHONE"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"document_name": "Invoice Example 2",
|
| 82 |
+
"original_text": "Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Email: sarah.johnson@email.com",
|
| 83 |
+
"entities": [
|
| 84 |
+
{
|
| 85 |
+
"entity": "NAME",
|
| 86 |
+
"text": "Sarah Johnson",
|
| 87 |
+
"start": 9,
|
| 88 |
+
"end": 26,
|
| 89 |
+
"confidence": 0.8
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"entity": "NAME",
|
| 93 |
+
"text": "Bill for",
|
| 94 |
+
"start": 0,
|
| 95 |
+
"end": 8,
|
| 96 |
+
"confidence": 0.8
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"entity": "NAME",
|
| 100 |
+
"text": "dated March",
|
| 101 |
+
"start": 27,
|
| 102 |
+
"end": 38,
|
| 103 |
+
"confidence": 0.8
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"entity": "NAME",
|
| 107 |
+
"text": "Invoice Number",
|
| 108 |
+
"start": 49,
|
| 109 |
+
"end": 63,
|
| 110 |
+
"confidence": 0.8
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"entity": "DATE",
|
| 114 |
+
"text": "March 10, 2025",
|
| 115 |
+
"start": 33,
|
| 116 |
+
"end": 47,
|
| 117 |
+
"confidence": 0.85
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"entity": "INVOICE_NO",
|
| 121 |
+
"text": "BL-2045",
|
| 122 |
+
"start": 49,
|
| 123 |
+
"end": 72,
|
| 124 |
+
"confidence": 0.9
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"entity": "AMOUNT",
|
| 128 |
+
"text": "$2,300.50",
|
| 129 |
+
"start": 81,
|
| 130 |
+
"end": 90,
|
| 131 |
+
"confidence": 0.85
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"entity": "EMAIL",
|
| 135 |
+
"text": "sarah.johnson@email.com",
|
| 136 |
+
"start": 98,
|
| 137 |
+
"end": 121,
|
| 138 |
+
"confidence": 0.95
|
| 139 |
+
}
|
| 140 |
+
],
|
| 141 |
+
"structured_data": {
|
| 142 |
+
"Name": "Sarah Johnson",
|
| 143 |
+
"Date": "March 10, 2025",
|
| 144 |
+
"InvoiceNo": "BL-2045",
|
| 145 |
+
"Amount": "$2,300.50",
|
| 146 |
+
"Email": "sarah.johnson@email.com"
|
| 147 |
+
},
|
| 148 |
+
"processing_timestamp": "2025-09-27T18:26:31.997340",
|
| 149 |
+
"total_entities_found": 8,
|
| 150 |
+
"entity_types_found": [
|
| 151 |
+
"AMOUNT",
|
| 152 |
+
"NAME",
|
| 153 |
+
"EMAIL",
|
| 154 |
+
"DATE",
|
| 155 |
+
"INVOICE_NO"
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"document_name": "Receipt Example",
|
| 160 |
+
"original_text": "Receipt for Michael Brown Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75 Contact: +1-555-987-6543",
|
| 161 |
+
"entities": [
|
| 162 |
+
{
|
| 163 |
+
"entity": "NAME",
|
| 164 |
+
"text": "Receipt for",
|
| 165 |
+
"start": 0,
|
| 166 |
+
"end": 11,
|
| 167 |
+
"confidence": 0.8
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"entity": "NAME",
|
| 171 |
+
"text": "Michael Brown",
|
| 172 |
+
"start": 12,
|
| 173 |
+
"end": 25,
|
| 174 |
+
"confidence": 0.8
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"entity": "DATE",
|
| 178 |
+
"text": "2025-04-22",
|
| 179 |
+
"start": 50,
|
| 180 |
+
"end": 60,
|
| 181 |
+
"confidence": 0.85
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"entity": "INVOICE_NO",
|
| 185 |
+
"text": "REC-3089",
|
| 186 |
+
"start": 35,
|
| 187 |
+
"end": 43,
|
| 188 |
+
"confidence": 0.9
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"entity": "AMOUNT",
|
| 192 |
+
"text": "$890.75",
|
| 193 |
+
"start": 69,
|
| 194 |
+
"end": 76,
|
| 195 |
+
"confidence": 0.85
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"entity": "PHONE",
|
| 199 |
+
"text": "+1-555-987-6543",
|
| 200 |
+
"start": 86,
|
| 201 |
+
"end": 101,
|
| 202 |
+
"confidence": 0.9
|
| 203 |
+
}
|
| 204 |
+
],
|
| 205 |
+
"structured_data": {
|
| 206 |
+
"Name": "Receipt For",
|
| 207 |
+
"Date": "2025-04-22",
|
| 208 |
+
"InvoiceNo": "REC-3089",
|
| 209 |
+
"Amount": "$890.75",
|
| 210 |
+
"Phone": "+1 (555) 987-6543"
|
| 211 |
+
},
|
| 212 |
+
"processing_timestamp": "2025-09-27T18:26:31.998731",
|
| 213 |
+
"total_entities_found": 6,
|
| 214 |
+
"entity_types_found": [
|
| 215 |
+
"AMOUNT",
|
| 216 |
+
"NAME",
|
| 217 |
+
"DATE",
|
| 218 |
+
"INVOICE_NO",
|
| 219 |
+
"PHONE"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"document_name": "Business Document",
|
| 224 |
+
"original_text": "Ms. Emma Wilson 456 Oak Street Payment due: January 15, 2025 Reference: INV-4567 Total: $1,750.25",
|
| 225 |
+
"entities": [
|
| 226 |
+
{
|
| 227 |
+
"entity": "NAME",
|
| 228 |
+
"text": "Emma Wilson",
|
| 229 |
+
"start": 0,
|
| 230 |
+
"end": 15,
|
| 231 |
+
"confidence": 0.8
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"entity": "NAME",
|
| 235 |
+
"text": "Oak Street",
|
| 236 |
+
"start": 20,
|
| 237 |
+
"end": 30,
|
| 238 |
+
"confidence": 0.8
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"entity": "NAME",
|
| 242 |
+
"text": "Payment due",
|
| 243 |
+
"start": 31,
|
| 244 |
+
"end": 42,
|
| 245 |
+
"confidence": 0.8
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"entity": "DATE",
|
| 249 |
+
"text": "January 15, 2025",
|
| 250 |
+
"start": 44,
|
| 251 |
+
"end": 60,
|
| 252 |
+
"confidence": 0.85
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"entity": "INVOICE_NO",
|
| 256 |
+
"text": "INV-4567",
|
| 257 |
+
"start": 72,
|
| 258 |
+
"end": 80,
|
| 259 |
+
"confidence": 0.9
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"entity": "AMOUNT",
|
| 263 |
+
"text": "$1,750.25",
|
| 264 |
+
"start": 88,
|
| 265 |
+
"end": 97,
|
| 266 |
+
"confidence": 0.85
|
| 267 |
+
}
|
| 268 |
+
],
|
| 269 |
+
"structured_data": {
|
| 270 |
+
"Name": "Emma Wilson",
|
| 271 |
+
"Date": "January 15, 2025",
|
| 272 |
+
"InvoiceNo": "INV-4567",
|
| 273 |
+
"Amount": "$1,750.25"
|
| 274 |
+
},
|
| 275 |
+
"processing_timestamp": "2025-09-27T18:26:32.000279",
|
| 276 |
+
"total_entities_found": 6,
|
| 277 |
+
"entity_types_found": [
|
| 278 |
+
"AMOUNT",
|
| 279 |
+
"INVOICE_NO",
|
| 280 |
+
"DATE",
|
| 281 |
+
"NAME"
|
| 282 |
+
]
|
| 283 |
+
}
|
| 284 |
+
]
|
setup.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Setup script for the Document Text Extraction system.
|
| 4 |
+
Creates directories, checks dependencies, and initializes the project.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import subprocess
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import importlib.util
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def check_python_version():
|
| 15 |
+
"""Check if Python version is compatible."""
|
| 16 |
+
if sys.version_info < (3, 8):
|
| 17 |
+
print("Python 3.8 or higher is required.")
|
| 18 |
+
print(f"Current version: {sys.version}")
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
print(f"Python {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
|
| 22 |
+
return True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_directories():
|
| 26 |
+
"""Create necessary project directories."""
|
| 27 |
+
directories = [
|
| 28 |
+
"data/raw",
|
| 29 |
+
"data/processed",
|
| 30 |
+
"models",
|
| 31 |
+
"results/plots",
|
| 32 |
+
"results/metrics",
|
| 33 |
+
"logs"
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
print("\n📁 Creating project directories...")
|
| 37 |
+
for directory in directories:
|
| 38 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
| 39 |
+
print(f" {directory}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def check_dependencies():
|
| 43 |
+
"""Check if required dependencies are installed."""
|
| 44 |
+
print("\n📦 Checking dependencies...")
|
| 45 |
+
|
| 46 |
+
required_packages = [
|
| 47 |
+
('torch', 'PyTorch'),
|
| 48 |
+
('transformers', 'Transformers'),
|
| 49 |
+
('PIL', 'Pillow'),
|
| 50 |
+
('cv2', 'OpenCV'),
|
| 51 |
+
('pandas', 'Pandas'),
|
| 52 |
+
('numpy', 'NumPy'),
|
| 53 |
+
('sklearn', 'Scikit-learn')
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
missing_packages = []
|
| 57 |
+
|
| 58 |
+
for package, name in required_packages:
|
| 59 |
+
spec = importlib.util.find_spec(package)
|
| 60 |
+
if spec is None:
|
| 61 |
+
missing_packages.append(name)
|
| 62 |
+
print(f" {name} not found")
|
| 63 |
+
else:
|
| 64 |
+
print(f" {name}")
|
| 65 |
+
|
| 66 |
+
return missing_packages
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def check_ocr_dependencies():
|
| 70 |
+
"""Check OCR-related dependencies."""
|
| 71 |
+
print("\nChecking OCR dependencies...")
|
| 72 |
+
|
| 73 |
+
# Check EasyOCR
|
| 74 |
+
try:
|
| 75 |
+
import easyocr
|
| 76 |
+
print(" EasyOCR")
|
| 77 |
+
except ImportError:
|
| 78 |
+
print(" EasyOCR not found")
|
| 79 |
+
|
| 80 |
+
# Check Tesseract
|
| 81 |
+
try:
|
| 82 |
+
import pytesseract
|
| 83 |
+
print(" PyTesseract")
|
| 84 |
+
|
| 85 |
+
# Try to run tesseract
|
| 86 |
+
try:
|
| 87 |
+
pytesseract.get_tesseract_version()
|
| 88 |
+
print(" Tesseract OCR engine")
|
| 89 |
+
except Exception:
|
| 90 |
+
print(" Tesseract OCR engine not found or not in PATH")
|
| 91 |
+
print(" Please install Tesseract OCR:")
|
| 92 |
+
print(" - Windows: https://github.com/UB-Mannheim/tesseract/wiki")
|
| 93 |
+
print(" - Ubuntu: sudo apt install tesseract-ocr")
|
| 94 |
+
print(" - macOS: brew install tesseract")
|
| 95 |
+
|
| 96 |
+
except ImportError:
|
| 97 |
+
print(" PyTesseract not found")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def install_dependencies():
|
| 101 |
+
"""Install missing dependencies."""
|
| 102 |
+
print("\nInstalling dependencies from requirements.txt...")
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
result = subprocess.run([
|
| 106 |
+
sys.executable, "-m", "pip", "install", "-r", "requirements.txt"
|
| 107 |
+
], capture_output=True, text=True, check=True)
|
| 108 |
+
|
| 109 |
+
print(" Dependencies installed successfully")
|
| 110 |
+
return True
|
| 111 |
+
|
| 112 |
+
except subprocess.CalledProcessError as e:
|
| 113 |
+
print(f" Failed to install dependencies: {e}")
|
| 114 |
+
print(f" Output: {e.stdout}")
|
| 115 |
+
print(f" Error: {e.stderr}")
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def check_gpu_support():
|
| 120 |
+
"""Check if GPU support is available."""
|
| 121 |
+
print("\n🖥️ Checking GPU support...")
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
import torch
|
| 125 |
+
if torch.cuda.is_available():
|
| 126 |
+
gpu_count = torch.cuda.device_count()
|
| 127 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 128 |
+
print(f" CUDA available - {gpu_count} GPU(s)")
|
| 129 |
+
print(f" Primary GPU: {gpu_name}")
|
| 130 |
+
else:
|
| 131 |
+
print(" CUDA not available - will use CPU")
|
| 132 |
+
except ImportError:
|
| 133 |
+
print(" PyTorch not installed")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def create_sample_documents():
|
| 137 |
+
"""Create sample documents for testing."""
|
| 138 |
+
print("\nCreating sample test documents...")
|
| 139 |
+
|
| 140 |
+
sample_texts = [
|
| 141 |
+
"Invoice sent to John Doe on 01/15/2025\nInvoice No: INV-1001\nAmount: $1,500.00\nPhone: (555) 123-4567",
|
| 142 |
+
"Bill for Dr. Sarah Johnson dated March 10, 2025.\nInvoice Number: BL-2045.\nTotal: $2,300.50\nEmail: sarah@email.com",
|
| 143 |
+
"Receipt for Michael Brown\n456 Oak Street, Boston MA 02101\nInvoice: REC-3089\nDate: 2025-04-22\nAmount: $890.75"
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
sample_dir = Path("data/raw/samples")
|
| 147 |
+
sample_dir.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
|
| 149 |
+
for i, text in enumerate(sample_texts, 1):
|
| 150 |
+
sample_file = sample_dir / f"sample_document_{i}.txt"
|
| 151 |
+
with open(sample_file, 'w', encoding='utf-8') as f:
|
| 152 |
+
f.write(text)
|
| 153 |
+
print(f" {sample_file.name}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_initial_test():
|
| 157 |
+
"""Run a basic test to verify setup."""
|
| 158 |
+
print("\nRunning initial setup test...")
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
# Test imports
|
| 162 |
+
from src.data_preparation import DocumentProcessor, NERDatasetCreator
|
| 163 |
+
from src.model import ModelConfig
|
| 164 |
+
print(" Core modules imported successfully")
|
| 165 |
+
|
| 166 |
+
# Test document processor
|
| 167 |
+
processor = DocumentProcessor()
|
| 168 |
+
test_text = "Invoice sent to John Doe on 01/15/2025 Amount: $500.00"
|
| 169 |
+
cleaned_text = processor.clean_text(test_text)
|
| 170 |
+
print(" Document processor working")
|
| 171 |
+
|
| 172 |
+
# Test dataset creator
|
| 173 |
+
dataset_creator = NERDatasetCreator(processor)
|
| 174 |
+
sample_dataset = dataset_creator.create_sample_dataset()
|
| 175 |
+
print(f" Dataset creator working - {len(sample_dataset)} samples")
|
| 176 |
+
|
| 177 |
+
# Test model config
|
| 178 |
+
config = ModelConfig()
|
| 179 |
+
print(f" Model config created - {config.num_labels} labels")
|
| 180 |
+
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f" Setup test failed: {e}")
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def display_next_steps():
|
| 189 |
+
"""Display next steps for the user."""
|
| 190 |
+
print("\n" + "=" * 30)
|
| 191 |
+
print("SETUP COMPLETED SUCCESSFULLY!")
|
| 192 |
+
print("=" * 30)
|
| 193 |
+
|
| 194 |
+
print("\nNext Steps:")
|
| 195 |
+
print("1. Quick Demo:")
|
| 196 |
+
print(" python demo.py")
|
| 197 |
+
|
| 198 |
+
print("\n2. Train Your Model:")
|
| 199 |
+
print(" # Add your documents to data/raw/")
|
| 200 |
+
print(" # Then run:")
|
| 201 |
+
print(" python src/training_pipeline.py")
|
| 202 |
+
|
| 203 |
+
print("\n3. 🌐 Start Web API:")
|
| 204 |
+
print(" python api/app.py")
|
| 205 |
+
print(" # Then open: http://localhost:8000")
|
| 206 |
+
|
| 207 |
+
print("\n4. Run Tests:")
|
| 208 |
+
print(" python tests/test_extraction.py")
|
| 209 |
+
|
| 210 |
+
print("\n5. 📚 Documentation:")
|
| 211 |
+
print(" # View README.md for detailed usage")
|
| 212 |
+
print(" # API docs: http://localhost:8000/docs")
|
| 213 |
+
|
| 214 |
+
print("\nPro Tips:")
|
| 215 |
+
print(" - Place your documents in data/raw/ for training")
|
| 216 |
+
print(" - Use GPU for faster training (if available)")
|
| 217 |
+
print(" - Adjust batch_size in config if you get memory errors")
|
| 218 |
+
print(" - Check logs/ directory for debugging information")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def main():
|
| 222 |
+
"""Main setup function."""
|
| 223 |
+
print("DOCUMENT TEXT EXTRACTION - SETUP SCRIPT")
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
|
| 226 |
+
# Check Python version
|
| 227 |
+
if not check_python_version():
|
| 228 |
+
return False
|
| 229 |
+
|
| 230 |
+
# Create directories
|
| 231 |
+
create_directories()
|
| 232 |
+
|
| 233 |
+
# Check and install dependencies
|
| 234 |
+
missing_packages = check_dependencies()
|
| 235 |
+
if missing_packages:
|
| 236 |
+
print(f"\nMissing packages: {', '.join(missing_packages)}")
|
| 237 |
+
install_deps = input("Install missing dependencies? (y/n): ").lower().strip()
|
| 238 |
+
|
| 239 |
+
if install_deps == 'y':
|
| 240 |
+
if not install_dependencies():
|
| 241 |
+
print("Failed to install dependencies. Please install manually:")
|
| 242 |
+
print(" pip install -r requirements.txt")
|
| 243 |
+
return False
|
| 244 |
+
else:
|
| 245 |
+
print("Some features may not work without required dependencies.")
|
| 246 |
+
|
| 247 |
+
# Check OCR dependencies
|
| 248 |
+
check_ocr_dependencies()
|
| 249 |
+
|
| 250 |
+
# Check GPU support
|
| 251 |
+
check_gpu_support()
|
| 252 |
+
|
| 253 |
+
# Create sample documents
|
| 254 |
+
create_sample_documents()
|
| 255 |
+
|
| 256 |
+
# Run initial test
|
| 257 |
+
if not run_initial_test():
|
| 258 |
+
print("Setup test failed. Some features may not work correctly.")
|
| 259 |
+
print(" Check error messages above and ensure all dependencies are installed.")
|
| 260 |
+
|
| 261 |
+
# Display next steps
|
| 262 |
+
display_next_steps()
|
| 263 |
+
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
success = main()
|
| 269 |
+
|
| 270 |
+
if success:
|
| 271 |
+
print(f"\nSetup completed! Ready to extract text from documents!")
|
| 272 |
+
else:
|
| 273 |
+
print(f"\nSetup encountered issues. Please check the messages above.")
|
| 274 |
+
sys.exit(1)
|
simple_api.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simplified Document Text Extraction API
|
| 4 |
+
Uses regex patterns instead of ML model for demonstration
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Dict, List, Any, Optional
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
# Add current directory to Python path
|
| 16 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 20 |
+
from fastapi.responses import HTMLResponse, FileResponse
|
| 21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
import uvicorn
|
| 24 |
+
HAS_FASTAPI = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
print("FastAPI not installed. Install with: pip install fastapi uvicorn python-multipart")
|
| 27 |
+
HAS_FASTAPI = False
|
| 28 |
+
|
| 29 |
+
class SimpleDocumentProcessor:
|
| 30 |
+
"""Simplified document processor using regex patterns"""
|
| 31 |
+
|
| 32 |
+
def __init__(self):
|
| 33 |
+
# Define regex patterns for different entity types
|
| 34 |
+
self.patterns = {
|
| 35 |
+
'NAME': [
|
| 36 |
+
r'\b(?:Mr\.|Mrs\.|Ms\.|Dr\.|Prof\.)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
|
| 37 |
+
r'\b([A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b',
|
| 38 |
+
r'(?:Invoice|Bill|Receipt)\s+(?:sent\s+)?(?:to\s+|for\s+)?([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
|
| 39 |
+
],
|
| 40 |
+
'DATE': [
|
| 41 |
+
r'\b(\d{1,2}[\/\-]\d{1,2}[\/\-]\d{2,4})\b',
|
| 42 |
+
r'\b(\d{2,4}[\/\-]\d{1,2}[\/\-]\d{1,2})\b',
|
| 43 |
+
r'\b((?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{2,4})\b',
|
| 44 |
+
r'\b((?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2},?\s+\d{2,4})\b',
|
| 45 |
+
],
|
| 46 |
+
'AMOUNT': [
|
| 47 |
+
r'\$\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)',
|
| 48 |
+
r'(?:Amount|Total|Sum):\s*\$?\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)',
|
| 49 |
+
r'(\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|dollars?))',
|
| 50 |
+
],
|
| 51 |
+
'INVOICE_NO': [
|
| 52 |
+
r'(?:Invoice|Bill|Receipt)(?:\s+No\.?|#|Number):\s*([A-Z]{2,4}[-\s]?\d{3,6})',
|
| 53 |
+
r'(?:INV|BL|REC)[-\s]?(\d{3,6})',
|
| 54 |
+
r'Reference:\s*([A-Z]{2,4}[-\s]?\d{3,6})',
|
| 55 |
+
],
|
| 56 |
+
'EMAIL': [
|
| 57 |
+
r'\b([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})\b',
|
| 58 |
+
],
|
| 59 |
+
'PHONE': [
|
| 60 |
+
r'\b(\+?1[-.\s]?\(?[2-9]\d{2}\)?[-.\s]?\d{3}[-.\s]?\d{4})\b',
|
| 61 |
+
r'\b(\([2-9]\d{2}\)\s*[2-9]\d{2}[-.\s]?\d{4})\b',
|
| 62 |
+
r'\b([2-9]\d{2}[-.\s]?[2-9]\d{2}[-.\s]?\d{4})\b',
|
| 63 |
+
],
|
| 64 |
+
'ADDRESS': [
|
| 65 |
+
r'\b(\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Lane|Ln|Drive|Dr|Boulevard|Blvd|Way))\b',
|
| 66 |
+
]
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# Confidence scores for different entity types
|
| 70 |
+
self.confidence_scores = {
|
| 71 |
+
'NAME': 0.80,
|
| 72 |
+
'DATE': 0.85,
|
| 73 |
+
'AMOUNT': 0.85,
|
| 74 |
+
'INVOICE_NO': 0.90,
|
| 75 |
+
'EMAIL': 0.95,
|
| 76 |
+
'PHONE': 0.90,
|
| 77 |
+
'ADDRESS': 0.75
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def extract_entities(self, text: str) -> List[Dict[str, Any]]:
|
| 81 |
+
"""Extract entities from text using regex patterns"""
|
| 82 |
+
entities = []
|
| 83 |
+
|
| 84 |
+
for entity_type, patterns in self.patterns.items():
|
| 85 |
+
for pattern in patterns:
|
| 86 |
+
matches = re.finditer(pattern, text, re.IGNORECASE)
|
| 87 |
+
for match in matches:
|
| 88 |
+
entity = {
|
| 89 |
+
'entity': entity_type,
|
| 90 |
+
'text': match.group(1) if match.groups() else match.group(0),
|
| 91 |
+
'start': match.start(),
|
| 92 |
+
'end': match.end(),
|
| 93 |
+
'confidence': self.confidence_scores[entity_type]
|
| 94 |
+
}
|
| 95 |
+
entities.append(entity)
|
| 96 |
+
|
| 97 |
+
return entities
|
| 98 |
+
|
| 99 |
+
def create_structured_data(self, entities: List[Dict]) -> Dict[str, str]:
|
| 100 |
+
"""Create structured data from extracted entities"""
|
| 101 |
+
structured = {}
|
| 102 |
+
|
| 103 |
+
# Get the best entity for each type
|
| 104 |
+
entity_groups = {}
|
| 105 |
+
for entity in entities:
|
| 106 |
+
entity_type = entity['entity']
|
| 107 |
+
if entity_type not in entity_groups:
|
| 108 |
+
entity_groups[entity_type] = []
|
| 109 |
+
entity_groups[entity_type].append(entity)
|
| 110 |
+
|
| 111 |
+
# Select best entity for each type
|
| 112 |
+
for entity_type, group in entity_groups.items():
|
| 113 |
+
if group:
|
| 114 |
+
# Sort by confidence and take the best one
|
| 115 |
+
best_entity = max(group, key=lambda x: x['confidence'])
|
| 116 |
+
|
| 117 |
+
# Format field names
|
| 118 |
+
field_mapping = {
|
| 119 |
+
'NAME': 'Name',
|
| 120 |
+
'DATE': 'Date',
|
| 121 |
+
'AMOUNT': 'Amount',
|
| 122 |
+
'INVOICE_NO': 'InvoiceNo',
|
| 123 |
+
'EMAIL': 'Email',
|
| 124 |
+
'PHONE': 'Phone',
|
| 125 |
+
'ADDRESS': 'Address'
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
field_name = field_mapping.get(entity_type, entity_type)
|
| 129 |
+
structured[field_name] = best_entity['text']
|
| 130 |
+
|
| 131 |
+
return structured
|
| 132 |
+
|
| 133 |
+
def process_text(self, text: str) -> Dict[str, Any]:
|
| 134 |
+
"""Process text and extract structured information"""
|
| 135 |
+
entities = self.extract_entities(text)
|
| 136 |
+
structured_data = self.create_structured_data(entities)
|
| 137 |
+
|
| 138 |
+
# Get unique entity types
|
| 139 |
+
entity_types = list(set(entity['entity'] for entity in entities))
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
'status': 'success',
|
| 143 |
+
'data': {
|
| 144 |
+
'original_text': text,
|
| 145 |
+
'entities': entities,
|
| 146 |
+
'structured_data': structured_data,
|
| 147 |
+
'processing_timestamp': datetime.now().isoformat(),
|
| 148 |
+
'total_entities_found': len(entities),
|
| 149 |
+
'entity_types_found': sorted(entity_types)
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
# Pydantic models for API
|
| 154 |
+
if HAS_FASTAPI:
|
| 155 |
+
class TextRequest(BaseModel):
|
| 156 |
+
text: str
|
| 157 |
+
|
| 158 |
+
def create_app():
|
| 159 |
+
"""Create and configure FastAPI app"""
|
| 160 |
+
if not HAS_FASTAPI:
|
| 161 |
+
raise ImportError("FastAPI dependencies not installed")
|
| 162 |
+
|
| 163 |
+
app = FastAPI(
|
| 164 |
+
title="Simple Document Text Extraction API",
|
| 165 |
+
description="Extract structured information from documents using regex patterns",
|
| 166 |
+
version="1.0.0"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Enable CORS
|
| 170 |
+
app.add_middleware(
|
| 171 |
+
CORSMiddleware,
|
| 172 |
+
allow_origins=["*"],
|
| 173 |
+
allow_credentials=True,
|
| 174 |
+
allow_methods=["*"],
|
| 175 |
+
allow_headers=["*"],
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Initialize processor
|
| 179 |
+
processor = SimpleDocumentProcessor()
|
| 180 |
+
|
| 181 |
+
@app.get("/", response_class=HTMLResponse)
|
| 182 |
+
async def get_interface():
|
| 183 |
+
"""Serve the web interface"""
|
| 184 |
+
return """
|
| 185 |
+
<!DOCTYPE html>
|
| 186 |
+
<html>
|
| 187 |
+
<head>
|
| 188 |
+
<title>Document Text Extraction Demo</title>
|
| 189 |
+
<style>
|
| 190 |
+
body {
|
| 191 |
+
font-family: Arial, sans-serif;
|
| 192 |
+
max-width: 1200px;
|
| 193 |
+
margin: 0 auto;
|
| 194 |
+
padding: 20px;
|
| 195 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 196 |
+
color: #333;
|
| 197 |
+
}
|
| 198 |
+
.container {
|
| 199 |
+
background: white;
|
| 200 |
+
padding: 30px;
|
| 201 |
+
border-radius: 10px;
|
| 202 |
+
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
| 203 |
+
}
|
| 204 |
+
.header {
|
| 205 |
+
text-align: center;
|
| 206 |
+
margin-bottom: 30px;
|
| 207 |
+
}
|
| 208 |
+
.header h1 {
|
| 209 |
+
color: #2c3e50;
|
| 210 |
+
font-size: 2.5em;
|
| 211 |
+
margin-bottom: 10px;
|
| 212 |
+
}
|
| 213 |
+
.header p {
|
| 214 |
+
color: #7f8c8d;
|
| 215 |
+
font-size: 1.2em;
|
| 216 |
+
}
|
| 217 |
+
.tabs {
|
| 218 |
+
display: flex;
|
| 219 |
+
margin-bottom: 20px;
|
| 220 |
+
}
|
| 221 |
+
.tab {
|
| 222 |
+
flex: 1;
|
| 223 |
+
text-align: center;
|
| 224 |
+
padding: 15px;
|
| 225 |
+
background: #ecf0f1;
|
| 226 |
+
border: none;
|
| 227 |
+
cursor: pointer;
|
| 228 |
+
font-size: 16px;
|
| 229 |
+
transition: background 0.3s;
|
| 230 |
+
}
|
| 231 |
+
.tab.active {
|
| 232 |
+
background: #3498db;
|
| 233 |
+
color: white;
|
| 234 |
+
}
|
| 235 |
+
.tab:hover {
|
| 236 |
+
background: #3498db;
|
| 237 |
+
color: white;
|
| 238 |
+
}
|
| 239 |
+
.tab-content {
|
| 240 |
+
display: none;
|
| 241 |
+
padding: 20px;
|
| 242 |
+
border: 1px solid #ddd;
|
| 243 |
+
border-radius: 5px;
|
| 244 |
+
}
|
| 245 |
+
.tab-content.active {
|
| 246 |
+
display: block;
|
| 247 |
+
}
|
| 248 |
+
textarea {
|
| 249 |
+
width: 100%;
|
| 250 |
+
height: 150px;
|
| 251 |
+
margin-bottom: 15px;
|
| 252 |
+
padding: 10px;
|
| 253 |
+
border: 1px solid #ddd;
|
| 254 |
+
border-radius: 5px;
|
| 255 |
+
font-size: 14px;
|
| 256 |
+
}
|
| 257 |
+
input[type="file"] {
|
| 258 |
+
margin-bottom: 15px;
|
| 259 |
+
padding: 10px;
|
| 260 |
+
}
|
| 261 |
+
button {
|
| 262 |
+
background: #27ae60;
|
| 263 |
+
color: white;
|
| 264 |
+
padding: 12px 25px;
|
| 265 |
+
border: none;
|
| 266 |
+
border-radius: 5px;
|
| 267 |
+
cursor: pointer;
|
| 268 |
+
font-size: 16px;
|
| 269 |
+
transition: background 0.3s;
|
| 270 |
+
}
|
| 271 |
+
button:hover {
|
| 272 |
+
background: #2ecc71;
|
| 273 |
+
}
|
| 274 |
+
.results {
|
| 275 |
+
margin-top: 20px;
|
| 276 |
+
padding: 20px;
|
| 277 |
+
background: #f8f9fa;
|
| 278 |
+
border-radius: 5px;
|
| 279 |
+
border-left: 4px solid #27ae60;
|
| 280 |
+
}
|
| 281 |
+
.entity {
|
| 282 |
+
background: #e8f4fd;
|
| 283 |
+
padding: 8px 12px;
|
| 284 |
+
margin: 5px;
|
| 285 |
+
border-radius: 20px;
|
| 286 |
+
display: inline-block;
|
| 287 |
+
font-size: 12px;
|
| 288 |
+
border: 1px solid #3498db;
|
| 289 |
+
}
|
| 290 |
+
.entity.NAME { background: #ffeb3b; border-color: #ff9800; }
|
| 291 |
+
.entity.DATE { background: #4caf50; border-color: #2e7d32; color: white; }
|
| 292 |
+
.entity.AMOUNT { background: #f44336; border-color: #c62828; color: white; }
|
| 293 |
+
.entity.INVOICE_NO { background: #9c27b0; border-color: #6a1b9a; color: white; }
|
| 294 |
+
.entity.EMAIL { background: #00bcd4; border-color: #00838f; color: white; }
|
| 295 |
+
.entity.PHONE { background: #ff5722; border-color: #d84315; color: white; }
|
| 296 |
+
.entity.ADDRESS { background: #795548; border-color: #5d4037; color: white; }
|
| 297 |
+
.structured-data {
|
| 298 |
+
background: #e8f5e8;
|
| 299 |
+
padding: 15px;
|
| 300 |
+
border-radius: 5px;
|
| 301 |
+
margin-top: 15px;
|
| 302 |
+
}
|
| 303 |
+
.examples {
|
| 304 |
+
background: #fff3cd;
|
| 305 |
+
padding: 15px;
|
| 306 |
+
border-radius: 5px;
|
| 307 |
+
margin-top: 20px;
|
| 308 |
+
}
|
| 309 |
+
.example-btn {
|
| 310 |
+
background: #6c757d;
|
| 311 |
+
font-size: 12px;
|
| 312 |
+
padding: 5px 10px;
|
| 313 |
+
margin: 2px;
|
| 314 |
+
}
|
| 315 |
+
pre {
|
| 316 |
+
background: #f8f9fa;
|
| 317 |
+
padding: 15px;
|
| 318 |
+
border-radius: 5px;
|
| 319 |
+
overflow-x: auto;
|
| 320 |
+
font-size: 12px;
|
| 321 |
+
border: 1px solid #dee2e6;
|
| 322 |
+
}
|
| 323 |
+
</style>
|
| 324 |
+
</head>
|
| 325 |
+
<body>
|
| 326 |
+
<div class="container">
|
| 327 |
+
<div class="header">
|
| 328 |
+
<h1> Document Text Extraction</h1>
|
| 329 |
+
<p>Extract structured information from documents using AI patterns</p>
|
| 330 |
+
</div>
|
| 331 |
+
|
| 332 |
+
<div class="tabs">
|
| 333 |
+
<button class="tab active" onclick="showTab('text')">Enter Text</button>
|
| 334 |
+
<button class="tab" onclick="showTab('file')">Upload File</button>
|
| 335 |
+
<button class="tab" onclick="showTab('api')">API Docs</button>
|
| 336 |
+
</div>
|
| 337 |
+
|
| 338 |
+
<div id="text-tab" class="tab-content active">
|
| 339 |
+
<h3>Enter Text to Extract:</h3>
|
| 340 |
+
<textarea id="textInput" placeholder="Paste your document text here...">Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250.00 Phone: (555) 123-4567 Email: robert.white@email.com</textarea>
|
| 341 |
+
<button onclick="extractFromText()">Extract Information</button>
|
| 342 |
+
|
| 343 |
+
<div class="examples">
|
| 344 |
+
<h4>Try These Examples:</h4>
|
| 345 |
+
<button class="example-btn" onclick="useExample(0)">Invoice Example</button>
|
| 346 |
+
<button class="example-btn" onclick="useExample(1)">Receipt Example</button>
|
| 347 |
+
<button class="example-btn" onclick="useExample(2)">Business Document</button>
|
| 348 |
+
<button class="example-btn" onclick="useExample(3)">Payment Notice</button>
|
| 349 |
+
</div>
|
| 350 |
+
</div>
|
| 351 |
+
|
| 352 |
+
<div id="file-tab" class="tab-content">
|
| 353 |
+
<h3>Upload Document:</h3>
|
| 354 |
+
<input type="file" id="fileInput" accept=".pdf,.docx,.txt,.jpg,.png,.tiff">
|
| 355 |
+
<br>
|
| 356 |
+
<button onclick="extractFromFile()">Upload & Extract</button>
|
| 357 |
+
<p><em>Note: File upload processing is simplified in this demo</em></p>
|
| 358 |
+
</div>
|
| 359 |
+
|
| 360 |
+
<div id="api-tab" class="tab-content">
|
| 361 |
+
<h3>API Documentation</h3>
|
| 362 |
+
<h4>Endpoints:</h4>
|
| 363 |
+
<pre><strong>POST /extract-from-text</strong>
|
| 364 |
+
Content-Type: application/json
|
| 365 |
+
{
|
| 366 |
+
"text": "Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
| 367 |
+
}</pre>
|
| 368 |
+
|
| 369 |
+
<pre><strong>POST /extract-from-file</strong>
|
| 370 |
+
Content-Type: multipart/form-data
|
| 371 |
+
file: [uploaded file]</pre>
|
| 372 |
+
|
| 373 |
+
<h4>Response Format:</h4>
|
| 374 |
+
<pre>{
|
| 375 |
+
"status": "success",
|
| 376 |
+
"data": {
|
| 377 |
+
"original_text": "...",
|
| 378 |
+
"entities": [...],
|
| 379 |
+
"structured_data": {...},
|
| 380 |
+
"processing_timestamp": "2025-09-27T...",
|
| 381 |
+
"total_entities_found": 7,
|
| 382 |
+
"entity_types_found": ["NAME", "DATE", "AMOUNT", "INVOICE_NO"]
|
| 383 |
+
}
|
| 384 |
+
}</pre>
|
| 385 |
+
</div>
|
| 386 |
+
|
| 387 |
+
<div id="results"></div>
|
| 388 |
+
</div>
|
| 389 |
+
|
| 390 |
+
<script>
|
| 391 |
+
const examples = [
|
| 392 |
+
"Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250.00 Phone: (555) 123-4567 Email: robert.white@email.com",
|
| 393 |
+
"Receipt for Michael Brown Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75 Contact: +1-555-987-6543",
|
| 394 |
+
"Ms. Emma Wilson 456 Oak Street Payment due: January 15, 2025 Reference: INV-4567 Total: $1,750.25",
|
| 395 |
+
"Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Email: sarah.johnson@email.com"
|
| 396 |
+
];
|
| 397 |
+
|
| 398 |
+
function showTab(tabName) {
|
| 399 |
+
// Hide all tabs
|
| 400 |
+
document.querySelectorAll('.tab-content').forEach(content => {
|
| 401 |
+
content.classList.remove('active');
|
| 402 |
+
});
|
| 403 |
+
document.querySelectorAll('.tab').forEach(tab => {
|
| 404 |
+
tab.classList.remove('active');
|
| 405 |
+
});
|
| 406 |
+
|
| 407 |
+
// Show selected tab
|
| 408 |
+
document.getElementById(tabName + '-tab').classList.add('active');
|
| 409 |
+
event.target.classList.add('active');
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
function useExample(index) {
|
| 413 |
+
document.getElementById('textInput').value = examples[index];
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
async function extractFromText() {
|
| 417 |
+
const text = document.getElementById('textInput').value;
|
| 418 |
+
if (!text.trim()) {
|
| 419 |
+
alert('Please enter some text');
|
| 420 |
+
return;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
try {
|
| 424 |
+
const response = await fetch('/extract-from-text', {
|
| 425 |
+
method: 'POST',
|
| 426 |
+
headers: {
|
| 427 |
+
'Content-Type': 'application/json',
|
| 428 |
+
},
|
| 429 |
+
body: JSON.stringify({ text: text })
|
| 430 |
+
});
|
| 431 |
+
|
| 432 |
+
const result = await response.json();
|
| 433 |
+
displayResults(result);
|
| 434 |
+
} catch (error) {
|
| 435 |
+
alert('Error: ' + error.message);
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
async function extractFromFile() {
|
| 440 |
+
const fileInput = document.getElementById('fileInput');
|
| 441 |
+
if (!fileInput.files[0]) {
|
| 442 |
+
alert('Please select a file');
|
| 443 |
+
return;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
// For demo purposes, show that file upload would work
|
| 447 |
+
alert('File upload processing would happen here. For now, using sample text extraction.');
|
| 448 |
+
document.getElementById('textInput').value = examples[0];
|
| 449 |
+
showTab('text');
|
| 450 |
+
extractFromText();
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
function displayResults(result) {
|
| 454 |
+
const resultsDiv = document.getElementById('results');
|
| 455 |
+
|
| 456 |
+
if (result.status !== 'success') {
|
| 457 |
+
resultsDiv.innerHTML = '<div class="results"><h3>Error</h3><p>' + result.message + '</p></div>';
|
| 458 |
+
return;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
const data = result.data;
|
| 462 |
+
let html = '<div class="results">';
|
| 463 |
+
html += '<h3>Extraction Results</h3>';
|
| 464 |
+
html += '<p><strong>Found:</strong> ' + data.total_entities_found + ' entities of ' + data.entity_types_found.length + ' types</p>';
|
| 465 |
+
|
| 466 |
+
// Show entities
|
| 467 |
+
html += '<h4>Detected Entities:</h4>';
|
| 468 |
+
data.entities.forEach(entity => {
|
| 469 |
+
html += '<span class="entity ' + entity.entity + '">' + entity.entity + ': ' + entity.text + ' (' + Math.round(entity.confidence * 100) + '%)</span> ';
|
| 470 |
+
});
|
| 471 |
+
|
| 472 |
+
// Show structured data
|
| 473 |
+
if (Object.keys(data.structured_data).length > 0) {
|
| 474 |
+
html += '<div class="structured-data">';
|
| 475 |
+
html += '<h4>Structured Information:</h4>';
|
| 476 |
+
html += '<ul>';
|
| 477 |
+
for (const [key, value] of Object.entries(data.structured_data)) {
|
| 478 |
+
html += '<li><strong>' + key + ':</strong> ' + value + '</li>';
|
| 479 |
+
}
|
| 480 |
+
html += '</ul>';
|
| 481 |
+
html += '</div>';
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
// Show processing info
|
| 485 |
+
html += '<p><small>🕒 Processed at: ' + new Date(data.processing_timestamp).toLocaleString() + '</small></p>';
|
| 486 |
+
html += '</div>';
|
| 487 |
+
|
| 488 |
+
resultsDiv.innerHTML = html;
|
| 489 |
+
}
|
| 490 |
+
</script>
|
| 491 |
+
</body>
|
| 492 |
+
</html>
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
@app.post("/extract-from-text")
|
| 496 |
+
async def extract_from_text(request: TextRequest):
|
| 497 |
+
"""Extract entities from text"""
|
| 498 |
+
try:
|
| 499 |
+
result = processor.process_text(request.text)
|
| 500 |
+
return result
|
| 501 |
+
except Exception as e:
|
| 502 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 503 |
+
|
| 504 |
+
@app.post("/extract-from-file")
|
| 505 |
+
async def extract_from_file(file: UploadFile = File(...)):
|
| 506 |
+
"""Extract entities from uploaded file"""
|
| 507 |
+
try:
|
| 508 |
+
# Read file content
|
| 509 |
+
content = await file.read()
|
| 510 |
+
|
| 511 |
+
# For demo purposes, convert to text (simplified)
|
| 512 |
+
if file.filename.lower().endswith('.txt'):
|
| 513 |
+
text = content.decode('utf-8')
|
| 514 |
+
else:
|
| 515 |
+
# For other file types, use sample text in demo
|
| 516 |
+
text = "Demo processing for " + file.filename + ": Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
| 517 |
+
|
| 518 |
+
result = processor.process_text(text)
|
| 519 |
+
return result
|
| 520 |
+
|
| 521 |
+
except Exception as e:
|
| 522 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 523 |
+
|
| 524 |
+
@app.get("/health")
|
| 525 |
+
async def health_check():
|
| 526 |
+
"""Health check endpoint"""
|
| 527 |
+
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
|
| 528 |
+
|
| 529 |
+
return app
|
| 530 |
+
|
| 531 |
+
def main():
|
| 532 |
+
"""Main function to run the API server"""
|
| 533 |
+
if not HAS_FASTAPI:
|
| 534 |
+
print("FastAPI dependencies not installed.")
|
| 535 |
+
print("📦 Install with: pip install fastapi uvicorn python-multipart")
|
| 536 |
+
return
|
| 537 |
+
|
| 538 |
+
print("Starting Simple Document Text Extraction API...")
|
| 539 |
+
print("Access the web interface at: http://localhost:7000")
|
| 540 |
+
print("API documentation at: http://localhost:7000/docs")
|
| 541 |
+
print("Health check at: http://localhost:7000/health")
|
| 542 |
+
print("\nServer starting...")
|
| 543 |
+
|
| 544 |
+
app = create_app()
|
| 545 |
+
uvicorn.run(app, host="0.0.0.0", port=7000, log_level="info")
|
| 546 |
+
|
| 547 |
+
if __name__ == "__main__":
|
| 548 |
+
main()
|
simple_demo.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified demo of document text extraction without heavy ML dependencies.
|
| 3 |
+
This demonstrates the core workflow and patterns without requiring PyTorch/Transformers.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Tuple, Any
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SimpleDocumentProcessor:
|
| 14 |
+
"""Simplified document processor for demo purposes."""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""Initialize with regex patterns for entity extraction."""
|
| 18 |
+
self.entity_patterns = {
|
| 19 |
+
'NAME': [
|
| 20 |
+
r'\b(?:Mr\.|Mrs\.|Ms\.|Dr\.)\s+([A-Z][a-z]+ [A-Z][a-z]+)\b',
|
| 21 |
+
r'\b([A-Z][a-z]+ [A-Z][a-z]+)\b',
|
| 22 |
+
],
|
| 23 |
+
'DATE': [
|
| 24 |
+
r'\b(\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4})\b',
|
| 25 |
+
r'\b(\d{4}[/\-]\d{1,2}[/\-]\d{1,2})\b',
|
| 26 |
+
r'\b((?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4})\b'
|
| 27 |
+
],
|
| 28 |
+
'INVOICE_NO': [
|
| 29 |
+
r'(?:Invoice\s+(?:No|Number|#):\s*)?([A-Z]{2,4}[-]?\d{3,6})',
|
| 30 |
+
r'(INV[-]?\d{3,6})',
|
| 31 |
+
r'(BL[-]?\d{3,6})',
|
| 32 |
+
r'(REC[-]?\d{3,6})',
|
| 33 |
+
],
|
| 34 |
+
'AMOUNT': [
|
| 35 |
+
r'(\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?)',
|
| 36 |
+
r'(\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|EUR|GBP))',
|
| 37 |
+
],
|
| 38 |
+
'PHONE': [
|
| 39 |
+
r'(\+?\d{1,3}[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4})',
|
| 40 |
+
r'(\(\d{3}\)\s*\d{3}-\d{4})',
|
| 41 |
+
],
|
| 42 |
+
'EMAIL': [
|
| 43 |
+
r'\b([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,})\b',
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
def extract_entities(self, text: str) -> List[Dict[str, Any]]:
|
| 48 |
+
"""Extract entities from text using regex patterns."""
|
| 49 |
+
entities = []
|
| 50 |
+
|
| 51 |
+
for entity_type, patterns in self.entity_patterns.items():
|
| 52 |
+
for pattern in patterns:
|
| 53 |
+
matches = re.finditer(pattern, text, re.IGNORECASE)
|
| 54 |
+
for match in matches:
|
| 55 |
+
entity_text = match.group(1) if match.groups() else match.group(0)
|
| 56 |
+
entities.append({
|
| 57 |
+
'entity': entity_type,
|
| 58 |
+
'text': entity_text.strip(),
|
| 59 |
+
'start': match.start(),
|
| 60 |
+
'end': match.end(),
|
| 61 |
+
'confidence': self.get_confidence_score(entity_type)
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
return entities
|
| 65 |
+
|
| 66 |
+
def get_confidence_score(self, entity_type: str) -> float:
|
| 67 |
+
"""Get confidence score for entity type."""
|
| 68 |
+
confidence_map = {
|
| 69 |
+
'NAME': 0.80,
|
| 70 |
+
'DATE': 0.85,
|
| 71 |
+
'AMOUNT': 0.85,
|
| 72 |
+
'INVOICE_NO': 0.90,
|
| 73 |
+
'EMAIL': 0.95,
|
| 74 |
+
'PHONE': 0.90,
|
| 75 |
+
'ADDRESS': 0.75
|
| 76 |
+
}
|
| 77 |
+
return confidence_map.get(entity_type, 0.70)
|
| 78 |
+
|
| 79 |
+
def create_structured_data(self, entities: List[Dict[str, Any]]) -> Dict[str, str]:
|
| 80 |
+
"""Create structured data from entities."""
|
| 81 |
+
structured = {}
|
| 82 |
+
|
| 83 |
+
# Group entities by type
|
| 84 |
+
entity_groups = {}
|
| 85 |
+
for entity in entities:
|
| 86 |
+
entity_type = entity['entity']
|
| 87 |
+
if entity_type not in entity_groups:
|
| 88 |
+
entity_groups[entity_type] = []
|
| 89 |
+
entity_groups[entity_type].append(entity)
|
| 90 |
+
|
| 91 |
+
# Select best entity for each type
|
| 92 |
+
for entity_type, group in entity_groups.items():
|
| 93 |
+
if group:
|
| 94 |
+
# Sort by confidence and length, take the best one
|
| 95 |
+
best_entity = max(group, key=lambda x: (x['confidence'], len(x['text'])))
|
| 96 |
+
|
| 97 |
+
# Map to structured field names
|
| 98 |
+
field_mapping = {
|
| 99 |
+
'NAME': 'Name',
|
| 100 |
+
'DATE': 'Date',
|
| 101 |
+
'AMOUNT': 'Amount',
|
| 102 |
+
'INVOICE_NO': 'InvoiceNo',
|
| 103 |
+
'EMAIL': 'Email',
|
| 104 |
+
'PHONE': 'Phone',
|
| 105 |
+
'ADDRESS': 'Address'
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
field_name = field_mapping.get(entity_type, entity_type)
|
| 109 |
+
structured[field_name] = best_entity['text']
|
| 110 |
+
|
| 111 |
+
return structured
|
| 112 |
+
|
| 113 |
+
def process_document(self, text: str) -> Dict[str, Any]:
|
| 114 |
+
"""Process document text and extract information."""
|
| 115 |
+
entities = self.extract_entities(text)
|
| 116 |
+
structured_data = self.create_structured_data(entities)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
'text': text,
|
| 120 |
+
'entities': entities,
|
| 121 |
+
'structured_data': structured_data,
|
| 122 |
+
'entity_count': len(entities),
|
| 123 |
+
'entity_types': list(set(e['entity'] for e in entities))
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def run_demo():
|
| 128 |
+
"""Run the simplified document extraction demo."""
|
| 129 |
+
|
| 130 |
+
print("SIMPLIFIED DOCUMENT TEXT EXTRACTION DEMO")
|
| 131 |
+
print("=" * 60)
|
| 132 |
+
print("This demo shows the core extraction logic using regex patterns")
|
| 133 |
+
print("(without the full ML pipeline for demonstration purposes)")
|
| 134 |
+
print()
|
| 135 |
+
|
| 136 |
+
# Initialize processor
|
| 137 |
+
processor = SimpleDocumentProcessor()
|
| 138 |
+
|
| 139 |
+
# Sample documents
|
| 140 |
+
sample_documents = [
|
| 141 |
+
{
|
| 142 |
+
"name": "Invoice Example 1",
|
| 143 |
+
"text": "Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250.00 Phone: (555) 123-4567 Email: robert.white@email.com"
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"name": "Invoice Example 2",
|
| 147 |
+
"text": "Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Email: sarah.johnson@email.com"
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"name": "Receipt Example",
|
| 151 |
+
"text": "Receipt for Michael Brown Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75 Contact: +1-555-987-6543"
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"name": "Business Document",
|
| 155 |
+
"text": "Ms. Emma Wilson 456 Oak Street Payment due: January 15, 2025 Reference: INV-4567 Total: $1,750.25"
|
| 156 |
+
}
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
# Process each document
|
| 160 |
+
all_results = []
|
| 161 |
+
total_entities = 0
|
| 162 |
+
all_entity_types = set()
|
| 163 |
+
|
| 164 |
+
for i, doc in enumerate(sample_documents, 1):
|
| 165 |
+
print(f"\nDocument {i}: {doc['name']}")
|
| 166 |
+
print("-" * 50)
|
| 167 |
+
print(f"Text: {doc['text']}")
|
| 168 |
+
print()
|
| 169 |
+
|
| 170 |
+
# Process document
|
| 171 |
+
result = processor.process_document(doc['text'])
|
| 172 |
+
all_results.append(result)
|
| 173 |
+
|
| 174 |
+
# Update totals
|
| 175 |
+
total_entities += result['entity_count']
|
| 176 |
+
all_entity_types.update(result['entity_types'])
|
| 177 |
+
|
| 178 |
+
print(f"Extraction Results:")
|
| 179 |
+
print(f" Found {result['entity_count']} entities")
|
| 180 |
+
print(f" Entity types: {', '.join(result['entity_types'])}")
|
| 181 |
+
|
| 182 |
+
# Show structured data if available
|
| 183 |
+
if result['structured_data']:
|
| 184 |
+
print(f"\nStructured Information:")
|
| 185 |
+
for key, value in result['structured_data'].items():
|
| 186 |
+
print(f" {key}: {value}")
|
| 187 |
+
|
| 188 |
+
# Show detailed entities
|
| 189 |
+
if result['entities']:
|
| 190 |
+
print(f"\nDetailed Entities:")
|
| 191 |
+
for entity in result['entities']:
|
| 192 |
+
print(f" {entity['entity']}: '{entity['text']}' (confidence: {entity['confidence']*100:.0f}%)")
|
| 193 |
+
|
| 194 |
+
# Save results
|
| 195 |
+
output_dir = Path("results")
|
| 196 |
+
output_dir.mkdir(exist_ok=True)
|
| 197 |
+
output_file = output_dir / "demo_extraction_results.json"
|
| 198 |
+
|
| 199 |
+
# Prepare output data
|
| 200 |
+
output_data = {
|
| 201 |
+
'demo_info': {
|
| 202 |
+
'timestamp': datetime.now().isoformat(),
|
| 203 |
+
'documents_processed': len(sample_documents),
|
| 204 |
+
'total_entities_found': total_entities,
|
| 205 |
+
'unique_entity_types': sorted(list(all_entity_types))
|
| 206 |
+
},
|
| 207 |
+
'results': all_results
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# Save to file
|
| 211 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 212 |
+
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
| 213 |
+
|
| 214 |
+
print(f"\nResults saved to: {output_file}")
|
| 215 |
+
|
| 216 |
+
print(f"\nDemo Summary:")
|
| 217 |
+
print(f" Documents processed: {len(sample_documents)}")
|
| 218 |
+
print(f" Total entities found: {total_entities}")
|
| 219 |
+
print(f" Total structured fields: {sum(len(r['structured_data']) for r in all_results)}")
|
| 220 |
+
print(f" Unique entity types: {', '.join(sorted(all_entity_types))}")
|
| 221 |
+
|
| 222 |
+
print(f"\nDemo completed successfully!")
|
| 223 |
+
|
| 224 |
+
print(f"\nThis demonstrates the core extraction logic.")
|
| 225 |
+
print(f" The full system would add:")
|
| 226 |
+
print(f" - OCR for scanned documents")
|
| 227 |
+
print(f" - ML model (DistilBERT) for better accuracy")
|
| 228 |
+
print(f" - Web API for file uploads")
|
| 229 |
+
print(f" - Training pipeline for custom domains")
|
| 230 |
+
|
| 231 |
+
# Simulate API functionality
|
| 232 |
+
print(f"\nAPI FUNCTIONALITY SIMULATION")
|
| 233 |
+
print("=" * 40)
|
| 234 |
+
|
| 235 |
+
sample_text = "Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
| 236 |
+
|
| 237 |
+
print('API Request (POST /extract-from-text):')
|
| 238 |
+
print(' {')
|
| 239 |
+
print(f' "text": "{sample_text}"')
|
| 240 |
+
print('}')
|
| 241 |
+
|
| 242 |
+
print(f"\nAPI Response:")
|
| 243 |
+
api_result = processor.process_document(sample_text)
|
| 244 |
+
|
| 245 |
+
api_response = {
|
| 246 |
+
"status": "success",
|
| 247 |
+
"data": {
|
| 248 |
+
"original_text": sample_text,
|
| 249 |
+
"entities": api_result['entities'],
|
| 250 |
+
"structured_data": api_result['structured_data'],
|
| 251 |
+
"processing_timestamp": datetime.now().isoformat(),
|
| 252 |
+
"total_entities_found": api_result['entity_count'],
|
| 253 |
+
"entity_types_found": api_result['entity_types']
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
print(json.dumps(api_response, indent=2))
|
| 258 |
+
|
| 259 |
+
print(f"\nTo run the full system:")
|
| 260 |
+
print(f" 1. Install ML dependencies: pip install torch transformers")
|
| 261 |
+
print(f" 2. Run training: python src/training_pipeline.py")
|
| 262 |
+
print(f" 3. Start API: python api/app.py")
|
| 263 |
+
print(f" 4. Open browser: http://localhost:8000")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
run_demo()
|
| 268 |
+
"""Simplified document processor for demo purposes."""
|
| 269 |
+
|
| 270 |
+
def __init__(self):
|
| 271 |
+
"""Initialize with regex patterns for entity extraction."""
|
| 272 |
+
self.entity_patterns = {
|
| 273 |
+
'NAME': [
|
| 274 |
+
r'\b(?:Mr\.|Mrs\.|Ms\.|Dr\.)\s+([A-Z][a-z]+ [A-Z][a-z]+)\b',
|
| 275 |
+
r'\b([A-Z][a-z]+ [A-Z][a-z]+)\b',
|
| 276 |
+
],
|
| 277 |
+
'DATE': [
|
| 278 |
+
r'\b(\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4})\b',
|
| 279 |
+
r'\b(\d{4}[/\-]\d{1,2}[/\-]\d{1,2})\b',
|
| 280 |
+
r'\b((?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4})\b'
|
| 281 |
+
],
|
| 282 |
+
'INVOICE_NO': [
|
| 283 |
+
r'(?:Invoice\s+(?:No|Number|#):\s*)?([A-Z]{2,4}[-]?\d{3,6})',
|
| 284 |
+
r'(INV[-]?\d{3,6})',
|
| 285 |
+
r'(BL[-]?\d{3,6})',
|
| 286 |
+
r'(REC[-]?\d{3,6})',
|
| 287 |
+
],
|
| 288 |
+
'AMOUNT': [
|
| 289 |
+
r'(\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?)',
|
| 290 |
+
r'(\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|EUR|GBP))',
|
| 291 |
+
],
|
| 292 |
+
'PHONE': [
|
| 293 |
+
r'(\+?\d{1,3}[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4})',
|
| 294 |
+
r'(\(\d{3}\)\s*\d{3}-\d{4})',
|
| 295 |
+
],
|
| 296 |
+
'EMAIL': [
|
| 297 |
+
r'\b([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,})\b',
|
| 298 |
+
]
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def extract_entities(self, text: str) -> List[Dict[str, Any]]:
|
| 302 |
+
"""Extract entities from text using regex patterns."""
|
| 303 |
+
entities = []
|
| 304 |
+
|
| 305 |
+
for entity_type, patterns in self.entity_patterns.items():
|
| 306 |
+
for pattern in patterns:
|
| 307 |
+
matches = re.finditer(pattern, text, re.IGNORECASE)
|
| 308 |
+
for match in matches:
|
| 309 |
+
entity_text = match.group(1) if match.groups() else match.group(0)
|
| 310 |
+
|
| 311 |
+
# Calculate position
|
| 312 |
+
start_pos = match.start()
|
| 313 |
+
end_pos = match.end()
|
| 314 |
+
|
| 315 |
+
# Assign confidence based on pattern strength
|
| 316 |
+
confidence = self._calculate_confidence(entity_type, entity_text, pattern)
|
| 317 |
+
|
| 318 |
+
entity = {
|
| 319 |
+
'entity': entity_type,
|
| 320 |
+
'text': entity_text.strip(),
|
| 321 |
+
'start': start_pos,
|
| 322 |
+
'end': end_pos,
|
| 323 |
+
'confidence': confidence
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
# Avoid duplicates
|
| 327 |
+
if not self._is_duplicate(entity, entities):
|
| 328 |
+
entities.append(entity)
|
| 329 |
+
|
| 330 |
+
return entities
|
| 331 |
+
|
| 332 |
+
def _calculate_confidence(self, entity_type: str, text: str, pattern: str) -> float:
|
| 333 |
+
"""Calculate confidence score for extracted entity."""
|
| 334 |
+
base_confidence = 0.8
|
| 335 |
+
|
| 336 |
+
# Boost confidence for specific patterns
|
| 337 |
+
if entity_type == 'EMAIL' and '@' in text:
|
| 338 |
+
base_confidence = 0.95
|
| 339 |
+
elif entity_type == 'PHONE' and len(re.sub(r'[^\d]', '', text)) >= 10:
|
| 340 |
+
base_confidence = 0.90
|
| 341 |
+
elif entity_type == 'AMOUNT' and '$' in text:
|
| 342 |
+
base_confidence = 0.85
|
| 343 |
+
elif entity_type == 'DATE':
|
| 344 |
+
base_confidence = 0.85
|
| 345 |
+
elif entity_type == 'INVOICE_NO' and any(prefix in text.upper() for prefix in ['INV', 'BL', 'REC']):
|
| 346 |
+
base_confidence = 0.90
|
| 347 |
+
|
| 348 |
+
return min(base_confidence, 0.99)
|
| 349 |
+
|
| 350 |
+
def _is_duplicate(self, new_entity: Dict, existing_entities: List[Dict]) -> bool:
|
| 351 |
+
"""Check if entity is duplicate."""
|
| 352 |
+
for existing in existing_entities:
|
| 353 |
+
if (existing['entity'] == new_entity['entity'] and
|
| 354 |
+
existing['text'].lower() == new_entity['text'].lower()):
|
| 355 |
+
return True
|
| 356 |
+
return False
|
| 357 |
+
|
| 358 |
+
def postprocess_entities(self, entities: List[Dict], text: str) -> Dict[str, str]:
|
| 359 |
+
"""Convert entities to structured data format."""
|
| 360 |
+
structured_data = {}
|
| 361 |
+
|
| 362 |
+
# Group entities by type and pick the best one
|
| 363 |
+
entity_groups = {}
|
| 364 |
+
for entity in entities:
|
| 365 |
+
entity_type = entity['entity']
|
| 366 |
+
if entity_type not in entity_groups:
|
| 367 |
+
entity_groups[entity_type] = []
|
| 368 |
+
entity_groups[entity_type].append(entity)
|
| 369 |
+
|
| 370 |
+
# Select best entity for each type
|
| 371 |
+
for entity_type, group in entity_groups.items():
|
| 372 |
+
best_entity = max(group, key=lambda x: x['confidence'])
|
| 373 |
+
|
| 374 |
+
# Format the value
|
| 375 |
+
formatted_value = self._format_entity_value(best_entity['text'], entity_type)
|
| 376 |
+
|
| 377 |
+
# Map to human-readable keys
|
| 378 |
+
readable_key = {
|
| 379 |
+
'NAME': 'Name',
|
| 380 |
+
'DATE': 'Date',
|
| 381 |
+
'INVOICE_NO': 'InvoiceNo',
|
| 382 |
+
'AMOUNT': 'Amount',
|
| 383 |
+
'PHONE': 'Phone',
|
| 384 |
+
'EMAIL': 'Email'
|
| 385 |
+
}.get(entity_type, entity_type)
|
| 386 |
+
|
| 387 |
+
structured_data[readable_key] = formatted_value
|
| 388 |
+
|
| 389 |
+
return structured_data
|
| 390 |
+
|
| 391 |
+
def _format_entity_value(self, text: str, entity_type: str) -> str:
|
| 392 |
+
"""Format entity value based on type."""
|
| 393 |
+
text = text.strip()
|
| 394 |
+
|
| 395 |
+
if entity_type == 'NAME':
|
| 396 |
+
return ' '.join(word.capitalize() for word in text.split())
|
| 397 |
+
elif entity_type == 'PHONE':
|
| 398 |
+
digits = re.sub(r'[^\d]', '', text)
|
| 399 |
+
if len(digits) == 10:
|
| 400 |
+
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
| 401 |
+
elif len(digits) == 11 and digits[0] == '1':
|
| 402 |
+
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
| 403 |
+
elif entity_type == 'AMOUNT':
|
| 404 |
+
# Ensure proper formatting
|
| 405 |
+
if not text.startswith('$'):
|
| 406 |
+
return f"${text}"
|
| 407 |
+
|
| 408 |
+
return text
|
| 409 |
+
|
| 410 |
+
def process_text(self, text: str) -> Dict[str, Any]:
|
| 411 |
+
"""Process text and return extraction results."""
|
| 412 |
+
# Extract entities
|
| 413 |
+
entities = self.extract_entities(text)
|
| 414 |
+
|
| 415 |
+
# Create structured data
|
| 416 |
+
structured_data = self.postprocess_entities(entities, text)
|
| 417 |
+
|
| 418 |
+
# Return complete result
|
| 419 |
+
return {
|
| 420 |
+
'original_text': text,
|
| 421 |
+
'entities': entities,
|
| 422 |
+
'structured_data': structured_data,
|
| 423 |
+
'processing_timestamp': datetime.now().isoformat(),
|
| 424 |
+
'total_entities_found': len(entities),
|
| 425 |
+
'entity_types_found': list(set(e['entity'] for e in entities))
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def run_demo():
|
| 430 |
+
"""Run the document extraction demo."""
|
| 431 |
+
print("SIMPLIFIED DOCUMENT TEXT EXTRACTION DEMO")
|
| 432 |
+
print("=" * 60)
|
| 433 |
+
print("This demo shows the core extraction logic using regex patterns")
|
| 434 |
+
print("(without the full ML pipeline for demonstration purposes)")
|
| 435 |
+
print()
|
| 436 |
+
|
| 437 |
+
# Initialize processor
|
| 438 |
+
processor = SimpleDocumentProcessor()
|
| 439 |
+
|
| 440 |
+
# Sample documents
|
| 441 |
+
sample_docs = [
|
| 442 |
+
{
|
| 443 |
+
"name": "Invoice Example 1",
|
| 444 |
+
"text": "Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250.00 Phone: (555) 123-4567"
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"name": "Invoice Example 2",
|
| 448 |
+
"text": "Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Email: sarah.johnson@email.com"
|
| 449 |
+
},
|
| 450 |
+
{
|
| 451 |
+
"name": "Receipt Example",
|
| 452 |
+
"text": "Receipt for Michael Brown Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75 Contact: +1-555-987-6543"
|
| 453 |
+
},
|
| 454 |
+
{
|
| 455 |
+
"name": "Business Document",
|
| 456 |
+
"text": "Ms. Emma Wilson 456 Oak Street Payment due: January 15, 2025 Reference: INV-4567 Total: $1,750.25"
|
| 457 |
+
}
|
| 458 |
+
]
|
| 459 |
+
|
| 460 |
+
results = []
|
| 461 |
+
|
| 462 |
+
for i, doc in enumerate(sample_docs, 1):
|
| 463 |
+
print(f"\nDocument {i}: {doc['name']}")
|
| 464 |
+
print("-" * 50)
|
| 465 |
+
print(f"Text: {doc['text']}")
|
| 466 |
+
|
| 467 |
+
# Process the document
|
| 468 |
+
result = processor.process_text(doc['text'])
|
| 469 |
+
results.append({
|
| 470 |
+
'document_name': doc['name'],
|
| 471 |
+
**result
|
| 472 |
+
})
|
| 473 |
+
|
| 474 |
+
# Display results
|
| 475 |
+
print(f"\nExtraction Results:")
|
| 476 |
+
print(f" Found {result['total_entities_found']} entities")
|
| 477 |
+
print(f" Entity types: {', '.join(result['entity_types_found'])}")
|
| 478 |
+
|
| 479 |
+
# Show structured data
|
| 480 |
+
if result['structured_data']:
|
| 481 |
+
print(f"\nStructured Information:")
|
| 482 |
+
for key, value in result['structured_data'].items():
|
| 483 |
+
print(f" {key}: {value}")
|
| 484 |
+
|
| 485 |
+
# Show detailed entities
|
| 486 |
+
if result['entities']:
|
| 487 |
+
print(f"\nDetailed Entities:")
|
| 488 |
+
for entity in result['entities']:
|
| 489 |
+
confidence_pct = int(entity['confidence'] * 100)
|
| 490 |
+
print(f" {entity['entity']}: '{entity['text']}' (confidence: {confidence_pct}%)")
|
| 491 |
+
|
| 492 |
+
# Save results
|
| 493 |
+
output_dir = Path("results")
|
| 494 |
+
output_dir.mkdir(exist_ok=True)
|
| 495 |
+
|
| 496 |
+
output_file = output_dir / "demo_extraction_results.json"
|
| 497 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 498 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 499 |
+
|
| 500 |
+
print(f"\n💾 Results saved to: {output_file}")
|
| 501 |
+
|
| 502 |
+
# Summary statistics
|
| 503 |
+
total_entities = sum(len(r['entities']) for r in results)
|
| 504 |
+
total_structured_fields = sum(len(r['structured_data']) for r in results)
|
| 505 |
+
unique_entity_types = set()
|
| 506 |
+
for r in results:
|
| 507 |
+
unique_entity_types.update(r['entity_types_found'])
|
| 508 |
+
|
| 509 |
+
print(f"\nDemo Summary:")
|
| 510 |
+
print(f" Documents processed: {len(results)}")
|
| 511 |
+
print(f" Total entities found: {total_entities}")
|
| 512 |
+
print(f" Total structured fields: {total_structured_fields}")
|
| 513 |
+
print(f" Unique entity types: {', '.join(sorted(unique_entity_types))}")
|
| 514 |
+
|
| 515 |
+
print(f"\nDemo completed successfully!")
|
| 516 |
+
print(f"\nThis demonstrates the core extraction logic.")
|
| 517 |
+
print(f" The full system would add:")
|
| 518 |
+
print(f" - OCR for scanned documents")
|
| 519 |
+
print(f" - ML model (DistilBERT) for better accuracy")
|
| 520 |
+
print(f" - Web API for file uploads")
|
| 521 |
+
print(f" - Training pipeline for custom domains")
|
| 522 |
+
|
| 523 |
+
return results
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def show_api_simulation():
|
| 527 |
+
"""Simulate the API functionality."""
|
| 528 |
+
print(f"\n🌐 API FUNCTIONALITY SIMULATION")
|
| 529 |
+
print("=" * 40)
|
| 530 |
+
|
| 531 |
+
processor = SimpleDocumentProcessor()
|
| 532 |
+
|
| 533 |
+
# Simulate API request
|
| 534 |
+
sample_request = {
|
| 535 |
+
"text": "Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
print(f"API Request (POST /extract-from-text):")
|
| 539 |
+
print(f" {json.dumps(sample_request, indent=2)}")
|
| 540 |
+
|
| 541 |
+
# Process
|
| 542 |
+
result = processor.process_text(sample_request["text"])
|
| 543 |
+
|
| 544 |
+
# Simulate API response
|
| 545 |
+
api_response = {
|
| 546 |
+
"status": "success",
|
| 547 |
+
"data": result
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
print(f"\nAPI Response:")
|
| 551 |
+
print(f" {json.dumps(api_response, indent=2)}")
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
if __name__ == "__main__":
|
| 555 |
+
# Run the main demo
|
| 556 |
+
results = run_demo()
|
| 557 |
+
|
| 558 |
+
# Show API simulation
|
| 559 |
+
show_api_simulation()
|
| 560 |
+
|
| 561 |
+
print(f"\nTo run the full system:")
|
| 562 |
+
print(f" 1. Install ML dependencies: pip install torch transformers")
|
| 563 |
+
print(f" 2. Run training: python src/training_pipeline.py")
|
| 564 |
+
print(f" 3. Start API: python api/app.py")
|
| 565 |
+
print(f" 4. Open browser: http://localhost:8000")
|
src/data_preparation.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data preparation module for document text extraction.
|
| 3 |
+
Handles OCR, text cleaning, and dataset creation for NER training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
import pytesseract
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
from typing import List, Dict, Tuple, Optional
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import fitz # PyMuPDF for PDF processing
|
| 17 |
+
from docx import Document
|
| 18 |
+
import easyocr
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DocumentProcessor:
|
| 22 |
+
"""Handles document processing, OCR, and text extraction."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, tesseract_path: Optional[str] = None):
|
| 25 |
+
"""Initialize document processor with OCR settings."""
|
| 26 |
+
if tesseract_path:
|
| 27 |
+
pytesseract.pytesseract.tesseract_cmd = tesseract_path
|
| 28 |
+
|
| 29 |
+
# Initialize EasyOCR reader
|
| 30 |
+
self.ocr_reader = easyocr.Reader(['en'])
|
| 31 |
+
|
| 32 |
+
# Entity patterns for initial labeling
|
| 33 |
+
self.entity_patterns = {
|
| 34 |
+
'NAME': [
|
| 35 |
+
r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', # First Last
|
| 36 |
+
r'(?:Mr\.|Mrs\.|Ms\.|Dr\.)\s+[A-Z][a-z]+ [A-Z][a-z]+', # Title + Name
|
| 37 |
+
],
|
| 38 |
+
'DATE': [
|
| 39 |
+
r'\b\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}\b', # DD/MM/YYYY
|
| 40 |
+
r'\b\d{4}[/\-]\d{1,2}[/\-]\d{1,2}\b', # YYYY/MM/DD
|
| 41 |
+
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4}\b'
|
| 42 |
+
],
|
| 43 |
+
'INVOICE_NO': [
|
| 44 |
+
r'(?:Invoice\s+(?:No|Number|#):\s*)?([A-Z]{2,4}[-]?\d{3,6})',
|
| 45 |
+
r'(?:INV[-]?\d{3,6})',
|
| 46 |
+
],
|
| 47 |
+
'AMOUNT': [
|
| 48 |
+
r'\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?', # $1,000.00
|
| 49 |
+
r'\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|EUR|GBP)', # 1000.00 USD
|
| 50 |
+
],
|
| 51 |
+
'ADDRESS': [
|
| 52 |
+
r'\d+\s+[A-Za-z\s]+(?:Street|St|Avenue|Ave|Road|Rd|Drive|Dr|Lane|Ln).*',
|
| 53 |
+
],
|
| 54 |
+
'PHONE': [
|
| 55 |
+
r'\+?\d{1,3}[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
|
| 56 |
+
r'\(\d{3}\)\s*\d{3}-\d{4}',
|
| 57 |
+
],
|
| 58 |
+
'EMAIL': [
|
| 59 |
+
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
| 60 |
+
]
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
def extract_text_from_pdf(self, pdf_path: str) -> str:
|
| 64 |
+
"""Extract text from PDF file."""
|
| 65 |
+
try:
|
| 66 |
+
doc = fitz.open(pdf_path)
|
| 67 |
+
text = ""
|
| 68 |
+
for page_num in range(len(doc)):
|
| 69 |
+
page = doc.load_page(page_num)
|
| 70 |
+
text += page.get_text()
|
| 71 |
+
doc.close()
|
| 72 |
+
return text
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Error extracting text from PDF {pdf_path}: {e}")
|
| 75 |
+
return ""
|
| 76 |
+
|
| 77 |
+
def extract_text_from_docx(self, docx_path: str) -> str:
|
| 78 |
+
"""Extract text from DOCX file."""
|
| 79 |
+
try:
|
| 80 |
+
doc = Document(docx_path)
|
| 81 |
+
text = ""
|
| 82 |
+
for paragraph in doc.paragraphs:
|
| 83 |
+
text += paragraph.text + "\n"
|
| 84 |
+
return text
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error extracting text from DOCX {docx_path}: {e}")
|
| 87 |
+
return ""
|
| 88 |
+
|
| 89 |
+
def preprocess_image(self, image_path: str) -> np.ndarray:
|
| 90 |
+
"""Preprocess image for better OCR results."""
|
| 91 |
+
img = cv2.imread(image_path)
|
| 92 |
+
if img is None:
|
| 93 |
+
raise ValueError(f"Could not load image: {image_path}")
|
| 94 |
+
|
| 95 |
+
# Convert to grayscale
|
| 96 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 97 |
+
|
| 98 |
+
# Apply Gaussian blur to reduce noise
|
| 99 |
+
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
| 100 |
+
|
| 101 |
+
# Apply adaptive threshold
|
| 102 |
+
thresh = cv2.adaptiveThreshold(
|
| 103 |
+
blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return thresh
|
| 107 |
+
|
| 108 |
+
def extract_text_with_tesseract(self, image_path: str) -> str:
|
| 109 |
+
"""Extract text using Tesseract OCR."""
|
| 110 |
+
try:
|
| 111 |
+
preprocessed_img = self.preprocess_image(image_path)
|
| 112 |
+
|
| 113 |
+
# Configure Tesseract
|
| 114 |
+
custom_config = r'--oem 3 --psm 6'
|
| 115 |
+
text = pytesseract.image_to_string(preprocessed_img, config=custom_config)
|
| 116 |
+
|
| 117 |
+
return text
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error with Tesseract OCR on {image_path}: {e}")
|
| 120 |
+
return ""
|
| 121 |
+
|
| 122 |
+
def extract_text_with_easyocr(self, image_path: str) -> str:
|
| 123 |
+
"""Extract text using EasyOCR."""
|
| 124 |
+
try:
|
| 125 |
+
results = self.ocr_reader.readtext(image_path)
|
| 126 |
+
text = " ".join([result[1] for result in results])
|
| 127 |
+
return text
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"Error with EasyOCR on {image_path}: {e}")
|
| 130 |
+
return ""
|
| 131 |
+
|
| 132 |
+
def extract_text_from_image(self, image_path: str, use_easyocr: bool = True) -> str:
|
| 133 |
+
"""Extract text from image using OCR."""
|
| 134 |
+
if use_easyocr:
|
| 135 |
+
text = self.extract_text_with_easyocr(image_path)
|
| 136 |
+
if not text.strip(): # Fallback to Tesseract
|
| 137 |
+
text = self.extract_text_with_tesseract(image_path)
|
| 138 |
+
else:
|
| 139 |
+
text = self.extract_text_with_tesseract(image_path)
|
| 140 |
+
if not text.strip(): # Fallback to EasyOCR
|
| 141 |
+
text = self.extract_text_with_easyocr(image_path)
|
| 142 |
+
|
| 143 |
+
return text
|
| 144 |
+
|
| 145 |
+
def clean_text(self, text: str) -> str:
|
| 146 |
+
"""Clean and normalize extracted text."""
|
| 147 |
+
# Remove extra whitespace
|
| 148 |
+
text = re.sub(r'\s+', ' ', text)
|
| 149 |
+
|
| 150 |
+
# Remove special characters but keep important punctuation
|
| 151 |
+
text = re.sub(r'[^\w\s\.\,\:\;\-\$\(\)\[\]\/]', '', text)
|
| 152 |
+
|
| 153 |
+
# Normalize whitespace around punctuation
|
| 154 |
+
text = re.sub(r'\s*([,.;:])\s*', r'\1 ', text)
|
| 155 |
+
|
| 156 |
+
return text.strip()
|
| 157 |
+
|
| 158 |
+
def process_document(self, file_path: str) -> str:
|
| 159 |
+
"""Process any document type and extract text."""
|
| 160 |
+
file_path = Path(file_path)
|
| 161 |
+
file_ext = file_path.suffix.lower()
|
| 162 |
+
|
| 163 |
+
if file_ext == '.pdf':
|
| 164 |
+
text = self.extract_text_from_pdf(str(file_path))
|
| 165 |
+
elif file_ext == '.docx':
|
| 166 |
+
text = self.extract_text_from_docx(str(file_path))
|
| 167 |
+
elif file_ext in ['.png', '.jpg', '.jpeg', '.tiff', '.bmp']:
|
| 168 |
+
text = self.extract_text_from_image(str(file_path))
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Unsupported file type: {file_ext}")
|
| 171 |
+
|
| 172 |
+
return self.clean_text(text)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class NERDatasetCreator:
|
| 176 |
+
"""Creates NER training datasets from processed documents."""
|
| 177 |
+
|
| 178 |
+
def __init__(self, document_processor: DocumentProcessor):
|
| 179 |
+
self.document_processor = document_processor
|
| 180 |
+
self.entity_labels = ['O', 'B-NAME', 'I-NAME', 'B-DATE', 'I-DATE',
|
| 181 |
+
'B-INVOICE_NO', 'I-INVOICE_NO', 'B-AMOUNT', 'I-AMOUNT',
|
| 182 |
+
'B-ADDRESS', 'I-ADDRESS', 'B-PHONE', 'I-PHONE',
|
| 183 |
+
'B-EMAIL', 'I-EMAIL']
|
| 184 |
+
|
| 185 |
+
def auto_label_text(self, text: str) -> List[Tuple[str, str]]:
|
| 186 |
+
"""Automatically label text using regex patterns."""
|
| 187 |
+
words = text.split()
|
| 188 |
+
labels = ['O'] * len(words)
|
| 189 |
+
|
| 190 |
+
# Track word positions in original text
|
| 191 |
+
word_positions = []
|
| 192 |
+
start = 0
|
| 193 |
+
for word in words:
|
| 194 |
+
pos = text.find(word, start)
|
| 195 |
+
word_positions.append((pos, pos + len(word)))
|
| 196 |
+
start = pos + len(word)
|
| 197 |
+
|
| 198 |
+
# Apply entity patterns
|
| 199 |
+
for entity_type, patterns in self.document_processor.entity_patterns.items():
|
| 200 |
+
for pattern in patterns:
|
| 201 |
+
matches = list(re.finditer(pattern, text, re.IGNORECASE))
|
| 202 |
+
for match in matches:
|
| 203 |
+
match_start, match_end = match.span()
|
| 204 |
+
|
| 205 |
+
# Find which words overlap with this match
|
| 206 |
+
first_word_idx = None
|
| 207 |
+
last_word_idx = None
|
| 208 |
+
|
| 209 |
+
for i, (word_start, word_end) in enumerate(word_positions):
|
| 210 |
+
if word_start >= match_start and word_end <= match_end:
|
| 211 |
+
if first_word_idx is None:
|
| 212 |
+
first_word_idx = i
|
| 213 |
+
last_word_idx = i
|
| 214 |
+
elif word_start < match_end and word_end > match_start:
|
| 215 |
+
# Partial overlap
|
| 216 |
+
if first_word_idx is None:
|
| 217 |
+
first_word_idx = i
|
| 218 |
+
last_word_idx = i
|
| 219 |
+
|
| 220 |
+
# Apply BIO labeling
|
| 221 |
+
if first_word_idx is not None:
|
| 222 |
+
labels[first_word_idx] = f'B-{entity_type}'
|
| 223 |
+
for i in range(first_word_idx + 1, last_word_idx + 1):
|
| 224 |
+
labels[i] = f'I-{entity_type}'
|
| 225 |
+
|
| 226 |
+
return list(zip(words, labels))
|
| 227 |
+
|
| 228 |
+
def create_training_example(self, text: str) -> Dict:
|
| 229 |
+
"""Create a training example from text."""
|
| 230 |
+
labeled_tokens = self.auto_label_text(text)
|
| 231 |
+
|
| 232 |
+
tokens = [token for token, _ in labeled_tokens]
|
| 233 |
+
labels = [label for _, label in labeled_tokens]
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
'tokens': tokens,
|
| 237 |
+
'labels': labels,
|
| 238 |
+
'text': text
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
def create_sample_dataset(self) -> List[Dict]:
|
| 242 |
+
"""Create sample training data for demonstration."""
|
| 243 |
+
sample_texts = [
|
| 244 |
+
"Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250",
|
| 245 |
+
"Bill for Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50",
|
| 246 |
+
"Payment due from Michael Brown on 01/12/2025. Reference: PAY-3067. Sum: $890.00",
|
| 247 |
+
"Receipt for Emma Wilson Invoice: REC-4089 Date: 2025-04-22 Amount: $1,750.25",
|
| 248 |
+
"Dr. James Smith 123 Main Street Boston MA 02101 Phone: (555) 123-4567 Email: james@email.com",
|
| 249 |
+
"Ms. Lisa Anderson 456 Oak Avenue New York NY 10001 Contact: +1-555-987-6543",
|
| 250 |
+
"Invoice INV-5678 issued to David Lee on February 5, 2025 for $3,400.00",
|
| 251 |
+
"Bill #BIL-9012 for Jennifer Garcia dated 2025-05-15. Total amount: $567.89"
|
| 252 |
+
]
|
| 253 |
+
|
| 254 |
+
dataset = []
|
| 255 |
+
for text in sample_texts:
|
| 256 |
+
example = self.create_training_example(text)
|
| 257 |
+
dataset.append(example)
|
| 258 |
+
|
| 259 |
+
return dataset
|
| 260 |
+
|
| 261 |
+
def process_documents_folder(self, folder_path: str) -> List[Dict]:
|
| 262 |
+
"""Process all documents in a folder and create training dataset."""
|
| 263 |
+
folder_path = Path(folder_path)
|
| 264 |
+
dataset = []
|
| 265 |
+
|
| 266 |
+
if not folder_path.exists():
|
| 267 |
+
print(f"Folder {folder_path} does not exist. Creating sample dataset instead.")
|
| 268 |
+
return self.create_sample_dataset()
|
| 269 |
+
|
| 270 |
+
supported_extensions = ['.pdf', '.docx', '.png', '.jpg', '.jpeg', '.tiff', '.bmp']
|
| 271 |
+
|
| 272 |
+
for file_path in folder_path.rglob('*'):
|
| 273 |
+
if file_path.suffix.lower() in supported_extensions:
|
| 274 |
+
try:
|
| 275 |
+
print(f"Processing {file_path.name}...")
|
| 276 |
+
text = self.document_processor.process_document(str(file_path))
|
| 277 |
+
|
| 278 |
+
if text.strip(): # Only process non-empty texts
|
| 279 |
+
example = self.create_training_example(text)
|
| 280 |
+
example['source_file'] = str(file_path)
|
| 281 |
+
dataset.append(example)
|
| 282 |
+
print(f"Processed {file_path.name}")
|
| 283 |
+
else:
|
| 284 |
+
print(f"No text extracted from {file_path.name}")
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"Error processing {file_path.name}: {e}")
|
| 288 |
+
|
| 289 |
+
if not dataset:
|
| 290 |
+
print("No documents processed. Creating sample dataset.")
|
| 291 |
+
return self.create_sample_dataset()
|
| 292 |
+
|
| 293 |
+
return dataset
|
| 294 |
+
|
| 295 |
+
def save_dataset(self, dataset: List[Dict], output_path: str):
|
| 296 |
+
"""Save dataset to JSON file."""
|
| 297 |
+
output_path = Path(output_path)
|
| 298 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 299 |
+
|
| 300 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 301 |
+
json.dump(dataset, f, indent=2, ensure_ascii=False)
|
| 302 |
+
|
| 303 |
+
print(f"Dataset saved to {output_path}")
|
| 304 |
+
print(f"Total examples: {len(dataset)}")
|
| 305 |
+
|
| 306 |
+
# Print statistics
|
| 307 |
+
all_labels = []
|
| 308 |
+
for example in dataset:
|
| 309 |
+
all_labels.extend(example['labels'])
|
| 310 |
+
|
| 311 |
+
label_counts = {}
|
| 312 |
+
for label in all_labels:
|
| 313 |
+
label_counts[label] = label_counts.get(label, 0) + 1
|
| 314 |
+
|
| 315 |
+
print("\nLabel distribution:")
|
| 316 |
+
for label, count in sorted(label_counts.items()):
|
| 317 |
+
print(f" {label}: {count}")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def main():
|
| 321 |
+
"""Main function to demonstrate data preparation."""
|
| 322 |
+
# Initialize components
|
| 323 |
+
processor = DocumentProcessor()
|
| 324 |
+
dataset_creator = NERDatasetCreator(processor)
|
| 325 |
+
|
| 326 |
+
# Process documents (or create sample data)
|
| 327 |
+
raw_data_path = "data/raw"
|
| 328 |
+
dataset = dataset_creator.process_documents_folder(raw_data_path)
|
| 329 |
+
|
| 330 |
+
# Save processed dataset
|
| 331 |
+
output_path = "data/processed/ner_dataset.json"
|
| 332 |
+
dataset_creator.save_dataset(dataset, output_path)
|
| 333 |
+
|
| 334 |
+
print(f"\nData preparation completed!")
|
| 335 |
+
print(f"Processed {len(dataset)} documents")
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
main()
|
src/inference.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference pipeline for document text extraction.
|
| 3 |
+
Processes new documents and extracts structured information using trained SLM.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
import re
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from src.data_preparation import DocumentProcessor
|
| 15 |
+
from src.model import DocumentNERModel, NERTrainer, ModelConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DocumentInference:
|
| 19 |
+
"""Inference pipeline for extracting structured data from documents."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_path: str):
|
| 22 |
+
"""Initialize inference pipeline with trained model."""
|
| 23 |
+
self.model_path = model_path
|
| 24 |
+
self.config = self._load_config()
|
| 25 |
+
self.model = None
|
| 26 |
+
self.trainer = None
|
| 27 |
+
self.document_processor = DocumentProcessor()
|
| 28 |
+
|
| 29 |
+
# Load the trained model
|
| 30 |
+
self._load_model()
|
| 31 |
+
|
| 32 |
+
# Post-processing patterns for field validation and formatting
|
| 33 |
+
self.postprocess_patterns = {
|
| 34 |
+
'DATE': [
|
| 35 |
+
r'\b\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}\b',
|
| 36 |
+
r'\b\d{4}[/\-]\d{1,2}[/\-]\d{1,2}\b',
|
| 37 |
+
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4}\b'
|
| 38 |
+
],
|
| 39 |
+
'AMOUNT': [
|
| 40 |
+
r'\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?',
|
| 41 |
+
r'\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|EUR|GBP)'
|
| 42 |
+
],
|
| 43 |
+
'PHONE': [
|
| 44 |
+
r'\+?\d{1,3}[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
|
| 45 |
+
r'\(\d{3}\)\s*\d{3}-\d{4}'
|
| 46 |
+
],
|
| 47 |
+
'EMAIL': [
|
| 48 |
+
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def _load_config(self) -> ModelConfig:
|
| 53 |
+
"""Load training configuration."""
|
| 54 |
+
config_path = Path(self.model_path) / "training_config.json"
|
| 55 |
+
|
| 56 |
+
if config_path.exists():
|
| 57 |
+
with open(config_path, 'r') as f:
|
| 58 |
+
config_dict = json.load(f)
|
| 59 |
+
config = ModelConfig(**config_dict)
|
| 60 |
+
else:
|
| 61 |
+
print("No training config found. Using default configuration.")
|
| 62 |
+
config = ModelConfig()
|
| 63 |
+
|
| 64 |
+
return config
|
| 65 |
+
|
| 66 |
+
def _load_model(self):
|
| 67 |
+
"""Load the trained model and tokenizer."""
|
| 68 |
+
try:
|
| 69 |
+
# Create model and trainer
|
| 70 |
+
self.model = DocumentNERModel(self.config)
|
| 71 |
+
self.trainer = NERTrainer(self.model, self.config)
|
| 72 |
+
|
| 73 |
+
# Load the trained weights
|
| 74 |
+
self.trainer.load_model(self.model_path)
|
| 75 |
+
|
| 76 |
+
print(f"Model loaded successfully from {self.model_path}")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise Exception(f"Failed to load model from {self.model_path}: {e}")
|
| 80 |
+
|
| 81 |
+
def predict_entities(self, text: str) -> List[Dict[str, Any]]:
|
| 82 |
+
"""Predict entities from text using the trained model."""
|
| 83 |
+
# Tokenize the text
|
| 84 |
+
tokens = text.split()
|
| 85 |
+
|
| 86 |
+
# Prepare input for the model
|
| 87 |
+
inputs = self.trainer.tokenizer(
|
| 88 |
+
tokens,
|
| 89 |
+
is_split_into_words=True,
|
| 90 |
+
padding='max_length',
|
| 91 |
+
truncation=True,
|
| 92 |
+
max_length=self.config.max_length,
|
| 93 |
+
return_tensors='pt'
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Move to device
|
| 97 |
+
inputs = {k: v.to(self.trainer.device) for k, v in inputs.items()}
|
| 98 |
+
|
| 99 |
+
# Get predictions
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
predictions, probabilities = self.model.predict(
|
| 102 |
+
inputs['input_ids'],
|
| 103 |
+
inputs['attention_mask']
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Convert predictions to labels
|
| 107 |
+
word_ids = inputs['input_ids'][0].cpu().numpy()
|
| 108 |
+
pred_labels = predictions[0].cpu().numpy()
|
| 109 |
+
probs = probabilities[0].cpu().numpy()
|
| 110 |
+
|
| 111 |
+
# Align predictions with original tokens
|
| 112 |
+
word_ids_list = self.trainer.tokenizer.convert_ids_to_tokens(word_ids)
|
| 113 |
+
|
| 114 |
+
# Extract entities
|
| 115 |
+
entities = self._extract_entities_from_predictions(
|
| 116 |
+
tokens, pred_labels, probs, word_ids_list
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return entities
|
| 120 |
+
|
| 121 |
+
def _extract_entities_from_predictions(self, tokens: List[str],
|
| 122 |
+
pred_labels: np.ndarray,
|
| 123 |
+
probs: np.ndarray,
|
| 124 |
+
word_ids_list: List[str]) -> List[Dict[str, Any]]:
|
| 125 |
+
"""Extract entities from model predictions."""
|
| 126 |
+
entities = []
|
| 127 |
+
current_entity = None
|
| 128 |
+
|
| 129 |
+
# Map tokenizer output back to original tokens
|
| 130 |
+
token_idx = 0
|
| 131 |
+
|
| 132 |
+
for i, (token_id, label_id) in enumerate(zip(word_ids_list, pred_labels)):
|
| 133 |
+
if token_id in ['[CLS]', '[SEP]', '[PAD]']:
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
label = self.config.id2label.get(label_id, 'O')
|
| 137 |
+
confidence = float(np.max(probs[i]))
|
| 138 |
+
|
| 139 |
+
if label.startswith('B-'):
|
| 140 |
+
# Start of new entity
|
| 141 |
+
if current_entity:
|
| 142 |
+
entities.append(current_entity)
|
| 143 |
+
|
| 144 |
+
entity_type = label[2:] # Remove 'B-' prefix
|
| 145 |
+
current_entity = {
|
| 146 |
+
'entity': entity_type,
|
| 147 |
+
'text': token_id if not token_id.startswith('##') else token_id[2:],
|
| 148 |
+
'start': token_idx,
|
| 149 |
+
'end': token_idx + 1,
|
| 150 |
+
'confidence': confidence
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
elif label.startswith('I-') and current_entity:
|
| 154 |
+
# Continue current entity
|
| 155 |
+
entity_type = label[2:] # Remove 'I-' prefix
|
| 156 |
+
if current_entity['entity'] == entity_type:
|
| 157 |
+
if token_id.startswith('##'):
|
| 158 |
+
current_entity['text'] += token_id[2:]
|
| 159 |
+
else:
|
| 160 |
+
current_entity['text'] += ' ' + token_id
|
| 161 |
+
current_entity['end'] = token_idx + 1
|
| 162 |
+
current_entity['confidence'] = min(current_entity['confidence'], confidence)
|
| 163 |
+
|
| 164 |
+
else:
|
| 165 |
+
# 'O' label or end of entity
|
| 166 |
+
if current_entity:
|
| 167 |
+
entities.append(current_entity)
|
| 168 |
+
current_entity = None
|
| 169 |
+
|
| 170 |
+
if not token_id.startswith('##'):
|
| 171 |
+
token_idx += 1
|
| 172 |
+
|
| 173 |
+
# Add the last entity if it exists
|
| 174 |
+
if current_entity:
|
| 175 |
+
entities.append(current_entity)
|
| 176 |
+
|
| 177 |
+
return entities
|
| 178 |
+
|
| 179 |
+
def postprocess_entities(self, entities: List[Dict[str, Any]],
|
| 180 |
+
original_text: str) -> Dict[str, Any]:
|
| 181 |
+
"""Post-process and structure extracted entities."""
|
| 182 |
+
structured_data = {}
|
| 183 |
+
|
| 184 |
+
for entity in entities:
|
| 185 |
+
entity_type = entity['entity']
|
| 186 |
+
entity_text = entity['text']
|
| 187 |
+
confidence = entity['confidence']
|
| 188 |
+
|
| 189 |
+
# Apply post-processing patterns for validation
|
| 190 |
+
if entity_type in self.postprocess_patterns:
|
| 191 |
+
is_valid = self._validate_entity(entity_text, entity_type)
|
| 192 |
+
if not is_valid:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
# Format the entity value
|
| 196 |
+
formatted_value = self._format_entity_value(entity_text, entity_type)
|
| 197 |
+
|
| 198 |
+
# Store the best entity for each type (highest confidence)
|
| 199 |
+
if entity_type not in structured_data or confidence > structured_data[entity_type]['confidence']:
|
| 200 |
+
structured_data[entity_type] = {
|
| 201 |
+
'value': formatted_value,
|
| 202 |
+
'confidence': confidence,
|
| 203 |
+
'original_text': entity_text
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# Convert to final format
|
| 207 |
+
final_data = {}
|
| 208 |
+
entity_mapping = {
|
| 209 |
+
'NAME': 'Name',
|
| 210 |
+
'DATE': 'Date',
|
| 211 |
+
'INVOICE_NO': 'InvoiceNo',
|
| 212 |
+
'AMOUNT': 'Amount',
|
| 213 |
+
'ADDRESS': 'Address',
|
| 214 |
+
'PHONE': 'Phone',
|
| 215 |
+
'EMAIL': 'Email'
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
for entity_type, entity_data in structured_data.items():
|
| 219 |
+
human_readable_key = entity_mapping.get(entity_type, entity_type)
|
| 220 |
+
final_data[human_readable_key] = entity_data['value']
|
| 221 |
+
|
| 222 |
+
return final_data
|
| 223 |
+
|
| 224 |
+
def _validate_entity(self, text: str, entity_type: str) -> bool:
|
| 225 |
+
"""Validate entity using regex patterns."""
|
| 226 |
+
patterns = self.postprocess_patterns.get(entity_type, [])
|
| 227 |
+
|
| 228 |
+
for pattern in patterns:
|
| 229 |
+
if re.search(pattern, text, re.IGNORECASE):
|
| 230 |
+
return True
|
| 231 |
+
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
def _format_entity_value(self, text: str, entity_type: str) -> str:
|
| 235 |
+
"""Format entity value based on its type."""
|
| 236 |
+
text = text.strip()
|
| 237 |
+
|
| 238 |
+
if entity_type == 'DATE':
|
| 239 |
+
# Normalize date format
|
| 240 |
+
date_patterns = [
|
| 241 |
+
(r'(\d{1,2})[/\-](\d{1,2})[/\-](\d{2,4})', r'\1/\2/\3'),
|
| 242 |
+
(r'(\d{4})[/\-](\d{1,2})[/\-](\d{1,2})', r'\3/\2/\1')
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
for pattern, replacement in date_patterns:
|
| 246 |
+
match = re.search(pattern, text)
|
| 247 |
+
if match:
|
| 248 |
+
return re.sub(pattern, replacement, text)
|
| 249 |
+
|
| 250 |
+
elif entity_type == 'AMOUNT':
|
| 251 |
+
# Normalize amount format
|
| 252 |
+
amount_match = re.search(r'[\$\d,\.]+', text)
|
| 253 |
+
if amount_match:
|
| 254 |
+
return amount_match.group()
|
| 255 |
+
|
| 256 |
+
elif entity_type == 'PHONE':
|
| 257 |
+
# Normalize phone format
|
| 258 |
+
digits = re.sub(r'[^\d]', '', text)
|
| 259 |
+
if len(digits) == 10:
|
| 260 |
+
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
| 261 |
+
elif len(digits) == 11 and digits[0] == '1':
|
| 262 |
+
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
| 263 |
+
|
| 264 |
+
elif entity_type == 'NAME':
|
| 265 |
+
# Capitalize name properly
|
| 266 |
+
return ' '.join(word.capitalize() for word in text.split())
|
| 267 |
+
|
| 268 |
+
return text
|
| 269 |
+
|
| 270 |
+
def process_document(self, file_path: str) -> Dict[str, Any]:
|
| 271 |
+
"""Process a document and extract structured information."""
|
| 272 |
+
print(f"Processing document: {file_path}")
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
# Extract text from document
|
| 276 |
+
text = self.document_processor.process_document(file_path)
|
| 277 |
+
|
| 278 |
+
if not text.strip():
|
| 279 |
+
return {
|
| 280 |
+
'error': 'No text could be extracted from the document',
|
| 281 |
+
'file_path': file_path
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
# Predict entities
|
| 285 |
+
entities = self.predict_entities(text)
|
| 286 |
+
|
| 287 |
+
# Post-process and structure data
|
| 288 |
+
structured_data = self.postprocess_entities(entities, text)
|
| 289 |
+
|
| 290 |
+
# Create result
|
| 291 |
+
result = {
|
| 292 |
+
'file_path': file_path,
|
| 293 |
+
'extracted_text': text[:500] + '...' if len(text) > 500 else text,
|
| 294 |
+
'entities': entities,
|
| 295 |
+
'structured_data': structured_data,
|
| 296 |
+
'processing_timestamp': datetime.now().isoformat(),
|
| 297 |
+
'model_path': self.model_path
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
print(f"Successfully processed {file_path}")
|
| 301 |
+
print(f" Found {len(entities)} entities")
|
| 302 |
+
print(f" Structured fields: {list(structured_data.keys())}")
|
| 303 |
+
|
| 304 |
+
return result
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
error_result = {
|
| 308 |
+
'error': str(e),
|
| 309 |
+
'file_path': file_path,
|
| 310 |
+
'processing_timestamp': datetime.now().isoformat()
|
| 311 |
+
}
|
| 312 |
+
print(f"Error processing {file_path}: {e}")
|
| 313 |
+
return error_result
|
| 314 |
+
|
| 315 |
+
def process_text_directly(self, text: str) -> Dict[str, Any]:
|
| 316 |
+
"""Process text directly without file operations."""
|
| 317 |
+
print("Processing text directly...")
|
| 318 |
+
|
| 319 |
+
try:
|
| 320 |
+
# Clean the text
|
| 321 |
+
cleaned_text = self.document_processor.clean_text(text)
|
| 322 |
+
|
| 323 |
+
# Predict entities
|
| 324 |
+
entities = self.predict_entities(cleaned_text)
|
| 325 |
+
|
| 326 |
+
# Post-process and structure data
|
| 327 |
+
structured_data = self.postprocess_entities(entities, cleaned_text)
|
| 328 |
+
|
| 329 |
+
# Create result
|
| 330 |
+
result = {
|
| 331 |
+
'original_text': text,
|
| 332 |
+
'cleaned_text': cleaned_text,
|
| 333 |
+
'entities': entities,
|
| 334 |
+
'structured_data': structured_data,
|
| 335 |
+
'processing_timestamp': datetime.now().isoformat(),
|
| 336 |
+
'model_path': self.model_path
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
print(f"Successfully processed text")
|
| 340 |
+
print(f" Found {len(entities)} entities")
|
| 341 |
+
print(f" Structured fields: {list(structured_data.keys())}")
|
| 342 |
+
|
| 343 |
+
return result
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
error_result = {
|
| 347 |
+
'error': str(e),
|
| 348 |
+
'original_text': text,
|
| 349 |
+
'processing_timestamp': datetime.now().isoformat()
|
| 350 |
+
}
|
| 351 |
+
print(f"Error processing text: {e}")
|
| 352 |
+
return error_result
|
| 353 |
+
|
| 354 |
+
def batch_process_documents(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
| 355 |
+
"""Process multiple documents in batch."""
|
| 356 |
+
print(f"Processing {len(file_paths)} documents...")
|
| 357 |
+
|
| 358 |
+
results = []
|
| 359 |
+
for i, file_path in enumerate(file_paths):
|
| 360 |
+
print(f"\nProcessing {i+1}/{len(file_paths)}: {Path(file_path).name}")
|
| 361 |
+
result = self.process_document(file_path)
|
| 362 |
+
results.append(result)
|
| 363 |
+
|
| 364 |
+
print(f"\nBatch processing completed!")
|
| 365 |
+
print(f" Successfully processed: {sum(1 for r in results if 'error' not in r)}")
|
| 366 |
+
print(f" Errors: {sum(1 for r in results if 'error' in r)}")
|
| 367 |
+
|
| 368 |
+
return results
|
| 369 |
+
|
| 370 |
+
def save_results(self, results: List[Dict[str, Any]], output_path: str):
|
| 371 |
+
"""Save processing results to JSON file."""
|
| 372 |
+
output_path = Path(output_path)
|
| 373 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 374 |
+
|
| 375 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 376 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 377 |
+
|
| 378 |
+
print(f"Results saved to: {output_path}")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def create_demo_inference(model_path: str = "models/document_ner_model") -> DocumentInference:
|
| 382 |
+
"""Create inference pipeline for demonstration."""
|
| 383 |
+
try:
|
| 384 |
+
inference = DocumentInference(model_path)
|
| 385 |
+
return inference
|
| 386 |
+
except Exception as e:
|
| 387 |
+
print(f"Failed to create inference pipeline: {e}")
|
| 388 |
+
print("Make sure you have trained the model first by running training_pipeline.py")
|
| 389 |
+
raise
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def demo_text_extraction():
|
| 393 |
+
"""Demonstrate text extraction with sample texts."""
|
| 394 |
+
print("DOCUMENT TEXT EXTRACTION - INFERENCE DEMO")
|
| 395 |
+
print("=" * 60)
|
| 396 |
+
|
| 397 |
+
# Sample texts for demonstration
|
| 398 |
+
sample_texts = [
|
| 399 |
+
"Invoice sent to Robert White on 15/09/2025 Invoice No: INV-1024 Amount: $1,250",
|
| 400 |
+
"Bill for Dr. Sarah Johnson dated March 10, 2025. Invoice Number: BL-2045. Total: $2,300.50 Phone: (555) 123-4567",
|
| 401 |
+
"Receipt for Michael Brown 456 Oak Street Boston MA Email: michael@email.com Invoice: REC-3089 Date: 2025-04-22 Amount: $890.75"
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
# Create inference pipeline
|
| 405 |
+
try:
|
| 406 |
+
inference = create_demo_inference()
|
| 407 |
+
|
| 408 |
+
results = []
|
| 409 |
+
for i, text in enumerate(sample_texts):
|
| 410 |
+
print(f"\nProcessing Sample Text {i+1}:")
|
| 411 |
+
print("-" * 40)
|
| 412 |
+
print(f"Text: {text}")
|
| 413 |
+
|
| 414 |
+
result = inference.process_text_directly(text)
|
| 415 |
+
results.append(result)
|
| 416 |
+
|
| 417 |
+
if 'error' not in result:
|
| 418 |
+
print(f"Structured Output: {json.dumps(result['structured_data'], indent=2)}")
|
| 419 |
+
else:
|
| 420 |
+
print(f"Error: {result['error']}")
|
| 421 |
+
|
| 422 |
+
# Save results
|
| 423 |
+
inference.save_results(results, "results/demo_extraction_results.json")
|
| 424 |
+
|
| 425 |
+
print("\nDemo completed successfully!")
|
| 426 |
+
|
| 427 |
+
except Exception as e:
|
| 428 |
+
print(f"Demo failed: {e}")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def main():
|
| 432 |
+
"""Main function for inference demonstration."""
|
| 433 |
+
demo_text_extraction()
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == "__main__":
|
| 437 |
+
main()
|
src/model.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Small Language Model (SLM) architecture for document text extraction.
|
| 3 |
+
Uses DistilBERT with transfer learning for Named Entity Recognition.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from transformers import (
|
| 10 |
+
DistilBertTokenizer,
|
| 11 |
+
DistilBertForTokenClassification,
|
| 12 |
+
DistilBertConfig,
|
| 13 |
+
get_linear_schedule_with_warmup
|
| 14 |
+
)
|
| 15 |
+
from typing import List, Dict, Tuple, Optional
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ModelConfig:
|
| 24 |
+
"""Configuration for the SLM model."""
|
| 25 |
+
model_name: str = "distilbert-base-uncased"
|
| 26 |
+
max_length: int = 512
|
| 27 |
+
batch_size: int = 16
|
| 28 |
+
learning_rate: float = 2e-5
|
| 29 |
+
num_epochs: int = 3
|
| 30 |
+
warmup_steps: int = 500
|
| 31 |
+
weight_decay: float = 0.01
|
| 32 |
+
dropout_rate: float = 0.3
|
| 33 |
+
|
| 34 |
+
# Entity labels
|
| 35 |
+
entity_labels: List[str] = None
|
| 36 |
+
|
| 37 |
+
def __post_init__(self):
|
| 38 |
+
if self.entity_labels is None:
|
| 39 |
+
self.entity_labels = [
|
| 40 |
+
'O', 'B-NAME', 'I-NAME', 'B-DATE', 'I-DATE',
|
| 41 |
+
'B-INVOICE_NO', 'I-INVOICE_NO', 'B-AMOUNT', 'I-AMOUNT',
|
| 42 |
+
'B-ADDRESS', 'I-ADDRESS', 'B-PHONE', 'I-PHONE',
|
| 43 |
+
'B-EMAIL', 'I-EMAIL'
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def num_labels(self) -> int:
|
| 48 |
+
return len(self.entity_labels)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def label2id(self) -> Dict[str, int]:
|
| 52 |
+
return {label: i for i, label in enumerate(self.entity_labels)}
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def id2label(self) -> Dict[int, str]:
|
| 56 |
+
return {i: label for i, label in enumerate(self.entity_labels)}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class NERDataset(Dataset):
|
| 60 |
+
"""PyTorch Dataset for NER training."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, dataset: List[Dict], tokenizer: DistilBertTokenizer,
|
| 63 |
+
config: ModelConfig, mode: str = 'train'):
|
| 64 |
+
self.dataset = dataset
|
| 65 |
+
self.tokenizer = tokenizer
|
| 66 |
+
self.config = config
|
| 67 |
+
self.mode = mode
|
| 68 |
+
|
| 69 |
+
# Prepare tokenized data
|
| 70 |
+
self.tokenized_data = self._tokenize_and_align_labels()
|
| 71 |
+
|
| 72 |
+
def _tokenize_and_align_labels(self) -> List[Dict]:
|
| 73 |
+
"""Tokenize text and align labels with subword tokens."""
|
| 74 |
+
tokenized_data = []
|
| 75 |
+
|
| 76 |
+
for example in self.dataset:
|
| 77 |
+
tokens = example['tokens']
|
| 78 |
+
labels = example['labels']
|
| 79 |
+
|
| 80 |
+
# Tokenize each word and track alignments
|
| 81 |
+
tokenized_inputs = self.tokenizer(
|
| 82 |
+
tokens,
|
| 83 |
+
is_split_into_words=True,
|
| 84 |
+
padding='max_length',
|
| 85 |
+
truncation=True,
|
| 86 |
+
max_length=self.config.max_length,
|
| 87 |
+
return_tensors='pt'
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Align labels with subword tokens
|
| 91 |
+
word_ids = tokenized_inputs.word_ids()
|
| 92 |
+
aligned_labels = []
|
| 93 |
+
previous_word_idx = None
|
| 94 |
+
|
| 95 |
+
for word_idx in word_ids:
|
| 96 |
+
if word_idx is None:
|
| 97 |
+
# Special tokens get -100 (ignored in loss computation)
|
| 98 |
+
aligned_labels.append(-100)
|
| 99 |
+
elif word_idx != previous_word_idx:
|
| 100 |
+
# First subword of a word gets the original label
|
| 101 |
+
if word_idx < len(labels):
|
| 102 |
+
label = labels[word_idx]
|
| 103 |
+
aligned_labels.append(self.config.label2id.get(label, 0))
|
| 104 |
+
else:
|
| 105 |
+
aligned_labels.append(-100)
|
| 106 |
+
else:
|
| 107 |
+
# Subsequent subwords of the same word
|
| 108 |
+
if word_idx < len(labels):
|
| 109 |
+
label = labels[word_idx]
|
| 110 |
+
if label.startswith('B-'):
|
| 111 |
+
# Convert B- to I- for subword tokens
|
| 112 |
+
i_label = label.replace('B-', 'I-')
|
| 113 |
+
aligned_labels.append(self.config.label2id.get(i_label, 0))
|
| 114 |
+
else:
|
| 115 |
+
aligned_labels.append(self.config.label2id.get(label, 0))
|
| 116 |
+
else:
|
| 117 |
+
aligned_labels.append(-100)
|
| 118 |
+
|
| 119 |
+
previous_word_idx = word_idx
|
| 120 |
+
|
| 121 |
+
tokenized_data.append({
|
| 122 |
+
'input_ids': tokenized_inputs['input_ids'].squeeze(),
|
| 123 |
+
'attention_mask': tokenized_inputs['attention_mask'].squeeze(),
|
| 124 |
+
'labels': torch.tensor(aligned_labels, dtype=torch.long),
|
| 125 |
+
'original_tokens': tokens,
|
| 126 |
+
'original_labels': labels
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
return tokenized_data
|
| 130 |
+
|
| 131 |
+
def __len__(self) -> int:
|
| 132 |
+
return len(self.tokenized_data)
|
| 133 |
+
|
| 134 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 135 |
+
return {
|
| 136 |
+
'input_ids': self.tokenized_data[idx]['input_ids'],
|
| 137 |
+
'attention_mask': self.tokenized_data[idx]['attention_mask'],
|
| 138 |
+
'labels': self.tokenized_data[idx]['labels']
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class DocumentNERModel(nn.Module):
|
| 143 |
+
"""DistilBERT-based model for document NER."""
|
| 144 |
+
|
| 145 |
+
def __init__(self, config: ModelConfig):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.config = config
|
| 148 |
+
|
| 149 |
+
# Load pre-trained DistilBERT configuration
|
| 150 |
+
bert_config = DistilBertConfig.from_pretrained(
|
| 151 |
+
config.model_name,
|
| 152 |
+
num_labels=config.num_labels,
|
| 153 |
+
id2label=config.id2label,
|
| 154 |
+
label2id=config.label2id,
|
| 155 |
+
dropout=config.dropout_rate,
|
| 156 |
+
attention_dropout=config.dropout_rate
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Initialize model with token classification head
|
| 160 |
+
self.model = DistilBertForTokenClassification.from_pretrained(
|
| 161 |
+
config.model_name,
|
| 162 |
+
config=bert_config
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Additional dropout layer for regularization
|
| 166 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 167 |
+
|
| 168 |
+
def forward(self, input_ids, attention_mask=None, labels=None):
|
| 169 |
+
"""Forward pass through the model."""
|
| 170 |
+
outputs = self.model(
|
| 171 |
+
input_ids=input_ids,
|
| 172 |
+
attention_mask=attention_mask,
|
| 173 |
+
labels=labels
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return outputs
|
| 177 |
+
|
| 178 |
+
def predict(self, input_ids, attention_mask):
|
| 179 |
+
"""Make predictions without computing loss."""
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
outputs = self.model(
|
| 182 |
+
input_ids=input_ids,
|
| 183 |
+
attention_mask=attention_mask
|
| 184 |
+
)
|
| 185 |
+
predictions = torch.argmax(outputs.logits, dim=-1)
|
| 186 |
+
probabilities = torch.softmax(outputs.logits, dim=-1)
|
| 187 |
+
|
| 188 |
+
return predictions, probabilities
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class NERTrainer:
|
| 192 |
+
"""Trainer class for the NER model."""
|
| 193 |
+
|
| 194 |
+
def __init__(self, model: DocumentNERModel, config: ModelConfig):
|
| 195 |
+
self.model = model
|
| 196 |
+
self.config = config
|
| 197 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 198 |
+
self.model.to(self.device)
|
| 199 |
+
|
| 200 |
+
# Initialize tokenizer
|
| 201 |
+
self.tokenizer = DistilBertTokenizer.from_pretrained(config.model_name)
|
| 202 |
+
|
| 203 |
+
def prepare_dataloaders(self, dataset: List[Dict],
|
| 204 |
+
test_size: float = 0.2) -> Tuple[DataLoader, DataLoader]:
|
| 205 |
+
"""Prepare training and validation dataloaders."""
|
| 206 |
+
# Split dataset
|
| 207 |
+
train_data, val_data = train_test_split(
|
| 208 |
+
dataset, test_size=test_size, random_state=42
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Create datasets
|
| 212 |
+
train_dataset = NERDataset(train_data, self.tokenizer, self.config, 'train')
|
| 213 |
+
val_dataset = NERDataset(val_data, self.tokenizer, self.config, 'val')
|
| 214 |
+
|
| 215 |
+
# Create dataloaders
|
| 216 |
+
train_dataloader = DataLoader(
|
| 217 |
+
train_dataset,
|
| 218 |
+
batch_size=self.config.batch_size,
|
| 219 |
+
shuffle=True
|
| 220 |
+
)
|
| 221 |
+
val_dataloader = DataLoader(
|
| 222 |
+
val_dataset,
|
| 223 |
+
batch_size=self.config.batch_size,
|
| 224 |
+
shuffle=False
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return train_dataloader, val_dataloader
|
| 228 |
+
|
| 229 |
+
def train(self, train_dataloader: DataLoader,
|
| 230 |
+
val_dataloader: DataLoader) -> Dict[str, List[float]]:
|
| 231 |
+
"""Train the NER model."""
|
| 232 |
+
# Initialize optimizer and scheduler
|
| 233 |
+
optimizer = torch.optim.AdamW(
|
| 234 |
+
self.model.parameters(),
|
| 235 |
+
lr=self.config.learning_rate,
|
| 236 |
+
weight_decay=self.config.weight_decay
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
total_steps = len(train_dataloader) * self.config.num_epochs
|
| 240 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 241 |
+
optimizer,
|
| 242 |
+
num_warmup_steps=self.config.warmup_steps,
|
| 243 |
+
num_training_steps=total_steps
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Training history
|
| 247 |
+
history = {
|
| 248 |
+
'train_loss': [],
|
| 249 |
+
'val_loss': [],
|
| 250 |
+
'val_accuracy': []
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
print(f"Training on device: {self.device}")
|
| 254 |
+
print(f"Total training steps: {total_steps}")
|
| 255 |
+
|
| 256 |
+
for epoch in range(self.config.num_epochs):
|
| 257 |
+
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
|
| 258 |
+
print("-" * 50)
|
| 259 |
+
|
| 260 |
+
# Training phase
|
| 261 |
+
train_loss = self._train_epoch(train_dataloader, optimizer, scheduler)
|
| 262 |
+
history['train_loss'].append(train_loss)
|
| 263 |
+
|
| 264 |
+
# Validation phase
|
| 265 |
+
val_loss, val_accuracy = self._validate_epoch(val_dataloader)
|
| 266 |
+
history['val_loss'].append(val_loss)
|
| 267 |
+
history['val_accuracy'].append(val_accuracy)
|
| 268 |
+
|
| 269 |
+
print(f"Train Loss: {train_loss:.4f}")
|
| 270 |
+
print(f"Val Loss: {val_loss:.4f}")
|
| 271 |
+
print(f"Val Accuracy: {val_accuracy:.4f}")
|
| 272 |
+
|
| 273 |
+
return history
|
| 274 |
+
|
| 275 |
+
def _train_epoch(self, dataloader: DataLoader, optimizer, scheduler) -> float:
|
| 276 |
+
"""Train for one epoch."""
|
| 277 |
+
self.model.train()
|
| 278 |
+
total_loss = 0
|
| 279 |
+
|
| 280 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 281 |
+
# Move batch to device
|
| 282 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 283 |
+
|
| 284 |
+
# Forward pass
|
| 285 |
+
outputs = self.model(**batch)
|
| 286 |
+
loss = outputs.loss
|
| 287 |
+
|
| 288 |
+
# Backward pass
|
| 289 |
+
optimizer.zero_grad()
|
| 290 |
+
loss.backward()
|
| 291 |
+
|
| 292 |
+
# Gradient clipping
|
| 293 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 294 |
+
|
| 295 |
+
optimizer.step()
|
| 296 |
+
scheduler.step()
|
| 297 |
+
|
| 298 |
+
total_loss += loss.item()
|
| 299 |
+
|
| 300 |
+
if batch_idx % 10 == 0:
|
| 301 |
+
print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
| 302 |
+
|
| 303 |
+
return total_loss / len(dataloader)
|
| 304 |
+
|
| 305 |
+
def _validate_epoch(self, dataloader: DataLoader) -> Tuple[float, float]:
|
| 306 |
+
"""Validate for one epoch."""
|
| 307 |
+
self.model.eval()
|
| 308 |
+
total_loss = 0
|
| 309 |
+
total_correct = 0
|
| 310 |
+
total_tokens = 0
|
| 311 |
+
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
for batch in dataloader:
|
| 314 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 315 |
+
|
| 316 |
+
outputs = self.model(**batch)
|
| 317 |
+
loss = outputs.loss
|
| 318 |
+
|
| 319 |
+
total_loss += loss.item()
|
| 320 |
+
|
| 321 |
+
# Calculate accuracy (ignoring -100 labels)
|
| 322 |
+
predictions = torch.argmax(outputs.logits, dim=-1)
|
| 323 |
+
labels = batch['labels']
|
| 324 |
+
|
| 325 |
+
# Mask for valid labels (not -100)
|
| 326 |
+
valid_mask = labels != -100
|
| 327 |
+
|
| 328 |
+
correct = (predictions == labels) & valid_mask
|
| 329 |
+
total_correct += correct.sum().item()
|
| 330 |
+
total_tokens += valid_mask.sum().item()
|
| 331 |
+
|
| 332 |
+
avg_loss = total_loss / len(dataloader)
|
| 333 |
+
accuracy = total_correct / total_tokens if total_tokens > 0 else 0
|
| 334 |
+
|
| 335 |
+
return avg_loss, accuracy
|
| 336 |
+
|
| 337 |
+
def save_model(self, save_path: str):
|
| 338 |
+
"""Save the trained model and tokenizer."""
|
| 339 |
+
self.model.model.save_pretrained(save_path)
|
| 340 |
+
self.tokenizer.save_pretrained(save_path)
|
| 341 |
+
|
| 342 |
+
# Save config
|
| 343 |
+
config_path = f"{save_path}/training_config.json"
|
| 344 |
+
with open(config_path, 'w') as f:
|
| 345 |
+
json.dump(vars(self.config), f, indent=2)
|
| 346 |
+
|
| 347 |
+
print(f"Model saved to {save_path}")
|
| 348 |
+
|
| 349 |
+
def load_model(self, model_path: str):
|
| 350 |
+
"""Load a pre-trained model."""
|
| 351 |
+
self.model.model = DistilBertForTokenClassification.from_pretrained(model_path)
|
| 352 |
+
self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
| 353 |
+
self.model.to(self.device)
|
| 354 |
+
print(f"Model loaded from {model_path}")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def create_model_and_trainer(config: Optional[ModelConfig] = None) -> Tuple[DocumentNERModel, NERTrainer]:
|
| 358 |
+
"""Create model and trainer with configuration."""
|
| 359 |
+
if config is None:
|
| 360 |
+
config = ModelConfig()
|
| 361 |
+
|
| 362 |
+
model = DocumentNERModel(config)
|
| 363 |
+
trainer = NERTrainer(model, config)
|
| 364 |
+
|
| 365 |
+
return model, trainer
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def main():
|
| 369 |
+
"""Demonstrate model creation and setup."""
|
| 370 |
+
# Create configuration
|
| 371 |
+
config = ModelConfig(
|
| 372 |
+
batch_size=8, # Smaller batch size for demo
|
| 373 |
+
num_epochs=2,
|
| 374 |
+
learning_rate=3e-5
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
print("Model Configuration:")
|
| 378 |
+
print(f"Model: {config.model_name}")
|
| 379 |
+
print(f"Max Length: {config.max_length}")
|
| 380 |
+
print(f"Batch Size: {config.batch_size}")
|
| 381 |
+
print(f"Learning Rate: {config.learning_rate}")
|
| 382 |
+
print(f"Number of Labels: {config.num_labels}")
|
| 383 |
+
print(f"Entity Labels: {config.entity_labels}")
|
| 384 |
+
|
| 385 |
+
# Create model and trainer
|
| 386 |
+
model, trainer = create_model_and_trainer(config)
|
| 387 |
+
|
| 388 |
+
print(f"\nModel created successfully!")
|
| 389 |
+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 390 |
+
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 391 |
+
|
| 392 |
+
return model, trainer
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
if __name__ == "__main__":
|
| 396 |
+
main()
|
src/training_pipeline.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Complete training pipeline for document text extraction using SLM.
|
| 3 |
+
Handles data loading, model training, evaluation, and saving.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Optional
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import seaborn as sns
|
| 13 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 14 |
+
import numpy as np
|
| 15 |
+
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report as seq_classification_report
|
| 16 |
+
|
| 17 |
+
from src.data_preparation import DocumentProcessor, NERDatasetCreator
|
| 18 |
+
from src.model import DocumentNERModel, NERTrainer, ModelConfig, create_model_and_trainer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TrainingPipeline:
|
| 22 |
+
"""Complete training pipeline for document NER."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: Optional[ModelConfig] = None):
|
| 25 |
+
"""Initialize training pipeline."""
|
| 26 |
+
self.config = config or ModelConfig()
|
| 27 |
+
self.model = None
|
| 28 |
+
self.trainer = None
|
| 29 |
+
self.history = {}
|
| 30 |
+
|
| 31 |
+
# Create necessary directories
|
| 32 |
+
self._create_directories()
|
| 33 |
+
|
| 34 |
+
def _create_directories(self):
|
| 35 |
+
"""Create necessary directories for training."""
|
| 36 |
+
directories = [
|
| 37 |
+
"data/raw",
|
| 38 |
+
"data/processed",
|
| 39 |
+
"models",
|
| 40 |
+
"results/plots",
|
| 41 |
+
"results/metrics"
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
for directory in directories:
|
| 45 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
def prepare_data(self, data_path: Optional[str] = None) -> List[Dict]:
|
| 48 |
+
"""Prepare training data from documents or create sample data."""
|
| 49 |
+
print("=" * 60)
|
| 50 |
+
print("STEP 1: DATA PREPARATION")
|
| 51 |
+
print("=" * 60)
|
| 52 |
+
|
| 53 |
+
# Initialize document processor and dataset creator
|
| 54 |
+
processor = DocumentProcessor()
|
| 55 |
+
dataset_creator = NERDatasetCreator(processor)
|
| 56 |
+
|
| 57 |
+
# Process documents or create sample data
|
| 58 |
+
if data_path and Path(data_path).exists():
|
| 59 |
+
print(f"Processing documents from: {data_path}")
|
| 60 |
+
dataset = dataset_creator.process_documents_folder(data_path)
|
| 61 |
+
else:
|
| 62 |
+
print("No document path provided or path doesn't exist.")
|
| 63 |
+
print("Creating sample dataset for demonstration...")
|
| 64 |
+
dataset = dataset_creator.create_sample_dataset()
|
| 65 |
+
|
| 66 |
+
# Save processed dataset
|
| 67 |
+
output_path = "data/processed/ner_dataset.json"
|
| 68 |
+
dataset_creator.save_dataset(dataset, output_path)
|
| 69 |
+
|
| 70 |
+
print(f"Data preparation completed!")
|
| 71 |
+
print(f"Dataset saved to: {output_path}")
|
| 72 |
+
print(f"Total examples: {len(dataset)}")
|
| 73 |
+
|
| 74 |
+
return dataset
|
| 75 |
+
|
| 76 |
+
def initialize_model(self):
|
| 77 |
+
"""Initialize model and trainer."""
|
| 78 |
+
print("\n" + "=" * 60)
|
| 79 |
+
print("STEP 2: MODEL INITIALIZATION")
|
| 80 |
+
print("=" * 60)
|
| 81 |
+
|
| 82 |
+
self.model, self.trainer = create_model_and_trainer(self.config)
|
| 83 |
+
|
| 84 |
+
print(f"Model initialized: {self.config.model_name}")
|
| 85 |
+
print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
| 86 |
+
print(f"Device: {self.trainer.device}")
|
| 87 |
+
print(f"Number of entity labels: {self.config.num_labels}")
|
| 88 |
+
|
| 89 |
+
return self.model, self.trainer
|
| 90 |
+
|
| 91 |
+
def train_model(self, dataset: List[Dict]) -> Dict[str, List[float]]:
|
| 92 |
+
"""Train the NER model."""
|
| 93 |
+
print("\n" + "=" * 60)
|
| 94 |
+
print("STEP 3: MODEL TRAINING")
|
| 95 |
+
print("=" * 60)
|
| 96 |
+
|
| 97 |
+
# Prepare dataloaders
|
| 98 |
+
print("Preparing training and validation data...")
|
| 99 |
+
train_dataloader, val_dataloader = self.trainer.prepare_dataloaders(dataset)
|
| 100 |
+
|
| 101 |
+
print(f"Training samples: {len(train_dataloader.dataset)}")
|
| 102 |
+
print(f"Validation samples: {len(val_dataloader.dataset)}")
|
| 103 |
+
print(f"Training batches: {len(train_dataloader)}")
|
| 104 |
+
print(f"Validation batches: {len(val_dataloader)}")
|
| 105 |
+
|
| 106 |
+
# Start training
|
| 107 |
+
print(f"\nStarting training for {self.config.num_epochs} epochs...")
|
| 108 |
+
self.history = self.trainer.train(train_dataloader, val_dataloader)
|
| 109 |
+
|
| 110 |
+
print(f"Training completed!")
|
| 111 |
+
return self.history
|
| 112 |
+
|
| 113 |
+
def evaluate_model(self, dataset: List[Dict]) -> Dict:
|
| 114 |
+
"""Evaluate the trained model."""
|
| 115 |
+
print("\n" + "=" * 60)
|
| 116 |
+
print("STEP 4: MODEL EVALUATION")
|
| 117 |
+
print("=" * 60)
|
| 118 |
+
|
| 119 |
+
# Prepare test data
|
| 120 |
+
_, test_dataloader = self.trainer.prepare_dataloaders(dataset, test_size=0.3)
|
| 121 |
+
|
| 122 |
+
# Evaluate
|
| 123 |
+
evaluation_results = self._detailed_evaluation(test_dataloader)
|
| 124 |
+
|
| 125 |
+
# Save evaluation results
|
| 126 |
+
results_path = "results/metrics/evaluation_results.json"
|
| 127 |
+
with open(results_path, 'w') as f:
|
| 128 |
+
json.dump(evaluation_results, f, indent=2)
|
| 129 |
+
|
| 130 |
+
print(f"Evaluation completed!")
|
| 131 |
+
print(f"Results saved to: {results_path}")
|
| 132 |
+
|
| 133 |
+
return evaluation_results
|
| 134 |
+
|
| 135 |
+
def _detailed_evaluation(self, test_dataloader) -> Dict:
|
| 136 |
+
"""Perform detailed evaluation of the model."""
|
| 137 |
+
self.model.eval()
|
| 138 |
+
|
| 139 |
+
all_predictions = []
|
| 140 |
+
all_labels = []
|
| 141 |
+
all_tokens = []
|
| 142 |
+
|
| 143 |
+
print("Running evaluation on test set...")
|
| 144 |
+
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for batch_idx, batch in enumerate(test_dataloader):
|
| 147 |
+
# Move to device
|
| 148 |
+
batch = {k: v.to(self.trainer.device) for k, v in batch.items()}
|
| 149 |
+
|
| 150 |
+
# Get predictions
|
| 151 |
+
predictions, probabilities = self.model.predict(
|
| 152 |
+
batch['input_ids'],
|
| 153 |
+
batch['attention_mask']
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Convert to numpy
|
| 157 |
+
pred_np = predictions.cpu().numpy()
|
| 158 |
+
labels_np = batch['labels'].cpu().numpy()
|
| 159 |
+
|
| 160 |
+
# Process each sequence in the batch
|
| 161 |
+
for i in range(pred_np.shape[0]):
|
| 162 |
+
pred_seq = []
|
| 163 |
+
label_seq = []
|
| 164 |
+
|
| 165 |
+
for j in range(pred_np.shape[1]):
|
| 166 |
+
if labels_np[i][j] != -100: # Valid label
|
| 167 |
+
pred_label = self.config.id2label[pred_np[i][j]]
|
| 168 |
+
true_label = self.config.id2label[labels_np[i][j]]
|
| 169 |
+
|
| 170 |
+
pred_seq.append(pred_label)
|
| 171 |
+
label_seq.append(true_label)
|
| 172 |
+
|
| 173 |
+
if pred_seq and label_seq: # Non-empty sequences
|
| 174 |
+
all_predictions.append(pred_seq)
|
| 175 |
+
all_labels.append(label_seq)
|
| 176 |
+
|
| 177 |
+
print(f"Processed {len(all_predictions)} sequences")
|
| 178 |
+
|
| 179 |
+
# Calculate metrics using seqeval
|
| 180 |
+
f1 = f1_score(all_labels, all_predictions)
|
| 181 |
+
precision = precision_score(all_labels, all_predictions)
|
| 182 |
+
recall = recall_score(all_labels, all_predictions)
|
| 183 |
+
|
| 184 |
+
# Detailed classification report
|
| 185 |
+
report = seq_classification_report(all_labels, all_predictions)
|
| 186 |
+
|
| 187 |
+
evaluation_results = {
|
| 188 |
+
'f1_score': f1,
|
| 189 |
+
'precision': precision,
|
| 190 |
+
'recall': recall,
|
| 191 |
+
'classification_report': report,
|
| 192 |
+
'num_test_sequences': len(all_predictions)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Print results
|
| 196 |
+
print(f"\nEvaluation Results:")
|
| 197 |
+
print(f"F1 Score: {f1:.4f}")
|
| 198 |
+
print(f"Precision: {precision:.4f}")
|
| 199 |
+
print(f"Recall: {recall:.4f}")
|
| 200 |
+
print(f"\nDetailed Classification Report:")
|
| 201 |
+
print(report)
|
| 202 |
+
|
| 203 |
+
return evaluation_results
|
| 204 |
+
|
| 205 |
+
def plot_training_history(self):
|
| 206 |
+
"""Plot training history."""
|
| 207 |
+
if not self.history:
|
| 208 |
+
print("No training history available.")
|
| 209 |
+
return
|
| 210 |
+
|
| 211 |
+
print("\n" + "=" * 60)
|
| 212 |
+
print("STEP 5: PLOTTING TRAINING HISTORY")
|
| 213 |
+
print("=" * 60)
|
| 214 |
+
|
| 215 |
+
# Create plots
|
| 216 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
|
| 217 |
+
|
| 218 |
+
# Loss plot
|
| 219 |
+
epochs = range(1, len(self.history['train_loss']) + 1)
|
| 220 |
+
axes[0].plot(epochs, self.history['train_loss'], 'b-', label='Training Loss')
|
| 221 |
+
axes[0].plot(epochs, self.history['val_loss'], 'r-', label='Validation Loss')
|
| 222 |
+
axes[0].set_title('Model Loss')
|
| 223 |
+
axes[0].set_xlabel('Epoch')
|
| 224 |
+
axes[0].set_ylabel('Loss')
|
| 225 |
+
axes[0].legend()
|
| 226 |
+
axes[0].grid(True)
|
| 227 |
+
|
| 228 |
+
# Accuracy plot
|
| 229 |
+
axes[1].plot(epochs, self.history['val_accuracy'], 'g-', label='Validation Accuracy')
|
| 230 |
+
axes[1].set_title('Model Accuracy')
|
| 231 |
+
axes[1].set_xlabel('Epoch')
|
| 232 |
+
axes[1].set_ylabel('Accuracy')
|
| 233 |
+
axes[1].legend()
|
| 234 |
+
axes[1].grid(True)
|
| 235 |
+
|
| 236 |
+
plt.tight_layout()
|
| 237 |
+
|
| 238 |
+
# Save plot
|
| 239 |
+
plot_path = "results/plots/training_history.png"
|
| 240 |
+
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
| 241 |
+
plt.close()
|
| 242 |
+
|
| 243 |
+
print(f"Training history plot saved to: {plot_path}")
|
| 244 |
+
|
| 245 |
+
def save_model(self, model_name: str = "document_ner_model"):
|
| 246 |
+
"""Save the trained model."""
|
| 247 |
+
print("\n" + "=" * 60)
|
| 248 |
+
print("STEP 6: SAVING MODEL")
|
| 249 |
+
print("=" * 60)
|
| 250 |
+
|
| 251 |
+
save_path = f"models/{model_name}"
|
| 252 |
+
self.trainer.save_model(save_path)
|
| 253 |
+
|
| 254 |
+
# Save training history
|
| 255 |
+
history_path = f"{save_path}/training_history.json"
|
| 256 |
+
with open(history_path, 'w') as f:
|
| 257 |
+
json.dump(self.history, f, indent=2)
|
| 258 |
+
|
| 259 |
+
print(f"Model saved to: {save_path}")
|
| 260 |
+
print(f"Training history saved to: {history_path}")
|
| 261 |
+
|
| 262 |
+
return save_path
|
| 263 |
+
|
| 264 |
+
def run_complete_pipeline(self, data_path: Optional[str] = None,
|
| 265 |
+
model_name: str = "document_ner_model") -> str:
|
| 266 |
+
"""Run the complete training pipeline."""
|
| 267 |
+
print("STARTING COMPLETE TRAINING PIPELINE")
|
| 268 |
+
print("=" * 80)
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
# Step 1: Prepare data
|
| 272 |
+
dataset = self.prepare_data(data_path)
|
| 273 |
+
|
| 274 |
+
# Step 2: Initialize model
|
| 275 |
+
self.initialize_model()
|
| 276 |
+
|
| 277 |
+
# Step 3: Train model
|
| 278 |
+
self.train_model(dataset)
|
| 279 |
+
|
| 280 |
+
# Step 4: Evaluate model
|
| 281 |
+
self.evaluate_model(dataset)
|
| 282 |
+
|
| 283 |
+
# Step 5: Plot training history
|
| 284 |
+
self.plot_training_history()
|
| 285 |
+
|
| 286 |
+
# Step 6: Save model
|
| 287 |
+
model_path = self.save_model(model_name)
|
| 288 |
+
|
| 289 |
+
print("\n" + "=" * 20)
|
| 290 |
+
print("TRAINING PIPELINE COMPLETED SUCCESSFULLY!")
|
| 291 |
+
print("=" * 20)
|
| 292 |
+
print(f"Model saved to: {model_path}")
|
| 293 |
+
print(f"Training completed in {self.config.num_epochs} epochs")
|
| 294 |
+
print(f"Final validation accuracy: {self.history['val_accuracy'][-1]:.4f}")
|
| 295 |
+
|
| 296 |
+
return model_path
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
print(f"\nError in training pipeline: {e}")
|
| 300 |
+
raise
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def create_custom_config() -> ModelConfig:
|
| 304 |
+
"""Create a custom configuration for training."""
|
| 305 |
+
config = ModelConfig(
|
| 306 |
+
model_name="distilbert-base-uncased",
|
| 307 |
+
max_length=256, # Shorter sequences for faster training
|
| 308 |
+
batch_size=16, # Adjust based on your GPU memory
|
| 309 |
+
learning_rate=2e-5,
|
| 310 |
+
num_epochs=3,
|
| 311 |
+
warmup_steps=500,
|
| 312 |
+
weight_decay=0.01,
|
| 313 |
+
dropout_rate=0.1
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return config
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def main():
|
| 320 |
+
"""Main function to run the complete training pipeline."""
|
| 321 |
+
print("Document Text Extraction - Training Pipeline")
|
| 322 |
+
print("=" * 50)
|
| 323 |
+
|
| 324 |
+
# Create custom configuration
|
| 325 |
+
config = create_custom_config()
|
| 326 |
+
|
| 327 |
+
# Initialize training pipeline
|
| 328 |
+
pipeline = TrainingPipeline(config)
|
| 329 |
+
|
| 330 |
+
# Run complete pipeline
|
| 331 |
+
# You can provide a path to your document folder here
|
| 332 |
+
# pipeline.run_complete_pipeline(data_path="data/raw")
|
| 333 |
+
|
| 334 |
+
# For demonstration, we'll use sample data
|
| 335 |
+
model_path = pipeline.run_complete_pipeline()
|
| 336 |
+
|
| 337 |
+
print(f"\nTraining completed! Model saved to: {model_path}")
|
| 338 |
+
print("You can now use this model for document text extraction!")
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if __name__ == "__main__":
|
| 342 |
+
main()
|
tests/test_extraction.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test cases for the document text extraction system.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import unittest
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import tempfile
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from src.data_preparation import DocumentProcessor, NERDatasetCreator
|
| 12 |
+
from src.model import ModelConfig, create_model_and_trainer
|
| 13 |
+
from src.inference import DocumentInference
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestDocumentProcessor(unittest.TestCase):
|
| 17 |
+
"""Test cases for document processing."""
|
| 18 |
+
|
| 19 |
+
def setUp(self):
|
| 20 |
+
"""Set up test fixtures."""
|
| 21 |
+
self.processor = DocumentProcessor()
|
| 22 |
+
|
| 23 |
+
def test_clean_text(self):
|
| 24 |
+
"""Test text cleaning functionality."""
|
| 25 |
+
dirty_text = " This is a test text!!! "
|
| 26 |
+
clean_text = self.processor.clean_text(dirty_text)
|
| 27 |
+
self.assertEqual(clean_text, "This is a test text!")
|
| 28 |
+
|
| 29 |
+
def test_entity_patterns(self):
|
| 30 |
+
"""Test entity pattern matching."""
|
| 31 |
+
test_text = "Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
| 32 |
+
|
| 33 |
+
# Test that patterns exist
|
| 34 |
+
self.assertIn('NAME', self.processor.entity_patterns)
|
| 35 |
+
self.assertIn('DATE', self.processor.entity_patterns)
|
| 36 |
+
self.assertIn('INVOICE_NO', self.processor.entity_patterns)
|
| 37 |
+
self.assertIn('AMOUNT', self.processor.entity_patterns)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TestNERDatasetCreator(unittest.TestCase):
|
| 41 |
+
"""Test cases for NER dataset creation."""
|
| 42 |
+
|
| 43 |
+
def setUp(self):
|
| 44 |
+
"""Set up test fixtures."""
|
| 45 |
+
self.processor = DocumentProcessor()
|
| 46 |
+
self.dataset_creator = NERDatasetCreator(self.processor)
|
| 47 |
+
|
| 48 |
+
def test_auto_label_text(self):
|
| 49 |
+
"""Test automatic text labeling."""
|
| 50 |
+
test_text = "Invoice sent to Robert White on 15/09/2025 Amount: $1,250"
|
| 51 |
+
labeled_tokens = self.dataset_creator.auto_label_text(test_text)
|
| 52 |
+
|
| 53 |
+
# Check that we get tokens and labels
|
| 54 |
+
self.assertIsInstance(labeled_tokens, list)
|
| 55 |
+
self.assertGreater(len(labeled_tokens), 0)
|
| 56 |
+
|
| 57 |
+
# Check that each item is a (token, label) tuple
|
| 58 |
+
for token, label in labeled_tokens:
|
| 59 |
+
self.assertIsInstance(token, str)
|
| 60 |
+
self.assertIsInstance(label, str)
|
| 61 |
+
|
| 62 |
+
def test_create_training_example(self):
|
| 63 |
+
"""Test training example creation."""
|
| 64 |
+
test_text = "Invoice INV-1001 for $500"
|
| 65 |
+
example = self.dataset_creator.create_training_example(test_text)
|
| 66 |
+
|
| 67 |
+
# Check required fields
|
| 68 |
+
self.assertIn('tokens', example)
|
| 69 |
+
self.assertIn('labels', example)
|
| 70 |
+
self.assertIn('text', example)
|
| 71 |
+
|
| 72 |
+
# Check that tokens and labels have the same length
|
| 73 |
+
self.assertEqual(len(example['tokens']), len(example['labels']))
|
| 74 |
+
|
| 75 |
+
def test_create_sample_dataset(self):
|
| 76 |
+
"""Test sample dataset creation."""
|
| 77 |
+
dataset = self.dataset_creator.create_sample_dataset()
|
| 78 |
+
|
| 79 |
+
# Check that we get a non-empty dataset
|
| 80 |
+
self.assertIsInstance(dataset, list)
|
| 81 |
+
self.assertGreater(len(dataset), 0)
|
| 82 |
+
|
| 83 |
+
# Check first example structure
|
| 84 |
+
first_example = dataset[0]
|
| 85 |
+
self.assertIn('tokens', first_example)
|
| 86 |
+
self.assertIn('labels', first_example)
|
| 87 |
+
self.assertIn('text', first_example)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class TestModelConfig(unittest.TestCase):
|
| 91 |
+
"""Test cases for model configuration."""
|
| 92 |
+
|
| 93 |
+
def test_default_config(self):
|
| 94 |
+
"""Test default configuration creation."""
|
| 95 |
+
config = ModelConfig()
|
| 96 |
+
|
| 97 |
+
# Check default values
|
| 98 |
+
self.assertEqual(config.model_name, "distilbert-base-uncased")
|
| 99 |
+
self.assertEqual(config.max_length, 512)
|
| 100 |
+
self.assertEqual(config.batch_size, 16)
|
| 101 |
+
|
| 102 |
+
# Check entity labels
|
| 103 |
+
self.assertIsInstance(config.entity_labels, list)
|
| 104 |
+
self.assertGreater(len(config.entity_labels), 0)
|
| 105 |
+
self.assertIn('O', config.entity_labels)
|
| 106 |
+
|
| 107 |
+
# Check label mappings
|
| 108 |
+
self.assertIsInstance(config.label2id, dict)
|
| 109 |
+
self.assertIsInstance(config.id2label, dict)
|
| 110 |
+
self.assertEqual(len(config.label2id), len(config.entity_labels))
|
| 111 |
+
|
| 112 |
+
def test_custom_config(self):
|
| 113 |
+
"""Test custom configuration."""
|
| 114 |
+
custom_labels = ['O', 'B-TEST', 'I-TEST']
|
| 115 |
+
config = ModelConfig(
|
| 116 |
+
batch_size=32,
|
| 117 |
+
learning_rate=1e-5,
|
| 118 |
+
entity_labels=custom_labels
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.assertEqual(config.batch_size, 32)
|
| 122 |
+
self.assertEqual(config.learning_rate, 1e-5)
|
| 123 |
+
self.assertEqual(config.entity_labels, custom_labels)
|
| 124 |
+
self.assertEqual(config.num_labels, 3)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TestModelCreation(unittest.TestCase):
|
| 128 |
+
"""Test cases for model creation."""
|
| 129 |
+
|
| 130 |
+
def test_create_model_and_trainer(self):
|
| 131 |
+
"""Test model and trainer creation."""
|
| 132 |
+
config = ModelConfig(
|
| 133 |
+
batch_size=4, # Small batch for testing
|
| 134 |
+
num_epochs=1,
|
| 135 |
+
entity_labels=['O', 'B-TEST', 'I-TEST']
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
model, trainer = create_model_and_trainer(config)
|
| 139 |
+
|
| 140 |
+
# Check that objects are created
|
| 141 |
+
self.assertIsNotNone(model)
|
| 142 |
+
self.assertIsNotNone(trainer)
|
| 143 |
+
|
| 144 |
+
# Check configuration
|
| 145 |
+
self.assertEqual(trainer.config.batch_size, 4)
|
| 146 |
+
self.assertEqual(trainer.config.num_epochs, 1)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class TestInference(unittest.TestCase):
|
| 150 |
+
"""Test cases for inference pipeline."""
|
| 151 |
+
|
| 152 |
+
@classmethod
|
| 153 |
+
def setUpClass(cls):
|
| 154 |
+
"""Set up class-level fixtures."""
|
| 155 |
+
# Create a minimal trained model for testing
|
| 156 |
+
# This is a placeholder - in real testing, you'd use a pre-trained model
|
| 157 |
+
cls.model_path = "test_model"
|
| 158 |
+
cls.test_text = "Invoice sent to John Doe on 01/15/2025 Amount: $500.00"
|
| 159 |
+
|
| 160 |
+
def test_entity_validation(self):
|
| 161 |
+
"""Test entity validation patterns."""
|
| 162 |
+
# We can test the patterns without loading a full model
|
| 163 |
+
test_patterns = {
|
| 164 |
+
'DATE': ['01/15/2025', '2025-01-15', 'January 15, 2025'],
|
| 165 |
+
'AMOUNT': ['$500.00', '$1,250.50', '1000.00 USD'],
|
| 166 |
+
'EMAIL': ['test@email.com', 'user.name@domain.co.uk'],
|
| 167 |
+
'PHONE': ['(555) 123-4567', '+1-555-987-6543', '555-123-4567']
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# This test checks that our regex patterns work
|
| 171 |
+
import re
|
| 172 |
+
|
| 173 |
+
date_pattern = r'\b\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}\b'
|
| 174 |
+
self.assertTrue(re.search(date_pattern, '01/15/2025'))
|
| 175 |
+
|
| 176 |
+
amount_pattern = r'\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?'
|
| 177 |
+
self.assertTrue(re.search(amount_pattern, '$1,250.50'))
|
| 178 |
+
|
| 179 |
+
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
| 180 |
+
self.assertTrue(re.search(email_pattern, 'test@email.com'))
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class TestEndToEnd(unittest.TestCase):
|
| 184 |
+
"""End-to-end integration tests."""
|
| 185 |
+
|
| 186 |
+
def test_data_preparation_flow(self):
|
| 187 |
+
"""Test the complete data preparation flow."""
|
| 188 |
+
# Create processor and dataset creator
|
| 189 |
+
processor = DocumentProcessor()
|
| 190 |
+
dataset_creator = NERDatasetCreator(processor)
|
| 191 |
+
|
| 192 |
+
# Create sample dataset
|
| 193 |
+
dataset = dataset_creator.create_sample_dataset()
|
| 194 |
+
|
| 195 |
+
# Verify dataset structure
|
| 196 |
+
self.assertIsInstance(dataset, list)
|
| 197 |
+
self.assertGreater(len(dataset), 0)
|
| 198 |
+
|
| 199 |
+
for example in dataset:
|
| 200 |
+
self.assertIn('tokens', example)
|
| 201 |
+
self.assertIn('labels', example)
|
| 202 |
+
self.assertIn('text', example)
|
| 203 |
+
self.assertEqual(len(example['tokens']), len(example['labels']))
|
| 204 |
+
|
| 205 |
+
def test_model_config_flow(self):
|
| 206 |
+
"""Test model configuration and creation flow."""
|
| 207 |
+
# Create configuration
|
| 208 |
+
config = ModelConfig(batch_size=4, num_epochs=1)
|
| 209 |
+
|
| 210 |
+
# Create model and trainer
|
| 211 |
+
model, trainer = create_model_and_trainer(config)
|
| 212 |
+
|
| 213 |
+
# Verify objects exist and have correct configuration
|
| 214 |
+
self.assertIsNotNone(model)
|
| 215 |
+
self.assertIsNotNone(trainer)
|
| 216 |
+
self.assertEqual(trainer.config.batch_size, 4)
|
| 217 |
+
self.assertEqual(trainer.config.num_epochs, 1)
|
| 218 |
+
|
| 219 |
+
def test_save_and_load_dataset(self):
|
| 220 |
+
"""Test saving and loading dataset."""
|
| 221 |
+
# Create dataset
|
| 222 |
+
processor = DocumentProcessor()
|
| 223 |
+
dataset_creator = NERDatasetCreator(processor)
|
| 224 |
+
dataset = dataset_creator.create_sample_dataset()
|
| 225 |
+
|
| 226 |
+
# Save to temporary file
|
| 227 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| 228 |
+
temp_path = f.name
|
| 229 |
+
json.dump(dataset, f, indent=2)
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
# Load and verify
|
| 233 |
+
with open(temp_path, 'r') as f:
|
| 234 |
+
loaded_dataset = json.load(f)
|
| 235 |
+
|
| 236 |
+
self.assertEqual(len(loaded_dataset), len(dataset))
|
| 237 |
+
self.assertEqual(loaded_dataset[0]['text'], dataset[0]['text'])
|
| 238 |
+
|
| 239 |
+
finally:
|
| 240 |
+
# Clean up
|
| 241 |
+
os.unlink(temp_path)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def run_tests():
|
| 245 |
+
"""Run all tests."""
|
| 246 |
+
print("Running Document Text Extraction Tests")
|
| 247 |
+
print("=" * 50)
|
| 248 |
+
|
| 249 |
+
# Create test suite
|
| 250 |
+
test_suite = unittest.TestSuite()
|
| 251 |
+
|
| 252 |
+
# Add test classes
|
| 253 |
+
test_classes = [
|
| 254 |
+
TestDocumentProcessor,
|
| 255 |
+
TestNERDatasetCreator,
|
| 256 |
+
TestModelConfig,
|
| 257 |
+
TestModelCreation,
|
| 258 |
+
TestInference,
|
| 259 |
+
TestEndToEnd
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
for test_class in test_classes:
|
| 263 |
+
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
| 264 |
+
test_suite.addTests(tests)
|
| 265 |
+
|
| 266 |
+
# Run tests
|
| 267 |
+
runner = unittest.TextTestRunner(verbosity=2)
|
| 268 |
+
result = runner.run(test_suite)
|
| 269 |
+
|
| 270 |
+
# Print summary
|
| 271 |
+
if result.wasSuccessful():
|
| 272 |
+
print(f"\nAll tests passed! ({result.testsRun} tests)")
|
| 273 |
+
else:
|
| 274 |
+
print(f"\n{len(result.failures)} failures, {len(result.errors)} errors")
|
| 275 |
+
|
| 276 |
+
if result.failures:
|
| 277 |
+
print("\nFailures:")
|
| 278 |
+
for test, failure in result.failures:
|
| 279 |
+
print(f" {test}: {failure}")
|
| 280 |
+
|
| 281 |
+
if result.errors:
|
| 282 |
+
print("\nErrors:")
|
| 283 |
+
for test, error in result.errors:
|
| 284 |
+
print(f" {test}: {error}")
|
| 285 |
+
|
| 286 |
+
return result.wasSuccessful()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
run_tests()
|