|
|
"""
|
|
|
FastAPI web service for document text extraction.
|
|
|
Provides REST API endpoints for uploading and processing documents.
|
|
|
"""
|
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
import uvicorn
|
|
|
import tempfile
|
|
|
import os
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from typing import List, Optional, Dict, Any
|
|
|
import shutil
|
|
|
|
|
|
from src.inference import DocumentInference
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="Document Text Extraction API",
|
|
|
description="Extract structured information from documents using Small Language Model (SLM)",
|
|
|
version="1.0.0"
|
|
|
)
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
inference_pipeline: Optional[DocumentInference] = None
|
|
|
|
|
|
def get_inference_pipeline() -> DocumentInference:
|
|
|
"""Get or initialize the inference pipeline."""
|
|
|
global inference_pipeline
|
|
|
|
|
|
if inference_pipeline is None:
|
|
|
model_path = "models/document_ner_model"
|
|
|
|
|
|
if not Path(model_path).exists():
|
|
|
raise HTTPException(
|
|
|
status_code=503,
|
|
|
detail="Model not found. Please train the model first by running training_pipeline.py"
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
inference_pipeline = DocumentInference(model_path)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(
|
|
|
status_code=503,
|
|
|
detail=f"Failed to load model: {str(e)}"
|
|
|
)
|
|
|
|
|
|
return inference_pipeline
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Initialize the model on startup."""
|
|
|
try:
|
|
|
get_inference_pipeline()
|
|
|
print("Model loaded successfully on startup")
|
|
|
except Exception as e:
|
|
|
print(f"Failed to load model on startup: {e}")
|
|
|
print("Model will be loaded on first request")
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def root():
|
|
|
"""Serve the main HTML interface."""
|
|
|
html_content = """
|
|
|
<!DOCTYPE html>
|
|
|
<html>
|
|
|
<head>
|
|
|
<title>Document Text Extraction</title>
|
|
|
<style>
|
|
|
body {
|
|
|
font-family: Arial, sans-serif;
|
|
|
max-width: 800px;
|
|
|
margin: 0 auto;
|
|
|
padding: 20px;
|
|
|
background-color: #f5f5f5;
|
|
|
}
|
|
|
.container {
|
|
|
background: white;
|
|
|
padding: 30px;
|
|
|
border-radius: 10px;
|
|
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
|
|
}
|
|
|
.header {
|
|
|
text-align: center;
|
|
|
color: #333;
|
|
|
margin-bottom: 30px;
|
|
|
}
|
|
|
.upload-area {
|
|
|
border: 2px dashed #ccc;
|
|
|
padding: 40px;
|
|
|
text-align: center;
|
|
|
margin: 20px 0;
|
|
|
border-radius: 8px;
|
|
|
background-color: #fafafa;
|
|
|
}
|
|
|
.upload-area:hover {
|
|
|
border-color: #007bff;
|
|
|
background-color: #f0f8ff;
|
|
|
}
|
|
|
.btn {
|
|
|
background-color: #007bff;
|
|
|
color: white;
|
|
|
padding: 10px 20px;
|
|
|
border: none;
|
|
|
border-radius: 5px;
|
|
|
cursor: pointer;
|
|
|
font-size: 16px;
|
|
|
}
|
|
|
.btn:hover {
|
|
|
background-color: #0056b3;
|
|
|
}
|
|
|
.result {
|
|
|
margin-top: 20px;
|
|
|
padding: 20px;
|
|
|
background-color: #f8f9fa;
|
|
|
border-radius: 5px;
|
|
|
border: 1px solid #dee2e6;
|
|
|
}
|
|
|
.json-output {
|
|
|
background-color: #f4f4f4;
|
|
|
padding: 15px;
|
|
|
border-radius: 5px;
|
|
|
font-family: monospace;
|
|
|
white-space: pre-wrap;
|
|
|
overflow-x: auto;
|
|
|
max-height: 400px;
|
|
|
overflow-y: auto;
|
|
|
}
|
|
|
.text-input {
|
|
|
width: 100%;
|
|
|
height: 100px;
|
|
|
padding: 10px;
|
|
|
border: 1px solid #ccc;
|
|
|
border-radius: 5px;
|
|
|
font-family: monospace;
|
|
|
resize: vertical;
|
|
|
}
|
|
|
.tab-container {
|
|
|
margin: 20px 0;
|
|
|
}
|
|
|
.tabs {
|
|
|
display: flex;
|
|
|
border-bottom: 1px solid #ccc;
|
|
|
}
|
|
|
.tab {
|
|
|
padding: 10px 20px;
|
|
|
cursor: pointer;
|
|
|
border-bottom: 2px solid transparent;
|
|
|
background-color: #f8f9fa;
|
|
|
margin-right: 5px;
|
|
|
}
|
|
|
.tab.active {
|
|
|
border-bottom-color: #007bff;
|
|
|
background-color: white;
|
|
|
}
|
|
|
.tab-content {
|
|
|
display: none;
|
|
|
padding: 20px 0;
|
|
|
}
|
|
|
.tab-content.active {
|
|
|
display: block;
|
|
|
}
|
|
|
</style>
|
|
|
</head>
|
|
|
<body>
|
|
|
<div class="container">
|
|
|
<div class="header">
|
|
|
<h1>Document Text Extraction</h1>
|
|
|
<p>Extract structured information from documents using AI</p>
|
|
|
</div>
|
|
|
|
|
|
<div class="tab-container">
|
|
|
<div class="tabs">
|
|
|
<div class="tab active" onclick="showTab('file')">Upload File</div>
|
|
|
<div class="tab" onclick="showTab('text')">Enter Text</div>
|
|
|
</div>
|
|
|
|
|
|
<div id="file-tab" class="tab-content active">
|
|
|
<form id="uploadForm" enctype="multipart/form-data">
|
|
|
<div class="upload-area">
|
|
|
<p>Choose a document to extract information</p>
|
|
|
<p><small>Supported: PDF, DOCX, Images (PNG, JPG, etc.)</small></p>
|
|
|
<input type="file" id="fileInput" name="file" accept=".pdf,.docx,.png,.jpg,.jpeg,.tiff,.bmp" style="margin: 10px 0;">
|
|
|
<br>
|
|
|
<button type="submit" class="btn">Extract Information</button>
|
|
|
</div>
|
|
|
</form>
|
|
|
</div>
|
|
|
|
|
|
<div id="text-tab" class="tab-content">
|
|
|
<form id="textForm">
|
|
|
<p>Enter text directly for information extraction:</p>
|
|
|
<textarea id="textInput" class="text-input" placeholder="Enter document text here, e.g.: Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"></textarea>
|
|
|
<br><br>
|
|
|
<button type="submit" class="btn">Extract from Text</button>
|
|
|
</form>
|
|
|
</div>
|
|
|
</div>
|
|
|
|
|
|
<div id="result" class="result" style="display: none;">
|
|
|
<h3>Extraction Results</h3>
|
|
|
<div id="resultContent"></div>
|
|
|
</div>
|
|
|
</div>
|
|
|
|
|
|
<script>
|
|
|
function showTab(tabName) {
|
|
|
// Hide all tab contents
|
|
|
document.querySelectorAll('.tab-content').forEach(content => {
|
|
|
content.classList.remove('active');
|
|
|
});
|
|
|
|
|
|
// Remove active class from all tabs
|
|
|
document.querySelectorAll('.tab').forEach(tab => {
|
|
|
tab.classList.remove('active');
|
|
|
});
|
|
|
|
|
|
// Show selected tab content
|
|
|
document.getElementById(tabName + '-tab').classList.add('active');
|
|
|
|
|
|
// Add active class to selected tab
|
|
|
event.target.classList.add('active');
|
|
|
}
|
|
|
|
|
|
// File upload form handler
|
|
|
document.getElementById('uploadForm').addEventListener('submit', async function(e) {
|
|
|
e.preventDefault();
|
|
|
|
|
|
const fileInput = document.getElementById('fileInput');
|
|
|
if (!fileInput.files[0]) {
|
|
|
alert('Please select a file');
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
const formData = new FormData();
|
|
|
formData.append('file', fileInput.files[0]);
|
|
|
|
|
|
try {
|
|
|
showResult('Processing document, please wait...');
|
|
|
|
|
|
const response = await fetch('/extract-from-file', {
|
|
|
method: 'POST',
|
|
|
body: formData
|
|
|
});
|
|
|
|
|
|
const result = await response.json();
|
|
|
displayResult(result);
|
|
|
|
|
|
} catch (error) {
|
|
|
showResult('Error: ' + error.message);
|
|
|
}
|
|
|
});
|
|
|
|
|
|
// Text form handler
|
|
|
document.getElementById('textForm').addEventListener('submit', async function(e) {
|
|
|
e.preventDefault();
|
|
|
|
|
|
const text = document.getElementById('textInput').value;
|
|
|
if (!text.trim()) {
|
|
|
alert('Please enter some text');
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
try {
|
|
|
showResult('Processing text, please wait...');
|
|
|
|
|
|
const response = await fetch('/extract-from-text', {
|
|
|
method: 'POST',
|
|
|
headers: {
|
|
|
'Content-Type': 'application/json',
|
|
|
},
|
|
|
body: JSON.stringify({ text: text })
|
|
|
});
|
|
|
|
|
|
const result = await response.json();
|
|
|
displayResult(result);
|
|
|
|
|
|
} catch (error) {
|
|
|
showResult('Error: ' + error.message);
|
|
|
}
|
|
|
});
|
|
|
|
|
|
function showResult(message) {
|
|
|
const resultDiv = document.getElementById('result');
|
|
|
const contentDiv = document.getElementById('resultContent');
|
|
|
contentDiv.innerHTML = message;
|
|
|
resultDiv.style.display = 'block';
|
|
|
}
|
|
|
|
|
|
function displayResult(result) {
|
|
|
let html = '';
|
|
|
|
|
|
if (result.error) {
|
|
|
html = `<div style="color: red;">Error: ${result.error}</div>`;
|
|
|
} else {
|
|
|
// Show structured data
|
|
|
if (result.structured_data && Object.keys(result.structured_data).length > 0) {
|
|
|
html += '<h4>Extracted Information:</h4>';
|
|
|
html += '<table style="width: 100%; border-collapse: collapse; margin: 10px 0;">';
|
|
|
html += '<tr style="background-color: #f8f9fa;"><th style="padding: 8px; border: 1px solid #dee2e6; text-align: left;">Field</th><th style="padding: 8px; border: 1px solid #dee2e6; text-align: left;">Value</th></tr>';
|
|
|
|
|
|
for (const [key, value] of Object.entries(result.structured_data)) {
|
|
|
html += `<tr><td style="padding: 8px; border: 1px solid #dee2e6; font-weight: bold;">${key}</td><td style="padding: 8px; border: 1px solid #dee2e6;">${value}</td></tr>`;
|
|
|
}
|
|
|
html += '</table>';
|
|
|
} else {
|
|
|
html += '<div style="color: orange;">No structured information found in the document.</div>';
|
|
|
}
|
|
|
|
|
|
// Show entities
|
|
|
if (result.entities && result.entities.length > 0) {
|
|
|
html += '<h4>Detected Entities:</h4>';
|
|
|
html += '<div style="margin: 10px 0;">';
|
|
|
result.entities.forEach(entity => {
|
|
|
const confidence = Math.round(entity.confidence * 100);
|
|
|
html += `<span style="display: inline-block; margin: 2px 4px; padding: 4px 8px; background-color: #e3f2fd; border: 1px solid #2196f3; border-radius: 15px; font-size: 12px;">
|
|
|
${entity.entity}: "${entity.text}" (${confidence}%)</span>`;
|
|
|
});
|
|
|
html += '</div>';
|
|
|
}
|
|
|
|
|
|
// Show raw JSON
|
|
|
html += '<h4>Full Response:</h4>';
|
|
|
html += `<div class="json-output">${JSON.stringify(result, null, 2)}</div>`;
|
|
|
}
|
|
|
|
|
|
showResult(html);
|
|
|
}
|
|
|
</script>
|
|
|
</body>
|
|
|
</html>
|
|
|
"""
|
|
|
return html_content
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
"""Health check endpoint."""
|
|
|
try:
|
|
|
get_inference_pipeline()
|
|
|
return {"status": "healthy", "message": "Model loaded successfully"}
|
|
|
except Exception as e:
|
|
|
return {"status": "unhealthy", "message": str(e)}
|
|
|
|
|
|
|
|
|
@app.post("/extract-from-file")
|
|
|
async def extract_from_file(file: UploadFile = File(...)):
|
|
|
"""Extract structured information from an uploaded file."""
|
|
|
if not file:
|
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
|
|
|
|
|
allowed_extensions = {'.pdf', '.docx', '.png', '.jpg', '.jpeg', '.tiff', '.bmp'}
|
|
|
file_extension = Path(file.filename).suffix.lower()
|
|
|
|
|
|
if file_extension not in allowed_extensions:
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail=f"Unsupported file type: {file_extension}. Allowed: {', '.join(allowed_extensions)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
|
|
shutil.copyfileobj(file.file, temp_file)
|
|
|
temp_file_path = temp_file.name
|
|
|
|
|
|
try:
|
|
|
|
|
|
inference = get_inference_pipeline()
|
|
|
result = inference.process_document(temp_file_path)
|
|
|
|
|
|
return JSONResponse(content=result)
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
finally:
|
|
|
|
|
|
try:
|
|
|
os.unlink(temp_file_path)
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
@app.post("/extract-from-text")
|
|
|
async def extract_from_text(request: Dict[str, str]):
|
|
|
"""Extract structured information from text."""
|
|
|
text = request.get("text", "").strip()
|
|
|
|
|
|
if not text:
|
|
|
raise HTTPException(status_code=400, detail="No text provided")
|
|
|
|
|
|
try:
|
|
|
|
|
|
inference = get_inference_pipeline()
|
|
|
result = inference.process_text_directly(text)
|
|
|
|
|
|
return JSONResponse(content=result)
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
@app.get("/supported-formats")
|
|
|
async def get_supported_formats():
|
|
|
"""Get list of supported file formats."""
|
|
|
return {
|
|
|
"supported_formats": [
|
|
|
{"extension": ".pdf", "description": "PDF documents"},
|
|
|
{"extension": ".docx", "description": "Microsoft Word documents"},
|
|
|
{"extension": ".png", "description": "PNG images"},
|
|
|
{"extension": ".jpg", "description": "JPEG images"},
|
|
|
{"extension": ".jpeg", "description": "JPEG images"},
|
|
|
{"extension": ".tiff", "description": "TIFF images"},
|
|
|
{"extension": ".bmp", "description": "BMP images"}
|
|
|
],
|
|
|
"entity_types": [
|
|
|
"Name", "Date", "InvoiceNo", "Amount", "Address", "Phone", "Email"
|
|
|
]
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.get("/model-info")
|
|
|
async def get_model_info():
|
|
|
"""Get information about the loaded model."""
|
|
|
try:
|
|
|
inference = get_inference_pipeline()
|
|
|
return {
|
|
|
"model_path": inference.model_path,
|
|
|
"model_name": inference.config.model_name,
|
|
|
"max_length": inference.config.max_length,
|
|
|
"entity_labels": inference.config.entity_labels,
|
|
|
"num_labels": inference.config.num_labels
|
|
|
}
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=503, detail=f"Model not loaded: {str(e)}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Run the FastAPI server."""
|
|
|
print("Starting Document Text Extraction API Server...")
|
|
|
print("Server will be available at: http://localhost:8000")
|
|
|
print("Web interface: http://localhost:8000")
|
|
|
print("API docs: http://localhost:8000/docs")
|
|
|
print("Health check: http://localhost:8000/health")
|
|
|
|
|
|
uvicorn.run(
|
|
|
"api.app:app",
|
|
|
host="0.0.0.0",
|
|
|
port=8000,
|
|
|
reload=True,
|
|
|
log_level="info"
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |