Initial RenAI app
Browse files- Dockerfile +34 -0
- __pycache__/configs.cpython-312.pyc +0 -0
- __pycache__/inference.cpython-312.pyc +0 -0
- __pycache__/vit.cpython-312.pyc +0 -0
- app.py +323 -0
- inference.py +255 -0
- main.py +65 -0
- requirements.txt +233 -0
- utils/__pycache__/configs.cpython-312.pyc +0 -0
- utils/__pycache__/helper.cpython-312.pyc +0 -0
- utils/__pycache__/inference.cpython-312.pyc +0 -0
- utils/__pycache__/line_segmentation.cpython-312.pyc +0 -0
- utils/__pycache__/postprocessing.cpython-312.pyc +0 -0
- utils/__pycache__/preprocessing.cpython-312.pyc +0 -0
- utils/__pycache__/vit.cpython-312.pyc +0 -0
- utils/helper.py +10 -0
- utils/line_segmentation.py +327 -0
- utils/postprocessing.py +353 -0
- utils/preprocessing.py +202 -0
- utils/unet.py +64 -0
- vit.py +111 -0
Dockerfile
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd user
|
| 4 |
+
|
| 5 |
+
USER user
|
| 6 |
+
|
| 7 |
+
ENV HOME=/home/user \
|
| 8 |
+
PATH="/home/user/.local/bin:$PATH" \
|
| 9 |
+
PYTHONUNBUFFERED=1 \
|
| 10 |
+
PYTHONDONTWRITEBYTECODE=1
|
| 11 |
+
|
| 12 |
+
WORKDIR $HOME/app
|
| 13 |
+
|
| 14 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 15 |
+
libglib2.0-0 \
|
| 16 |
+
libsm6 \
|
| 17 |
+
libxext6 \
|
| 18 |
+
libxrender-dev \
|
| 19 |
+
libgomp1 \
|
| 20 |
+
libgtk-3-0 \
|
| 21 |
+
libavcodec-dev \
|
| 22 |
+
libavformat-dev \
|
| 23 |
+
libswscale-dev \
|
| 24 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 25 |
+
|
| 26 |
+
COPY requirements.txt .
|
| 27 |
+
|
| 28 |
+
RUN pip install --no-cache-dir --timeout=100 -r requirements.txt
|
| 29 |
+
|
| 30 |
+
COPY . .
|
| 31 |
+
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
CMD ["uvicorn", "app:app", "--host=0.0.0.0", "--port=7860"]
|
__pycache__/configs.cpython-312.pyc
ADDED
|
Binary file (284 Bytes). View file
|
|
|
__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
__pycache__/vit.cpython-312.pyc
ADDED
|
Binary file (7.01 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
from typing import Optional, Any, Dict, Union
|
| 5 |
+
import shutil
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from main import RenAITranscription
|
| 11 |
+
import tempfile
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import base64
|
| 15 |
+
from io import BytesIO
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
app = FastAPI(title="RenAI Transcription API", version="1.0.0")
|
| 19 |
+
|
| 20 |
+
# Add CORS middleware
|
| 21 |
+
# app.add_middleware(
|
| 22 |
+
# CORSMiddleware,
|
| 23 |
+
# allow_origins=["*"],
|
| 24 |
+
# allow_credentials=True,
|
| 25 |
+
# allow_methods=["*"],
|
| 26 |
+
# allow_headers=["*"],
|
| 27 |
+
# )
|
| 28 |
+
|
| 29 |
+
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
| 30 |
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
| 31 |
+
|
| 32 |
+
def numpy_to_base64(array: np.ndarray, format: str = 'PNG', quality: int = 85) -> str:
|
| 33 |
+
"""
|
| 34 |
+
Convert numpy array (image) to base64 encoded string for web display.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
array: Numpy array representing the image
|
| 38 |
+
format: Image format ('PNG' or 'JPEG')
|
| 39 |
+
quality: JPEG quality (1-100), only used if format is JPEG
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Data URI string that can be directly used in HTML <img> src attribute
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
# Convert numpy array to PIL Image
|
| 46 |
+
img = Image.fromarray(array)
|
| 47 |
+
|
| 48 |
+
# Save to bytes buffer
|
| 49 |
+
buffer = BytesIO()
|
| 50 |
+
if format.upper() == 'JPEG':
|
| 51 |
+
# Convert to RGB if needed (JPEG doesn't support transparency)
|
| 52 |
+
if img.mode in ('RGBA', 'LA', 'P'):
|
| 53 |
+
background = Image.new('RGB', img.size, (255, 255, 255))
|
| 54 |
+
if img.mode == 'P':
|
| 55 |
+
img = img.convert('RGBA')
|
| 56 |
+
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
| 57 |
+
img = background
|
| 58 |
+
img.save(buffer, format='JPEG', quality=quality, optimize=True)
|
| 59 |
+
mime_type = 'image/jpeg'
|
| 60 |
+
else:
|
| 61 |
+
img.save(buffer, format='PNG', optimize=True)
|
| 62 |
+
mime_type = 'image/png'
|
| 63 |
+
|
| 64 |
+
# Encode to base64
|
| 65 |
+
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 66 |
+
return f"data:{mime_type};base64,{img_str}"
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Error converting numpy array to base64: {e}")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
def format_transcription_result(result: Dict, include_images: bool = False, image_format: str = 'PNG') -> Dict[str, Any]:
|
| 72 |
+
"""
|
| 73 |
+
Format transcription result into a structured response.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
result: Dictionary with line IDs as keys, each containing 'image' and 'transcription'
|
| 77 |
+
include_images: Whether to include base64 encoded images in response
|
| 78 |
+
image_format: Image format for base64 encoding ('PNG' or 'JPEG')
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Formatted dictionary with transcription data
|
| 82 |
+
"""
|
| 83 |
+
formatted_lines = {}
|
| 84 |
+
transcription_text = []
|
| 85 |
+
|
| 86 |
+
for line_id, line_data in result.items():
|
| 87 |
+
formatted_line = {
|
| 88 |
+
'line_id': line_id,
|
| 89 |
+
'transcription': line_data.get('transcription', '')
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Optionally include image as base64 (web-ready format)
|
| 93 |
+
if include_images and 'image' in line_data:
|
| 94 |
+
image_array = line_data['image']
|
| 95 |
+
if isinstance(image_array, np.ndarray):
|
| 96 |
+
image_base64 = numpy_to_base64(image_array, format=image_format)
|
| 97 |
+
if image_base64:
|
| 98 |
+
formatted_line['image'] = image_base64
|
| 99 |
+
|
| 100 |
+
formatted_lines[line_id] = formatted_line
|
| 101 |
+
transcription_text.append(f"{line_id}: {line_data.get('transcription', '')}")
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
'lines': formatted_lines,
|
| 105 |
+
'full_text': '\n'.join(transcription_text),
|
| 106 |
+
'total_lines': len(result)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
@app.get("/")
|
| 110 |
+
def home():
|
| 111 |
+
return {
|
| 112 |
+
"message": "Hello, RenAI!",
|
| 113 |
+
"version": "1.0.0",
|
| 114 |
+
"endpoints": {
|
| 115 |
+
"transcribe": "/renai-transcribe (POST)",
|
| 116 |
+
"transcribe_base64": "/renai-transcribe-base64 (POST)",
|
| 117 |
+
"health": "/health (GET)"
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
@app.post("/renai-transcribe")
|
| 122 |
+
async def transcription_endpoint(
|
| 123 |
+
image: UploadFile = File(..., description="Image file to transcribe"),
|
| 124 |
+
userToken: Optional[str] = Form(None, description="User authentication token"),
|
| 125 |
+
post_processing_enabled: bool = Form(False, description="Enable post-processing"),
|
| 126 |
+
unet_enabled: bool = Form(False, description="Enable UNet processing"),
|
| 127 |
+
include_images: bool = Form(True, description="Include base64 encoded line images in response"),
|
| 128 |
+
image_format: str = Form("JPEG", description="Image format for line images: PNG or JPEG")
|
| 129 |
+
):
|
| 130 |
+
"""
|
| 131 |
+
Upload an image file and get transcription results.
|
| 132 |
+
|
| 133 |
+
- **image**: Image file (JPG, PNG, BMP, TIFF, WebP)
|
| 134 |
+
- **userToken**: Optional user authentication token
|
| 135 |
+
- **post_processing_enabled**: Enable/disable post-processing
|
| 136 |
+
- **unet_enabled**: Enable/disable UNet processing
|
| 137 |
+
- **include_images**: Include base64 encoded images of each line (web-ready format)
|
| 138 |
+
- **image_format**: Format for line images: 'PNG' (higher quality, larger) or 'JPEG' (smaller, faster)
|
| 139 |
+
"""
|
| 140 |
+
start_time = datetime.now()
|
| 141 |
+
logger.info(f"Transcription request received for file: {image.filename} by userToken: {userToken if userToken else 'Anonymous'}")
|
| 142 |
+
|
| 143 |
+
# Validate file type
|
| 144 |
+
if not image.filename:
|
| 145 |
+
raise HTTPException(status_code=400, detail="No file provided")
|
| 146 |
+
|
| 147 |
+
file_extension = Path(image.filename).suffix.lower()
|
| 148 |
+
if file_extension not in ALLOWED_EXTENSIONS:
|
| 149 |
+
raise HTTPException(
|
| 150 |
+
status_code=400,
|
| 151 |
+
detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Check file size
|
| 155 |
+
if image.size and image.size > MAX_FILE_SIZE:
|
| 156 |
+
raise HTTPException(
|
| 157 |
+
status_code=400,
|
| 158 |
+
detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
temp_file_path = None
|
| 162 |
+
try:
|
| 163 |
+
# Create temporary file
|
| 164 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
| 165 |
+
shutil.copyfileobj(image.file, temp_file)
|
| 166 |
+
temp_file_path = temp_file.name
|
| 167 |
+
|
| 168 |
+
logger.info(f"Processing image: {temp_file_path}")
|
| 169 |
+
|
| 170 |
+
# Call transcription function
|
| 171 |
+
result = RenAITranscription(
|
| 172 |
+
image=temp_file_path,
|
| 173 |
+
post_processing_enabled=post_processing_enabled,
|
| 174 |
+
unet_enabled=unet_enabled
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
logger.info(f"Transcription completed. Result type: {type(result)}, Lines: {len(result)}")
|
| 178 |
+
|
| 179 |
+
# Format the result
|
| 180 |
+
formatted_result = format_transcription_result(result, include_images=include_images, image_format=image_format)
|
| 181 |
+
|
| 182 |
+
# Clean up
|
| 183 |
+
os.unlink(temp_file_path)
|
| 184 |
+
|
| 185 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 186 |
+
logger.info(f"Request completed in {processing_time:.2f}s")
|
| 187 |
+
|
| 188 |
+
response_data = {
|
| 189 |
+
"success": True,
|
| 190 |
+
"filename": image.filename,
|
| 191 |
+
"transcription": formatted_result,
|
| 192 |
+
"metadata": {
|
| 193 |
+
"processing_time_seconds": round(processing_time, 2),
|
| 194 |
+
"timestamp": datetime.now().isoformat(),
|
| 195 |
+
"total_lines": formatted_result['total_lines'],
|
| 196 |
+
"parameters": {
|
| 197 |
+
"post_processing_enabled": post_processing_enabled,
|
| 198 |
+
"unet_enabled": unet_enabled,
|
| 199 |
+
"include_images": include_images,
|
| 200 |
+
"userToken": userToken if userToken else "Anonymous"
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return JSONResponse(content=response_data)
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
# Clean up
|
| 209 |
+
if temp_file_path and os.path.exists(temp_file_path):
|
| 210 |
+
try:
|
| 211 |
+
os.unlink(temp_file_path)
|
| 212 |
+
except:
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
logger.error(f"Transcription failed: {e}")
|
| 216 |
+
|
| 217 |
+
raise HTTPException(
|
| 218 |
+
status_code=500,
|
| 219 |
+
detail={
|
| 220 |
+
"error": str(e),
|
| 221 |
+
"type": type(e).__name__
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
@app.post("/renai-transcribe-base64")
|
| 226 |
+
async def transcription_base64_endpoint(
|
| 227 |
+
image_data: str = Form(..., description="Base64 encoded image data"),
|
| 228 |
+
userToken: Optional[str] = Form(None, description="User authentication token"),
|
| 229 |
+
post_processing_enabled: bool = Form(False, description="Enable post-processing"),
|
| 230 |
+
unet_enabled: bool = Form(False, description="Enable UNet processing"),
|
| 231 |
+
include_images: bool = Form(False, description="Include base64 encoded line images in response"),
|
| 232 |
+
image_format: str = Form("JPEG", description="Image format for line images: PNG or JPEG")
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Alternative endpoint that accepts base64 encoded image data.
|
| 236 |
+
"""
|
| 237 |
+
import base64
|
| 238 |
+
import io
|
| 239 |
+
from PIL import Image
|
| 240 |
+
|
| 241 |
+
start_time = datetime.now()
|
| 242 |
+
logger.info(f"Base64 transcription request received by userToken: {userToken if userToken else 'Anonymous'}")
|
| 243 |
+
|
| 244 |
+
temp_file_path = None
|
| 245 |
+
try:
|
| 246 |
+
# Remove data URL prefix if present
|
| 247 |
+
if "," in image_data:
|
| 248 |
+
image_data = image_data.split(",", 1)[1]
|
| 249 |
+
|
| 250 |
+
# Decode base64 image
|
| 251 |
+
image_bytes = base64.b64decode(image_data)
|
| 252 |
+
image_pil = Image.open(io.BytesIO(image_bytes))
|
| 253 |
+
|
| 254 |
+
# Create temporary file
|
| 255 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
| 256 |
+
image_pil.save(temp_file.name)
|
| 257 |
+
temp_file_path = temp_file.name
|
| 258 |
+
|
| 259 |
+
logger.info(f"Processing base64 image: {temp_file_path}")
|
| 260 |
+
|
| 261 |
+
# Call transcription function
|
| 262 |
+
result = RenAITranscription(
|
| 263 |
+
image=temp_file_path,
|
| 264 |
+
post_processing_enabled=post_processing_enabled,
|
| 265 |
+
unet_enabled=unet_enabled
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Format the result
|
| 269 |
+
formatted_result = format_transcription_result(result, include_images=include_images, image_format=image_format)
|
| 270 |
+
|
| 271 |
+
# Clean up
|
| 272 |
+
os.unlink(temp_file_path)
|
| 273 |
+
|
| 274 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 275 |
+
logger.info(f"Base64 request completed in {processing_time:.2f}s")
|
| 276 |
+
|
| 277 |
+
response_data = {
|
| 278 |
+
"success": True,
|
| 279 |
+
"transcription": formatted_result,
|
| 280 |
+
"metadata": {
|
| 281 |
+
"processing_time_seconds": round(processing_time, 2),
|
| 282 |
+
"timestamp": datetime.now().isoformat(),
|
| 283 |
+
"total_lines": formatted_result['total_lines'],
|
| 284 |
+
"parameters": {
|
| 285 |
+
"post_processing_enabled": post_processing_enabled,
|
| 286 |
+
"unet_enabled": unet_enabled,
|
| 287 |
+
"include_images": include_images,
|
| 288 |
+
"image_format": image_format if include_images else None,
|
| 289 |
+
"userToken": userToken if userToken else "Anonymous"
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
return JSONResponse(content=response_data)
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
if temp_file_path and os.path.exists(temp_file_path):
|
| 298 |
+
try:
|
| 299 |
+
os.unlink(temp_file_path)
|
| 300 |
+
except:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
+
logger.error(f"Base64 transcription failed: {e}")
|
| 304 |
+
|
| 305 |
+
raise HTTPException(
|
| 306 |
+
status_code=500,
|
| 307 |
+
detail={
|
| 308 |
+
"error": str(e),
|
| 309 |
+
"type": type(e).__name__
|
| 310 |
+
}
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
@app.get("/health")
|
| 314 |
+
def health_check():
|
| 315 |
+
try:
|
| 316 |
+
return {
|
| 317 |
+
"status": "healthy",
|
| 318 |
+
"service": "RenAI Transcription API",
|
| 319 |
+
"timestamp": datetime.now().isoformat()
|
| 320 |
+
}
|
| 321 |
+
except Exception as e:
|
| 322 |
+
logger.error(f"Health check failed: {e}")
|
| 323 |
+
raise HTTPException(status_code=500, detail="Service unhealthy")
|
inference.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import re
|
| 7 |
+
import cv2
|
| 8 |
+
import string
|
| 9 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
| 10 |
+
from vit import LineDataset, collate_fn
|
| 11 |
+
from loguru import logger
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
class Inference:
|
| 15 |
+
def __init__(self, model_path, processor_path, target_size=(256, 64), batch_size=32):
|
| 16 |
+
"""
|
| 17 |
+
Initialize the TextGenerator with model and processor paths.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
model_path (str): Path to the pre-trained model
|
| 21 |
+
processor_path (str): Path to the pre-trained processor
|
| 22 |
+
target_size (tuple): Target size for input images (height, width)
|
| 23 |
+
batch_size (int): Batch size for inference
|
| 24 |
+
"""
|
| 25 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
self.model_path = self._get_absolute_path(model_path)
|
| 27 |
+
self.processor_path = self._get_absolute_path(processor_path)
|
| 28 |
+
self.target_size = target_size
|
| 29 |
+
self.batch_size = batch_size
|
| 30 |
+
|
| 31 |
+
# Initialize model and processor
|
| 32 |
+
self.processor = None
|
| 33 |
+
self.model = None
|
| 34 |
+
self._initialize_model()
|
| 35 |
+
|
| 36 |
+
def _get_absolute_path(self, path):
|
| 37 |
+
"""Convert relative path to absolute path"""
|
| 38 |
+
if os.path.isabs(path):
|
| 39 |
+
return path
|
| 40 |
+
# If it's a relative path, make it absolute relative to the current working directory
|
| 41 |
+
return os.path.join(os.getcwd(), path.lstrip('./'))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _initialize_model(self):
|
| 45 |
+
"""Load and initialize the model and processor."""
|
| 46 |
+
logger.info("Loading model...")
|
| 47 |
+
|
| 48 |
+
# Check if paths exist
|
| 49 |
+
if not os.path.exists(self.model_path):
|
| 50 |
+
raise FileNotFoundError(f"Model path not found: {self.model_path}")
|
| 51 |
+
if not os.path.exists(self.processor_path):
|
| 52 |
+
raise FileNotFoundError(f"Processor path not found: {self.processor_path}")
|
| 53 |
+
|
| 54 |
+
# List all files in the model directory
|
| 55 |
+
all_files = os.listdir(self.model_path)
|
| 56 |
+
|
| 57 |
+
# Validate that we have the necessary files
|
| 58 |
+
if not any(f in all_files for f in ['pytorch_model.bin', 'model.safetensors']):
|
| 59 |
+
logger.error("No model weights file found! (pytorch_model.bin or model.safetensors)")
|
| 60 |
+
raise FileNotFoundError("Model weights file missing")
|
| 61 |
+
|
| 62 |
+
if 'config.json' not in all_files:
|
| 63 |
+
logger.error("config.json file not found!")
|
| 64 |
+
raise FileNotFoundError("config.json missing")
|
| 65 |
+
|
| 66 |
+
logger.info(f"Loading model from: {self.model_path}")
|
| 67 |
+
logger.info(f"Loading processor from: {self.processor_path}")
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Load processor
|
| 71 |
+
self.processor = TrOCRProcessor.from_pretrained(self.processor_path, do_rescale=False, use_fast=True)
|
| 72 |
+
logger.info("Processor loaded successfully")
|
| 73 |
+
|
| 74 |
+
# Try different loading methods for the model
|
| 75 |
+
logger.info("Attempting to load model...")
|
| 76 |
+
|
| 77 |
+
# Method 1: Try with explicit device mapping
|
| 78 |
+
try:
|
| 79 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(
|
| 80 |
+
self.model_path,
|
| 81 |
+
use_safetensors=True,
|
| 82 |
+
device_map="auto" if torch.cuda.is_available() else None
|
| 83 |
+
)
|
| 84 |
+
logger.info("Model loaded with safetensors=True and device_map")
|
| 85 |
+
except Exception as e1:
|
| 86 |
+
logger.warning(f"Method 1 failed: {e1}")
|
| 87 |
+
|
| 88 |
+
# Method 2: Try without device mapping
|
| 89 |
+
try:
|
| 90 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(
|
| 91 |
+
self.model_path,
|
| 92 |
+
use_safetensors=True
|
| 93 |
+
)
|
| 94 |
+
logger.info("Model loaded with safetensors=True")
|
| 95 |
+
except Exception as e2:
|
| 96 |
+
logger.warning(f"Method 2 failed: {e2}")
|
| 97 |
+
|
| 98 |
+
# Method 3: Try without safetensors
|
| 99 |
+
try:
|
| 100 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(
|
| 101 |
+
self.model_path,
|
| 102 |
+
use_safetensors=True
|
| 103 |
+
)
|
| 104 |
+
logger.info("Model loaded with safetensors=False")
|
| 105 |
+
except Exception as e3:
|
| 106 |
+
logger.error(f"All loading methods failed: {e3}")
|
| 107 |
+
raise
|
| 108 |
+
|
| 109 |
+
# Move model to device if not already done by device_map
|
| 110 |
+
if not hasattr(self.model, 'device') or str(self.model.device) != str(self.device):
|
| 111 |
+
logger.info(f"Moving model to device: {self.device}")
|
| 112 |
+
self.model.to(self.device)
|
| 113 |
+
|
| 114 |
+
self.model.eval()
|
| 115 |
+
logger.info("Model loaded successfully and moved to device")
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error loading model or processor: {e}")
|
| 119 |
+
import traceback
|
| 120 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 121 |
+
raise
|
| 122 |
+
def preprocess_images(self, line_segments):
|
| 123 |
+
"""
|
| 124 |
+
Prepare line images for inference.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
line_segments (dict): Dictionary containing line segment information
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
tuple: (keys, line_images) - keys and corresponding images
|
| 131 |
+
"""
|
| 132 |
+
keys = list(line_segments.keys())
|
| 133 |
+
line_images = [line_segments[k]["image"] for k in keys]
|
| 134 |
+
return keys, line_images
|
| 135 |
+
|
| 136 |
+
def create_dataloader(self, line_images):
|
| 137 |
+
"""
|
| 138 |
+
Create DataLoader for inference.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
line_images (list): List of line images
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
DataLoader: Configured DataLoader for inference
|
| 145 |
+
"""
|
| 146 |
+
# Create dummy labels for inference
|
| 147 |
+
dummy_labels = [""] * len(line_images)
|
| 148 |
+
|
| 149 |
+
dataset = LineDataset(
|
| 150 |
+
self.processor,
|
| 151 |
+
self.model,
|
| 152 |
+
line_images,
|
| 153 |
+
dummy_labels,
|
| 154 |
+
self.target_size,
|
| 155 |
+
apply_augmentation=False
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
dataloader = DataLoader(
|
| 159 |
+
dataset,
|
| 160 |
+
batch_size=self.batch_size,
|
| 161 |
+
shuffle=False,
|
| 162 |
+
collate_fn=collate_fn
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return dataloader
|
| 166 |
+
|
| 167 |
+
def generate_texts(self, dataloader):
|
| 168 |
+
"""
|
| 169 |
+
Generate texts from images using the model.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
dataloader (DataLoader): DataLoader containing preprocessed images
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
list: List of generated texts
|
| 176 |
+
"""
|
| 177 |
+
generated_texts = []
|
| 178 |
+
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
for batch in dataloader:
|
| 181 |
+
pixel_values = batch["pixel_values"].to(self.device)
|
| 182 |
+
generated_ids = self.model.generate(pixel_values)
|
| 183 |
+
generated_texts_batch = self.processor.batch_decode(
|
| 184 |
+
generated_ids,
|
| 185 |
+
skip_special_tokens=True
|
| 186 |
+
)
|
| 187 |
+
generated_texts.extend(generated_texts_batch)
|
| 188 |
+
|
| 189 |
+
return generated_texts
|
| 190 |
+
|
| 191 |
+
def update_line_segments(self, line_segments, keys, generated_texts):
|
| 192 |
+
"""
|
| 193 |
+
Update line segments dictionary with generated transcriptions.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
line_segments (dict): Original line segments dictionary
|
| 197 |
+
keys (list): List of keys corresponding to the line segments
|
| 198 |
+
generated_texts (list): List of generated texts
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
dict: Updated line segments dictionary with transcriptions
|
| 202 |
+
"""
|
| 203 |
+
for key, text in zip(keys, generated_texts):
|
| 204 |
+
line_segments[key]["transcription"] = text
|
| 205 |
+
|
| 206 |
+
return line_segments
|
| 207 |
+
|
| 208 |
+
def generate_texts_from_images(self, line_segments):
|
| 209 |
+
"""
|
| 210 |
+
Main method to generate texts from line segment images.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
line_segments (dict): Dictionary containing line segment information
|
| 214 |
+
with "image" key for each segment
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
dict: Updated line segments dictionary with "transcription" key added
|
| 218 |
+
"""
|
| 219 |
+
logger.info("Starting text generation from images...")
|
| 220 |
+
# Preprocess images
|
| 221 |
+
keys, line_images = self.preprocess_images(line_segments)
|
| 222 |
+
|
| 223 |
+
# Create dataloader
|
| 224 |
+
dataloader = self.create_dataloader(line_images)
|
| 225 |
+
|
| 226 |
+
# Generate texts
|
| 227 |
+
generated_texts = self.generate_texts(dataloader)
|
| 228 |
+
|
| 229 |
+
# Update line segments with transcriptions
|
| 230 |
+
updated_line_segments = self.update_line_segments(
|
| 231 |
+
line_segments, keys, generated_texts
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return updated_line_segments
|
| 235 |
+
|
| 236 |
+
def generate_single_image(self, image):
|
| 237 |
+
"""
|
| 238 |
+
Generate text from a single image.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
image: PIL Image or numpy array
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
str: Generated text
|
| 245 |
+
"""
|
| 246 |
+
if isinstance(image, np.ndarray):
|
| 247 |
+
image = Image.fromarray(image)
|
| 248 |
+
|
| 249 |
+
# Create a temporary line_segments-like structure
|
| 250 |
+
temp_segments = {"temp_key": {"image": image}}
|
| 251 |
+
|
| 252 |
+
# Use the main method
|
| 253 |
+
result = self.generate_texts_from_images(temp_segments)
|
| 254 |
+
|
| 255 |
+
return result["temp_key"]["transcription"]
|
main.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from skimage.io import imread, imsave
|
| 2 |
+
from skimage.color import rgb2gray
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage.transform import resize
|
| 7 |
+
from utils.preprocessing import preprocessImage, postProcessImage, process_segment_and_crop_image
|
| 8 |
+
from utils.line_segmentation import segment_image_to_lines
|
| 9 |
+
from configs import unet_enabled
|
| 10 |
+
from utils.helper import load_images_from_json
|
| 11 |
+
from inference import Inference
|
| 12 |
+
from configs import model_path, processor_path, unet_model_path
|
| 13 |
+
from utils.postprocessing import PostProcessing
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
def RenAITranscription(image, post_processing_enabled=False,unet_enabled=False):
|
| 17 |
+
# 1- preprocessing
|
| 18 |
+
org_img = imread(image)[: , : ,:]
|
| 19 |
+
|
| 20 |
+
logger.info(f'Image Dimensions : {org_img.shape[0]} x {org_img.shape[1]}')
|
| 21 |
+
|
| 22 |
+
intial_process_image = preprocessImage(org_img)
|
| 23 |
+
|
| 24 |
+
if unet_enabled:
|
| 25 |
+
logger.info("Masked based segmentation and cropping enabled...")
|
| 26 |
+
cropped_img = process_segment_and_crop_image(unet_model_path, org_img, intial_process_image, padding=10, min_contour_area=100)
|
| 27 |
+
processed_image = postProcessImage(cropped_img)
|
| 28 |
+
logger.info(f"Image cropped and Pre-processed successfully.....")
|
| 29 |
+
else:
|
| 30 |
+
logger.info("Image Preprocessing started......")
|
| 31 |
+
processed_image = postProcessImage(intial_process_image)
|
| 32 |
+
logger.info(f"Image Pre-processed successfully.....")
|
| 33 |
+
|
| 34 |
+
# 2 - Line segmentation Algorithm
|
| 35 |
+
line_segments = segment_image_to_lines(processed_image, base_key="line",ct=0)
|
| 36 |
+
|
| 37 |
+
# 3 - Model Inference
|
| 38 |
+
|
| 39 |
+
transciption_generator = Inference(
|
| 40 |
+
model_path=model_path,
|
| 41 |
+
processor_path=processor_path,
|
| 42 |
+
target_size=(256, 64),
|
| 43 |
+
batch_size=32
|
| 44 |
+
)
|
| 45 |
+
result = transciption_generator.generate_texts_from_images(line_segments)
|
| 46 |
+
|
| 47 |
+
# Generated texts
|
| 48 |
+
for key, value in result.items():
|
| 49 |
+
print(f"{key}: {value['transcription']}")
|
| 50 |
+
|
| 51 |
+
# 4 - Post processing
|
| 52 |
+
# Dictionary based fuzzy matching
|
| 53 |
+
if post_processing_enabled:
|
| 54 |
+
for key, value in result.items():
|
| 55 |
+
corrected = PostProcessing(value['transcription'])
|
| 56 |
+
result[key]['post_processed'] = corrected
|
| 57 |
+
print(f"{key}: {value['post_processed']}")
|
| 58 |
+
|
| 59 |
+
print(result)
|
| 60 |
+
|
| 61 |
+
logger.info("Transcription completed successfully!")
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
RenAITranscription("1.png", post_processing_enabled=False, unet_enabled=False)
|
requirements.txt
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
accelerate==1.5.1
|
| 3 |
+
aiohappyeyeballs==2.5.0
|
| 4 |
+
aiohttp==3.11.13
|
| 5 |
+
aiosignal==1.3.2
|
| 6 |
+
albucore==0.0.23
|
| 7 |
+
albumentations==2.0.5
|
| 8 |
+
annotated-types==0.7.0
|
| 9 |
+
anyio==4.8.0
|
| 10 |
+
argon2-cffi==23.1.0
|
| 11 |
+
argon2-cffi-bindings==21.2.0
|
| 12 |
+
arrow==1.3.0
|
| 13 |
+
asttokens==3.0.0
|
| 14 |
+
astunparse==1.6.3
|
| 15 |
+
async-lru==2.0.4
|
| 16 |
+
attrs==25.1.0
|
| 17 |
+
babel==2.17.0
|
| 18 |
+
beautifulsoup4==4.13.3
|
| 19 |
+
bleach==6.2.0
|
| 20 |
+
blinker==1.9.0
|
| 21 |
+
blis==1.2.0
|
| 22 |
+
catalogue==2.0.10
|
| 23 |
+
certifi==2025.1.31
|
| 24 |
+
cffi==1.17.1
|
| 25 |
+
charset-normalizer==3.4.1
|
| 26 |
+
click==8.1.8
|
| 27 |
+
cloudpathlib==0.21.0
|
| 28 |
+
colorama==0.4.6
|
| 29 |
+
comm==0.2.2
|
| 30 |
+
confection==0.1.5
|
| 31 |
+
contourpy==1.3.1
|
| 32 |
+
cycler==0.12.1
|
| 33 |
+
cymem==2.0.11
|
| 34 |
+
datasets==3.3.2
|
| 35 |
+
datetime
|
| 36 |
+
debugpy==1.8.13
|
| 37 |
+
decorator==5.2.1
|
| 38 |
+
defusedxml==0.7.1
|
| 39 |
+
deskew
|
| 40 |
+
dill==0.3.8
|
| 41 |
+
editdistance==0.8.1
|
| 42 |
+
einops==0.8.1
|
| 43 |
+
evaluate==0.4.3
|
| 44 |
+
executing==2.2.0
|
| 45 |
+
fastapi
|
| 46 |
+
fastjsonschema==2.21.1
|
| 47 |
+
filelock==3.17.0
|
| 48 |
+
Flask==3.1.0
|
| 49 |
+
flatbuffers==25.2.10
|
| 50 |
+
fonttools==4.56.0
|
| 51 |
+
fqdn==1.5.1
|
| 52 |
+
frozenlist==1.5.0
|
| 53 |
+
fsspec==2024.12.0
|
| 54 |
+
gast==0.6.0
|
| 55 |
+
gensim==4.3.3
|
| 56 |
+
google-pasta==0.2.0
|
| 57 |
+
greenlet==3.1.1
|
| 58 |
+
grpcio==1.71.0
|
| 59 |
+
h11==0.14.0
|
| 60 |
+
h5py==3.13.0
|
| 61 |
+
httpcore==1.0.7
|
| 62 |
+
httpx==0.28.1
|
| 63 |
+
huggingface-hub==0.34.4
|
| 64 |
+
idna==3.10
|
| 65 |
+
imageio==2.37.0
|
| 66 |
+
iniconfig==2.0.0
|
| 67 |
+
inquirerpy==0.3.4
|
| 68 |
+
ipykernel==6.29.5
|
| 69 |
+
ipython==9.0.2
|
| 70 |
+
ipython_pygments_lexers==1.1.1
|
| 71 |
+
ipywidgets==8.1.5
|
| 72 |
+
isoduration==20.11.0
|
| 73 |
+
itsdangerous==2.2.0
|
| 74 |
+
jedi==0.19.2
|
| 75 |
+
Jinja2==3.1.6
|
| 76 |
+
jiwer==3.1.0
|
| 77 |
+
joblib==1.4.2
|
| 78 |
+
json5==0.10.0
|
| 79 |
+
jsonpointer==3.0.0
|
| 80 |
+
jsonschema==4.23.0
|
| 81 |
+
jsonschema-specifications==2024.10.1
|
| 82 |
+
jupyter==1.1.1
|
| 83 |
+
jupyter-console==6.6.3
|
| 84 |
+
jupyter-events==0.12.0
|
| 85 |
+
jupyter-lsp==2.2.5
|
| 86 |
+
jupyter_client==8.6.3
|
| 87 |
+
jupyter_core==5.7.2
|
| 88 |
+
jupyter_server==2.15.0
|
| 89 |
+
jupyter_server_terminals==0.5.3
|
| 90 |
+
jupyterlab==4.3.5
|
| 91 |
+
jupyterlab_pygments==0.3.0
|
| 92 |
+
jupyterlab_server==2.27.3
|
| 93 |
+
jupyterlab_widgets==3.0.13
|
| 94 |
+
keras==3.9.0
|
| 95 |
+
kiwisolver==1.4.8
|
| 96 |
+
langcodes==3.5.0
|
| 97 |
+
language_data==1.3.0
|
| 98 |
+
lazy_loader==0.4
|
| 99 |
+
Levenshtein==0.27.1
|
| 100 |
+
libclang==18.1.1
|
| 101 |
+
loguru
|
| 102 |
+
lxml==5.3.1
|
| 103 |
+
marisa-trie==1.2.1
|
| 104 |
+
Markdown==3.7
|
| 105 |
+
markdown-it-py==3.0.0
|
| 106 |
+
MarkupSafe==3.0.2
|
| 107 |
+
matplotlib==3.10.1
|
| 108 |
+
matplotlib-inline==0.1.7
|
| 109 |
+
mdurl==0.1.2
|
| 110 |
+
mistune==3.1.2
|
| 111 |
+
ml_dtypes==0.5.1
|
| 112 |
+
mpmath==1.3.0
|
| 113 |
+
multidict==6.1.0
|
| 114 |
+
multiprocess==0.70.16
|
| 115 |
+
murmurhash==1.0.12
|
| 116 |
+
namex==0.0.8
|
| 117 |
+
narwhals==1.30.0
|
| 118 |
+
nbclient==0.10.2
|
| 119 |
+
nbconvert==7.16.6
|
| 120 |
+
nbformat==5.10.4
|
| 121 |
+
nest-asyncio==1.6.0
|
| 122 |
+
networkx==3.3
|
| 123 |
+
ninja==1.11.1.4
|
| 124 |
+
nltk==3.9.1
|
| 125 |
+
notebook==7.3.2
|
| 126 |
+
notebook_shim==0.2.4
|
| 127 |
+
numpy==1.26.4
|
| 128 |
+
opencv-python==4.11.0.86
|
| 129 |
+
opencv-python-headless
|
| 130 |
+
opt_einsum==3.4.0
|
| 131 |
+
optree==0.14.1
|
| 132 |
+
overrides==7.7.0
|
| 133 |
+
packaging==24.2
|
| 134 |
+
pandas==2.2.3
|
| 135 |
+
pandocfilters==1.5.1
|
| 136 |
+
parso==0.8.4
|
| 137 |
+
pfzy==0.3.4
|
| 138 |
+
pillow==11.1.0
|
| 139 |
+
platformdirs==4.3.6
|
| 140 |
+
plotly==6.0.0
|
| 141 |
+
pluggy==1.5.0
|
| 142 |
+
preshed==3.0.9
|
| 143 |
+
prometheus_client==0.21.1
|
| 144 |
+
prompt_toolkit==3.0.50
|
| 145 |
+
propcache==0.3.0
|
| 146 |
+
protobuf==4.25.6
|
| 147 |
+
psutil==7.0.0
|
| 148 |
+
pure_eval==0.2.3
|
| 149 |
+
pyarrow==19.0.1
|
| 150 |
+
pycparser==2.22
|
| 151 |
+
pydantic==2.10.6
|
| 152 |
+
pydantic_core==2.27.2
|
| 153 |
+
Pygments==2.19.1
|
| 154 |
+
PyMuPDF==1.25.3
|
| 155 |
+
pyparsing==3.2.1
|
| 156 |
+
pytest==8.3.5
|
| 157 |
+
python-dateutil==2.9.0.post0
|
| 158 |
+
python-docx==1.1.2
|
| 159 |
+
python-json-logger==3.3.0
|
| 160 |
+
python-Levenshtein==0.27.1
|
| 161 |
+
python-multipart
|
| 162 |
+
pytz==2025.1
|
| 163 |
+
# pywin32==309
|
| 164 |
+
# pywinpty==2.0.15
|
| 165 |
+
PyYAML==6.0.2
|
| 166 |
+
pyzmq==26.2.1
|
| 167 |
+
RapidFuzz==3.12.2
|
| 168 |
+
referencing==0.36.2
|
| 169 |
+
regex==2024.11.6
|
| 170 |
+
requests==2.32.3
|
| 171 |
+
rfc3339-validator==0.1.4
|
| 172 |
+
rfc3986-validator==0.1.1
|
| 173 |
+
rich==13.9.4
|
| 174 |
+
rpds-py==0.23.1
|
| 175 |
+
safetensors==0.5.3
|
| 176 |
+
scikit-image==0.25.2
|
| 177 |
+
scikit-learn==1.6.1
|
| 178 |
+
scipy==1.13.1
|
| 179 |
+
seaborn==0.13.2
|
| 180 |
+
Send2Trash==1.8.3
|
| 181 |
+
setuptools==75.8.0
|
| 182 |
+
shellingham==1.5.4
|
| 183 |
+
simsimd==6.2.1
|
| 184 |
+
six==1.17.0
|
| 185 |
+
smart-open==7.1.0
|
| 186 |
+
sniffio==1.3.1
|
| 187 |
+
soupsieve==2.6
|
| 188 |
+
spacy==3.8.4
|
| 189 |
+
spacy-legacy==3.0.12
|
| 190 |
+
spacy-loggers==1.0.5
|
| 191 |
+
SQLAlchemy==2.0.38
|
| 192 |
+
srsly==2.5.1
|
| 193 |
+
stack-data==0.6.3
|
| 194 |
+
stringzilla==3.12.3
|
| 195 |
+
sympy==1.13.1
|
| 196 |
+
tensorboard==2.19.0
|
| 197 |
+
tensorboard-data-server==0.7.2
|
| 198 |
+
tensorflow==2.19.0
|
| 199 |
+
# tensorflow-intel==2.16.1
|
| 200 |
+
termcolor==2.5.0
|
| 201 |
+
terminado==0.18.1
|
| 202 |
+
tf_keras==2.19.0
|
| 203 |
+
thinc==8.3.4
|
| 204 |
+
threadpoolctl==3.5.0
|
| 205 |
+
tifffile==2025.2.18
|
| 206 |
+
tinycss2==1.4.0
|
| 207 |
+
tokenizers==0.21.0
|
| 208 |
+
torch==2.4.1
|
| 209 |
+
torchaudio==2.4.1
|
| 210 |
+
torchvision==0.19.1
|
| 211 |
+
tornado==6.4.2
|
| 212 |
+
tqdm==4.67.1
|
| 213 |
+
traitlets==5.14.3
|
| 214 |
+
transformers==4.49.0
|
| 215 |
+
typer==0.15.2
|
| 216 |
+
types-python-dateutil==2.9.0.20241206
|
| 217 |
+
typing_extensions==4.12.2
|
| 218 |
+
tzdata==2025.1
|
| 219 |
+
uri-template==1.3.0
|
| 220 |
+
urllib3==2.3.0
|
| 221 |
+
uvicorn
|
| 222 |
+
wasabi==1.1.3
|
| 223 |
+
wcwidth==0.2.13
|
| 224 |
+
weasel==0.4.1
|
| 225 |
+
webcolors==24.11.1
|
| 226 |
+
webencodings==0.5.1
|
| 227 |
+
websocket-client==1.8.0
|
| 228 |
+
Werkzeug==3.1.3
|
| 229 |
+
wheel==0.45.1
|
| 230 |
+
widgetsnbextension==4.0.13
|
| 231 |
+
wrapt==1.17.2
|
| 232 |
+
xxhash==3.5.0
|
| 233 |
+
yarl==1.18.3
|
utils/__pycache__/configs.cpython-312.pyc
ADDED
|
Binary file (331 Bytes). View file
|
|
|
utils/__pycache__/helper.cpython-312.pyc
ADDED
|
Binary file (660 Bytes). View file
|
|
|
utils/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (7.26 kB). View file
|
|
|
utils/__pycache__/line_segmentation.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
utils/__pycache__/postprocessing.cpython-312.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
utils/__pycache__/preprocessing.cpython-312.pyc
ADDED
|
Binary file (9.67 kB). View file
|
|
|
utils/__pycache__/vit.cpython-312.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
utils/helper.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
def load_images_from_json(line_segments):
|
| 4 |
+
line_images = []
|
| 5 |
+
image_paths = []
|
| 6 |
+
for key, value in line_segments.items():
|
| 7 |
+
line_images.append(value["image"])
|
| 8 |
+
image_paths.append(value.get("image_path", f"{key}.png"))
|
| 9 |
+
|
| 10 |
+
return line_images, image_paths
|
utils/line_segmentation.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from skimage.io import imread
|
| 2 |
+
from skimage.color import rgb2gray
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.filters import threshold_otsu
|
| 6 |
+
import os
|
| 7 |
+
from skimage.graph import route_through_array
|
| 8 |
+
from heapq import heappush, heappop
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
def heuristic(a, b):
|
| 12 |
+
"""Calculate the squared distance between two points."""
|
| 13 |
+
return (b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_binary(img):
|
| 17 |
+
"""Binarize the image using Otsu's threshold."""
|
| 18 |
+
mean = np.mean(img)
|
| 19 |
+
if mean == 0.0 or mean == 1.0:
|
| 20 |
+
return img
|
| 21 |
+
|
| 22 |
+
thresh = threshold_otsu(img)
|
| 23 |
+
binary = img <= thresh
|
| 24 |
+
binary = binary.astype(np.uint8)
|
| 25 |
+
return binary
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def astar(array, start, goal):
|
| 29 |
+
"""Perform A* algorithm to find a path from start to goal in a binary array."""
|
| 30 |
+
neighbors = [(0,1),(0,-1),(1,0),(-1,0),(1,1),(1,-1),(-1,1),(-1,-1)]
|
| 31 |
+
close_set = set()
|
| 32 |
+
came_from = {}
|
| 33 |
+
gscore = {start:0}
|
| 34 |
+
fscore = {start:heuristic(start, goal)}
|
| 35 |
+
oheap = []
|
| 36 |
+
|
| 37 |
+
heappush(oheap, (fscore[start], start))
|
| 38 |
+
|
| 39 |
+
while oheap:
|
| 40 |
+
current = heappop(oheap)[1]
|
| 41 |
+
|
| 42 |
+
if current == goal:
|
| 43 |
+
data = []
|
| 44 |
+
while current in came_from:
|
| 45 |
+
data.append(current)
|
| 46 |
+
current = came_from[current]
|
| 47 |
+
return data
|
| 48 |
+
|
| 49 |
+
close_set.add(current)
|
| 50 |
+
for i, j in neighbors:
|
| 51 |
+
neighbor = current[0] + i, current[1] + j
|
| 52 |
+
tentative_g_score = gscore[current] + heuristic(current, neighbor)
|
| 53 |
+
if 0 <= neighbor[0] < array.shape[0]:
|
| 54 |
+
if 0 <= neighbor[1] < array.shape[1]:
|
| 55 |
+
if array[neighbor[0]][neighbor[1]] == 1:
|
| 56 |
+
continue
|
| 57 |
+
else:
|
| 58 |
+
# array bound y walls
|
| 59 |
+
continue
|
| 60 |
+
else:
|
| 61 |
+
# array bound x walls
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0):
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in oheap]:
|
| 68 |
+
came_from[neighbor] = current
|
| 69 |
+
gscore[neighbor] = tentative_g_score
|
| 70 |
+
fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal)
|
| 71 |
+
heappush(oheap, (fscore[neighbor], neighbor))
|
| 72 |
+
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def preprocess_image(img, target_size):
|
| 77 |
+
"""Read and convert an image to grayscale."""
|
| 78 |
+
try:
|
| 79 |
+
if target_size is not None:
|
| 80 |
+
img = img[target_size[0]:target_size[1], target_size[2]:target_size[3],:]
|
| 81 |
+
if img.ndim == 3 and img.shape[2] == 4:
|
| 82 |
+
img = img[..., :3]
|
| 83 |
+
if img.ndim > 2:
|
| 84 |
+
img = rgb2gray(img)
|
| 85 |
+
return img
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Error in preprocessing: {e}")
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def horizontal_projections(sobel_image):
|
| 92 |
+
"""Calculate horizontal projections of the binary image."""
|
| 93 |
+
return np.sum(sobel_image, axis=1)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def binarize_image(image):
|
| 97 |
+
"""Binarize an image using Otsu's threshold."""
|
| 98 |
+
threshold = threshold_otsu(image)
|
| 99 |
+
return image < threshold
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def find_peak_regions(hpp, threshold):
|
| 103 |
+
"""Identify peak regions based on the horizontal projection profile."""
|
| 104 |
+
peaks = []
|
| 105 |
+
for i, hppv in enumerate(hpp):
|
| 106 |
+
if hppv < threshold:
|
| 107 |
+
peaks.append(i)
|
| 108 |
+
return peaks
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def line_segmentation(image, threshold=None, min_peak_group_size=7, target_size=None,
|
| 112 |
+
ct=0, parent_line_num=None, recursive=False, recursive_count=1,
|
| 113 |
+
base_key="line"):
|
| 114 |
+
"""
|
| 115 |
+
Segment an image into lines using horizontal projections and A*.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
image: Input image (numpy array)
|
| 119 |
+
threshold (float, optional): Threshold for peak detection
|
| 120 |
+
min_peak_group_size (int): Minimum size of peak groups to consider
|
| 121 |
+
target_size (tuple, optional): Target size for image preprocessing
|
| 122 |
+
ct (int): Counter for line numbering
|
| 123 |
+
parent_line_num (str, optional): Parent line number for recursive segmentation
|
| 124 |
+
recursive (bool): Whether this is a recursive call
|
| 125 |
+
recursive_count (int): Counter for recursive segmentation numbering
|
| 126 |
+
base_key (str): Base key for dictionary entries
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
tuple: (segmented_images_dict, counter value, bool indicating if valid separations were found)
|
| 130 |
+
"""
|
| 131 |
+
segmented_images_dict = {}
|
| 132 |
+
|
| 133 |
+
img = preprocess_image(image, target_size)
|
| 134 |
+
if img is None:
|
| 135 |
+
print(f"Failed to preprocess image")
|
| 136 |
+
return segmented_images_dict, ct, False
|
| 137 |
+
|
| 138 |
+
# Binarize image and get projections
|
| 139 |
+
binarized_image = binarize_image(img)
|
| 140 |
+
hpp = horizontal_projections(binarized_image)
|
| 141 |
+
|
| 142 |
+
if threshold is None:
|
| 143 |
+
threshold = (np.max(hpp) - np.min(hpp)) / 2
|
| 144 |
+
|
| 145 |
+
# Find peaks
|
| 146 |
+
peaks = find_peak_regions(hpp, threshold)
|
| 147 |
+
if not peaks:
|
| 148 |
+
print(f"No peaks found in image")
|
| 149 |
+
return segmented_images_dict, ct, False
|
| 150 |
+
|
| 151 |
+
peaks_indexes = np.array(peaks).astype(int)
|
| 152 |
+
|
| 153 |
+
segmented_img = np.copy(img)
|
| 154 |
+
r, c = segmented_img.shape
|
| 155 |
+
for ri in range(r):
|
| 156 |
+
if ri in peaks_indexes:
|
| 157 |
+
segmented_img[ri, :] = 0
|
| 158 |
+
|
| 159 |
+
# Group peaks
|
| 160 |
+
diff_between_consec_numbers = np.diff(peaks_indexes)
|
| 161 |
+
indexes_with_larger_diff = np.where(diff_between_consec_numbers > 1)[0].flatten()
|
| 162 |
+
peak_groups = np.split(peaks_indexes, indexes_with_larger_diff + 1)
|
| 163 |
+
peak_groups = [item for item in peak_groups if len(item) > min_peak_group_size]
|
| 164 |
+
|
| 165 |
+
if not peak_groups:
|
| 166 |
+
print(f"No valid peak groups found in image")
|
| 167 |
+
return segmented_images_dict, ct, False
|
| 168 |
+
|
| 169 |
+
binary_image = get_binary(img)
|
| 170 |
+
segment_separating_lines = []
|
| 171 |
+
|
| 172 |
+
for sub_image_index in peak_groups:
|
| 173 |
+
try:
|
| 174 |
+
start_row = sub_image_index[0]
|
| 175 |
+
end_row = sub_image_index[-1]
|
| 176 |
+
|
| 177 |
+
start_row = max(0, start_row)
|
| 178 |
+
end_row = min(binary_image.shape[0], end_row)
|
| 179 |
+
|
| 180 |
+
if end_row <= start_row:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
nmap = binary_image[start_row:end_row, :]
|
| 184 |
+
|
| 185 |
+
if nmap.size == 0:
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
start_point = (int(nmap.shape[0] / 2), 0)
|
| 189 |
+
end_point = (int(nmap.shape[0] / 2), nmap.shape[1] - 1)
|
| 190 |
+
|
| 191 |
+
path, _ = route_through_array(nmap, start_point, end_point)
|
| 192 |
+
path = np.array(path) + start_row
|
| 193 |
+
segment_separating_lines.append(path)
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f"Failed to process sub-image: {e}")
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
if not segment_separating_lines:
|
| 199 |
+
print(f"No valid segment separating lines found in image")
|
| 200 |
+
return segmented_images_dict, ct, False
|
| 201 |
+
|
| 202 |
+
# Separate images based on line segments
|
| 203 |
+
seperated_images = []
|
| 204 |
+
for index in range(len(segment_separating_lines) - 1):
|
| 205 |
+
try:
|
| 206 |
+
lower_line = np.min(segment_separating_lines[index][:, 0])
|
| 207 |
+
upper_line = np.max(segment_separating_lines[index + 1][:, 0])
|
| 208 |
+
|
| 209 |
+
if lower_line < upper_line and upper_line <= img.shape[0]:
|
| 210 |
+
line_image = img[lower_line:upper_line]
|
| 211 |
+
if line_image.size > 0:
|
| 212 |
+
seperated_images.append(line_image)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"Failed to separate image at index {index}: {e}")
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
if not seperated_images:
|
| 218 |
+
print(f"No valid separated images found in image")
|
| 219 |
+
return segmented_images_dict, ct, False
|
| 220 |
+
|
| 221 |
+
# Calculate height threshold
|
| 222 |
+
try:
|
| 223 |
+
image_heights = [line_image.shape[0] for line_image in seperated_images if line_image.size > 0]
|
| 224 |
+
if not image_heights:
|
| 225 |
+
print(f"No valid image heights found")
|
| 226 |
+
return segmented_images_dict, ct, False
|
| 227 |
+
height_threshold = np.percentile(image_heights, 90)
|
| 228 |
+
except Exception as e:
|
| 229 |
+
print(f"Failed to calculate height threshold: {e}")
|
| 230 |
+
return segmented_images_dict, ct, False
|
| 231 |
+
|
| 232 |
+
# Process each separated image
|
| 233 |
+
for index, line_image in enumerate(seperated_images):
|
| 234 |
+
try:
|
| 235 |
+
if line_image.size == 0 or line_image.shape[0] == 0 or line_image.shape[1] == 0:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
if parent_line_num is None:
|
| 239 |
+
dict_key = f'{base_key}_{ct + 1}'
|
| 240 |
+
else:
|
| 241 |
+
dict_key = f'{base_key}_{recursive_count}'
|
| 242 |
+
if index < len(seperated_images) - 1:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
segmented_images_dict[dict_key] = {
|
| 246 |
+
"image": line_image.copy(),
|
| 247 |
+
"transcription": f"{dict_key}"
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
# print(f"Added line image to dictionary with key: {dict_key}")
|
| 251 |
+
|
| 252 |
+
# Handle recursive segmentation
|
| 253 |
+
if line_image.shape[0] > height_threshold and not recursive:
|
| 254 |
+
try:
|
| 255 |
+
# Create recursive base key
|
| 256 |
+
recursive_base_key = f"{base_key}_{ct + 1}"
|
| 257 |
+
|
| 258 |
+
# Do recursive segmentation
|
| 259 |
+
recursive_dict, ct, found_valid_separations = line_segmentation(
|
| 260 |
+
line_image, threshold=threshold,
|
| 261 |
+
min_peak_group_size=3,
|
| 262 |
+
parent_line_num=str(ct + 1),
|
| 263 |
+
recursive=True,
|
| 264 |
+
ct=ct,
|
| 265 |
+
recursive_count=1,
|
| 266 |
+
base_key=recursive_base_key
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if found_valid_separations:
|
| 270 |
+
del segmented_images_dict[dict_key]
|
| 271 |
+
segmented_images_dict.update(recursive_dict)
|
| 272 |
+
print(f"Replaced {dict_key} with recursive segmentation results")
|
| 273 |
+
else:
|
| 274 |
+
print(f"Keeping original image {dict_key} as no valid separations were found")
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"Failed during recursive segmentation of {dict_key}: {e}")
|
| 278 |
+
|
| 279 |
+
ct += 1
|
| 280 |
+
if recursive:
|
| 281 |
+
recursive_count += 1
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"Failed to process line image at index {index}: {e}")
|
| 285 |
+
continue
|
| 286 |
+
logger.info(f"Total lines segment found: {len(segmented_images_dict)}")
|
| 287 |
+
return segmented_images_dict, ct, len(seperated_images) > 0
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def segment_image_to_lines(image_array, **kwargs):
|
| 291 |
+
"""
|
| 292 |
+
Convenience function to segment an image into lines.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
image_array: Input image as numpy array
|
| 296 |
+
**kwargs: Additional arguments for line_segmentation
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
dict: Dictionary with line keys and segmented image arrays as values
|
| 300 |
+
"""
|
| 301 |
+
try:
|
| 302 |
+
|
| 303 |
+
logger.info("Starting line segmentation...")
|
| 304 |
+
segmented_dict, _, success = line_segmentation(image_array, **kwargs)
|
| 305 |
+
if success:
|
| 306 |
+
logger.info(f"Line segmentation successful.....")
|
| 307 |
+
|
| 308 |
+
return segmented_dict
|
| 309 |
+
except Exception as e:
|
| 310 |
+
logger.error(f"Line segmentation failed: {e}")
|
| 311 |
+
return {}
|
| 312 |
+
|
| 313 |
+
# if __name__ == "__main__":
|
| 314 |
+
# # Example usage
|
| 315 |
+
# image_path = "./renAI-deploy/1.png"
|
| 316 |
+
# image = imread(image_path)
|
| 317 |
+
# segmented_lines = segment_image_to_lines(image, threshold=None, min_peak_group_size=10)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# print(len(segmented_lines.values()))
|
| 321 |
+
|
| 322 |
+
# for key, value in segmented_lines.items():
|
| 323 |
+
# print(f"{key}: {value['image'].shape}")
|
| 324 |
+
# print(f"{key}: {value['transcription']}")
|
| 325 |
+
# # plt.imshow(img, cmap='gray')
|
| 326 |
+
# # plt.title(key)
|
| 327 |
+
# # plt.show()
|
utils/postprocessing.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import unicodedata
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import List, Tuple, Dict, Set
|
| 5 |
+
import heapq
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
class SpanishFuzzyMatcher:
|
| 9 |
+
def __init__(self, dictionary_path: str):
|
| 10 |
+
self.dictionary = set()
|
| 11 |
+
self.word_by_length = defaultdict(list)
|
| 12 |
+
self.ngram_index = defaultdict(set)
|
| 13 |
+
self.common_words = set()
|
| 14 |
+
|
| 15 |
+
self._load_dictionary(dictionary_path)
|
| 16 |
+
self._build_indexes()
|
| 17 |
+
self._load_common_words()
|
| 18 |
+
|
| 19 |
+
def _detect_encoding(self, path: str) -> str:
|
| 20 |
+
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252', 'utf-16']
|
| 21 |
+
|
| 22 |
+
for encoding in encodings:
|
| 23 |
+
try:
|
| 24 |
+
with open(path, 'r', encoding=encoding) as f:
|
| 25 |
+
f.read(1024) # Try to read first 1KB
|
| 26 |
+
return encoding
|
| 27 |
+
except (UnicodeDecodeError, UnicodeError):
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
return 'utf-8'
|
| 31 |
+
|
| 32 |
+
def _load_dictionary(self, path: str):
|
| 33 |
+
try:
|
| 34 |
+
encoding = self._detect_encoding(path)
|
| 35 |
+
print(f"Detected encoding: {encoding}")
|
| 36 |
+
|
| 37 |
+
with open(path, 'r', encoding=encoding, errors='ignore') as f:
|
| 38 |
+
for line_num, line in enumerate(f, 1):
|
| 39 |
+
try:
|
| 40 |
+
word = line.strip().lower()
|
| 41 |
+
if word and len(word) > 1:
|
| 42 |
+
# Remove any non-alphabetic characters except hyphens and apostrophes
|
| 43 |
+
cleaned_word = re.sub(r"[^a-záéíóúüñç\-']", "", word)
|
| 44 |
+
if cleaned_word and len(cleaned_word) > 1:
|
| 45 |
+
self.dictionary.add(cleaned_word)
|
| 46 |
+
self.word_by_length[len(cleaned_word)].append(cleaned_word)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Warning: Skipping line {line_num} due to error: {e}")
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
print(f"Loaded {len(self.dictionary)} words from dictionary")
|
| 52 |
+
|
| 53 |
+
except FileNotFoundError:
|
| 54 |
+
raise FileNotFoundError(f"Dictionary file not found: {path}")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise Exception(f"Error loading dictionary: {e}")
|
| 57 |
+
|
| 58 |
+
def _load_common_words(self):
|
| 59 |
+
common_spanish = {
|
| 60 |
+
'el', 'la', 'de', 'que', 'y', 'a', 'en', 'un', 'es', 'se', 'no', 'te', 'lo', 'le', 'da', 'su', 'por', 'son', 'con', 'para', 'al', 'las', 'del', 'los', 'una', 'mi', 'muy', 'mas', 'me', 'si', 'ya', 'todo', 'como', 'pero', 'hay', 'o', 'cuando', 'esta', 'ser', 'tiene', 'estar', 'hacer', 'sobre', 'entre', 'poder', 'antes', 'tiempo', 'año', 'casa', 'día', 'vida', 'trabajo', 'hombre', 'mujer', 'mundo', 'parte', 'momento', 'lugar', 'país', 'forma', 'manera', 'estado', 'caso', 'grupo', 'agua', 'punto', 'vez', 'donde', 'quien', 'haber', 'tener', 'hacer', 'decir', 'ir', 'ver', 'dar', 'saber', 'querer', 'llegar', 'pasar', 'deber', 'poner', 'parecer', 'quedar', 'creer', 'hablar', 'llevar', 'dejar', 'seguir', 'encontrar', 'llamar', 'venir', 'pensar', 'salir', 'volver', 'tomar', 'conocer', 'vivir', 'sentir', 'tratar', 'mirar', 'contar', 'empezar', 'esperar', 'buscar', 'existir', 'entrar', 'trabajar', 'escribir', 'perder', 'producir', 'ocurrir', 'entender', 'pedir', 'recibir', 'recordar', 'terminar', 'permitir', 'aparecer', 'conseguir', 'comenzar', 'servir', 'sacar', 'necesitar', 'mantener', 'resultar', 'leer', 'caer', 'cambiar', 'presentar', 'crear', 'abrir', 'considerar', 'oír', 'acabar', 'convertir', 'ganar', 'traer', 'realizar', 'suponer', 'comprender', 'explicar', 'dedicar', 'andar', 'estudiar', 'mano', 'cabeza', 'ojo', 'cara', 'pie', 'corazón', 'vez', 'palabra', 'número', 'color', 'mesa', 'silla', 'libro', 'papel', 'coche', 'calle', 'puerta', 'ventana', 'ciudad', 'pueblo', 'escuela', 'hospital', 'iglesia', 'tienda', 'mercado', 'banco', 'hotel', 'restaurante', 'café', 'bar', 'teatro', 'cine', 'museo', 'parque', 'jardín', 'playa', 'montaña', 'río', 'mar', 'lago', 'bosque', 'árbol', 'flor', 'animal', 'perro', 'gato', 'pájaro', 'pez', 'comida', 'pan', 'carne', 'pollo', 'pescado', 'leche', 'huevo', 'queso', 'fruta', 'verdura', 'patata', 'tomate', 'cebolla', 'ajo', 'sal', 'azúcar', 'aceite', 'vino', 'cerveza', 'café', 'té', 'agua', 'fuego', 'aire', 'tierra', 'sol', 'luna', 'estrella', 'nube', 'lluvia', 'nieve', 'viento', 'calor', 'frío', 'luz', 'sombra', 'mañana', 'tarde', 'noche', 'hoy', 'ayer', 'mañana', 'semana', 'mes', 'año', 'hora', 'minuto', 'segundo', 'lunes', 'martes', 'miércoles', 'jueves', 'viernes', 'sábado', 'domingo', 'enero', 'febrero', 'marzo', 'abril', 'mayo', 'junio', 'julio', 'agosto', 'septiembre', 'octubre', 'noviembre', 'diciembre', 'primavera', 'verano', 'otoño', 'invierno', 'bueno', 'malo', 'grande', 'pequeño', 'alto', 'bajo', 'largo', 'corto', 'ancho', 'estrecho', 'grueso', 'delgado', 'fuerte', 'débil', 'rápido', 'lento', 'fácil', 'difícil', 'nuevo', 'viejo', 'joven', 'mayor', 'blanco', 'negro', 'rojo', 'azul', 'verde', 'amarillo', 'gris', 'marrón', 'rosa', 'naranja', 'morado', 'feliz', 'triste', 'contento', 'enfadado', 'cansado', 'aburrido', 'interesante', 'divertido', 'importante', 'necesario', 'posible', 'imposible', 'seguro', 'peligroso', 'rico', 'pobre', 'caro', 'barato', 'limpio', 'sucio', 'sano', 'enfermo', 'vivo', 'muerto', 'lleno', 'vacío', 'abierto', 'cerrado', 'caliente', 'frío', 'seco', 'mojado', 'duro', 'blando', 'suave', 'áspero', 'dulce', 'amargo', 'salado', 'picante', 'conocerte', 'tengas'
|
| 61 |
+
}
|
| 62 |
+
self.common_words = {word for word in common_spanish if word in self.dictionary}
|
| 63 |
+
print(f"Loaded {len(self.common_words)} common words")
|
| 64 |
+
|
| 65 |
+
def _is_common_spanish_error(self, ocr_word: str, dict_word: str) -> bool:
|
| 66 |
+
ocr_lower = ocr_word.lower()
|
| 67 |
+
dict_lower = dict_word.lower()
|
| 68 |
+
|
| 69 |
+
# Common OCR confusions in Spanish
|
| 70 |
+
ocr_substitutions = {
|
| 71 |
+
'b': 'v', 'v': 'b', # b/v confusion
|
| 72 |
+
'c': 's', 's': 'c', # c/s confusion
|
| 73 |
+
'z': 's', 's': 'z', # z/s confusion
|
| 74 |
+
'j': 'g', 'g': 'j', # j/g confusion
|
| 75 |
+
'y': 'i', 'i': 'y', # y/i confusion
|
| 76 |
+
'u': 'n', 'n': 'u', # u/n confusion (handwriting)
|
| 77 |
+
'll': 'y', 'y': 'll', # ll/y confusion
|
| 78 |
+
'ñ': 'n', 'n': 'ñ', # ñ/n confusion
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
if len(ocr_lower) == len(dict_lower):
|
| 82 |
+
diff_count = sum(1 for a, b in zip(ocr_lower, dict_lower) if a != b)
|
| 83 |
+
if diff_count == 1:
|
| 84 |
+
for i, (a, b) in enumerate(zip(ocr_lower, dict_lower)):
|
| 85 |
+
if a != b:
|
| 86 |
+
return a in ocr_substitutions and ocr_substitutions[a] == b
|
| 87 |
+
|
| 88 |
+
return False
|
| 89 |
+
def _build_indexes(self):
|
| 90 |
+
for word in self.dictionary:
|
| 91 |
+
padded_word = f"${word}$"
|
| 92 |
+
for i in range(len(padded_word) - 2):
|
| 93 |
+
trigram = padded_word[i:i+3]
|
| 94 |
+
self.ngram_index[trigram].add(word)
|
| 95 |
+
|
| 96 |
+
def _normalize_text(self, text: str) -> str:
|
| 97 |
+
text = unicodedata.normalize('NFD', text)
|
| 98 |
+
text = ''.join(c for c in text if unicodedata.category(c) != 'Mn')
|
| 99 |
+
return text.lower()
|
| 100 |
+
|
| 101 |
+
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
| 102 |
+
if len(s1) < len(s2):
|
| 103 |
+
return self._levenshtein_distance(s2, s1)
|
| 104 |
+
|
| 105 |
+
if len(s2) == 0:
|
| 106 |
+
return len(s1)
|
| 107 |
+
|
| 108 |
+
previous_row = list(range(len(s2) + 1))
|
| 109 |
+
for i, c1 in enumerate(s1):
|
| 110 |
+
current_row = [i + 1]
|
| 111 |
+
for j, c2 in enumerate(s2):
|
| 112 |
+
insertions = previous_row[j + 1] + 1
|
| 113 |
+
deletions = current_row[j] + 1
|
| 114 |
+
substitutions = previous_row[j] + (c1 != c2)
|
| 115 |
+
current_row.append(min(insertions, deletions, substitutions))
|
| 116 |
+
previous_row = current_row
|
| 117 |
+
|
| 118 |
+
return previous_row[-1]
|
| 119 |
+
|
| 120 |
+
def _damerau_levenshtein_distance(self, s1: str, s2: str) -> int:
|
| 121 |
+
len1, len2 = len(s1), len(s2)
|
| 122 |
+
|
| 123 |
+
da = {}
|
| 124 |
+
for char in s1 + s2:
|
| 125 |
+
if char not in da:
|
| 126 |
+
da[char] = 0
|
| 127 |
+
|
| 128 |
+
max_dist = len1 + len2
|
| 129 |
+
h = [[max_dist for _ in range(len2 + 2)] for _ in range(len1 + 2)]
|
| 130 |
+
|
| 131 |
+
h[0][0] = max_dist
|
| 132 |
+
for i in range(0, len1 + 1):
|
| 133 |
+
h[i + 1][0] = max_dist
|
| 134 |
+
h[i + 1][1] = i
|
| 135 |
+
for j in range(0, len2 + 1):
|
| 136 |
+
h[0][j + 1] = max_dist
|
| 137 |
+
h[1][j + 1] = j
|
| 138 |
+
|
| 139 |
+
for i in range(1, len1 + 1):
|
| 140 |
+
db = 0
|
| 141 |
+
for j in range(1, len2 + 1):
|
| 142 |
+
k = da[s2[j - 1]]
|
| 143 |
+
l = db
|
| 144 |
+
if s1[i - 1] == s2[j - 1]:
|
| 145 |
+
cost = 0
|
| 146 |
+
db = j
|
| 147 |
+
else:
|
| 148 |
+
cost = 1
|
| 149 |
+
|
| 150 |
+
h[i + 1][j + 1] = min(
|
| 151 |
+
h[i][j] + cost, # substitution
|
| 152 |
+
h[i + 1][j] + 1, # insertion
|
| 153 |
+
h[i][j + 1] + 1, # deletion
|
| 154 |
+
h[k][l] + (i - k - 1) + 1 + (j - l - 1) # transposition
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
da[s1[i - 1]] = i
|
| 158 |
+
|
| 159 |
+
return h[len1 + 1][len2 + 1]
|
| 160 |
+
|
| 161 |
+
def _jaro_winkler_similarity(self, s1: str, s2: str) -> float:
|
| 162 |
+
def jaro_similarity(s1: str, s2: str) -> float:
|
| 163 |
+
if s1 == s2:
|
| 164 |
+
return 1.0
|
| 165 |
+
|
| 166 |
+
len1, len2 = len(s1), len(s2)
|
| 167 |
+
if len1 == 0 or len2 == 0:
|
| 168 |
+
return 0.0
|
| 169 |
+
|
| 170 |
+
match_window = max(len1, len2) // 2 - 1
|
| 171 |
+
if match_window < 0:
|
| 172 |
+
match_window = 0
|
| 173 |
+
|
| 174 |
+
s1_matches = [False] * len1
|
| 175 |
+
s2_matches = [False] * len2
|
| 176 |
+
|
| 177 |
+
matches = 0
|
| 178 |
+
transpositions = 0
|
| 179 |
+
|
| 180 |
+
for i in range(len1):
|
| 181 |
+
start = max(0, i - match_window)
|
| 182 |
+
end = min(i + match_window + 1, len2)
|
| 183 |
+
|
| 184 |
+
for j in range(start, end):
|
| 185 |
+
if s2_matches[j] or s1[i] != s2[j]:
|
| 186 |
+
continue
|
| 187 |
+
s1_matches[i] = s2_matches[j] = True
|
| 188 |
+
matches += 1
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
if matches == 0:
|
| 192 |
+
return 0.0
|
| 193 |
+
|
| 194 |
+
k = 0
|
| 195 |
+
for i in range(len1):
|
| 196 |
+
if not s1_matches[i]:
|
| 197 |
+
continue
|
| 198 |
+
while not s2_matches[k]:
|
| 199 |
+
k += 1
|
| 200 |
+
if s1[i] != s2[k]:
|
| 201 |
+
transpositions += 1
|
| 202 |
+
k += 1
|
| 203 |
+
|
| 204 |
+
jaro = (matches / len1 + matches / len2 +
|
| 205 |
+
(matches - transpositions / 2) / matches) / 3
|
| 206 |
+
return jaro
|
| 207 |
+
|
| 208 |
+
jaro = jaro_similarity(s1, s2)
|
| 209 |
+
|
| 210 |
+
prefix_len = 0
|
| 211 |
+
for i in range(min(len(s1), len(s2), 4)):
|
| 212 |
+
if s1[i] == s2[i]:
|
| 213 |
+
prefix_len += 1
|
| 214 |
+
else:
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
return jaro + (0.1 * prefix_len * (1 - jaro))
|
| 218 |
+
|
| 219 |
+
def _get_candidates(self, word: str, max_candidates: int = 200) -> Set[str]:
|
| 220 |
+
candidates = set()
|
| 221 |
+
word_len = len(word)
|
| 222 |
+
|
| 223 |
+
common_candidates = set()
|
| 224 |
+
for common_word in self.common_words:
|
| 225 |
+
if abs(len(common_word) - word_len) <= 2:
|
| 226 |
+
common_candidates.add(common_word)
|
| 227 |
+
|
| 228 |
+
candidates.update(common_candidates)
|
| 229 |
+
|
| 230 |
+
for length in range(max(1, word_len - 2), word_len + 3):
|
| 231 |
+
length_words = self.word_by_length[length]
|
| 232 |
+
# Sort by length (shorter words first) and limit
|
| 233 |
+
sorted_words = sorted(length_words, key=len)[:max_candidates//3]
|
| 234 |
+
candidates.update(sorted_words)
|
| 235 |
+
|
| 236 |
+
padded_word = f"${word}$"
|
| 237 |
+
trigram_candidates = set()
|
| 238 |
+
trigram_scores = defaultdict(int)
|
| 239 |
+
|
| 240 |
+
for i in range(len(padded_word) - 2):
|
| 241 |
+
trigram = padded_word[i:i+3]
|
| 242 |
+
if trigram in self.ngram_index:
|
| 243 |
+
for candidate in self.ngram_index[trigram]:
|
| 244 |
+
trigram_scores[candidate] += 1
|
| 245 |
+
|
| 246 |
+
sorted_trigram = sorted(trigram_scores.items(), key=lambda x: x[1], reverse=True)
|
| 247 |
+
trigram_candidates = {word for word, score in sorted_trigram[:max_candidates//2]}
|
| 248 |
+
candidates.update(trigram_candidates)
|
| 249 |
+
|
| 250 |
+
return candidates
|
| 251 |
+
|
| 252 |
+
def _calculate_composite_score(self, word1: str, word2: str) -> float:
|
| 253 |
+
norm_word1 = self._normalize_text(word1)
|
| 254 |
+
norm_word2 = self._normalize_text(word2)
|
| 255 |
+
|
| 256 |
+
levenshtein = self._levenshtein_distance(norm_word1, norm_word2)
|
| 257 |
+
damerau = self._damerau_levenshtein_distance(norm_word1, norm_word2)
|
| 258 |
+
jaro_winkler = self._jaro_winkler_similarity(norm_word1, norm_word2)
|
| 259 |
+
|
| 260 |
+
max_len = max(len(norm_word1), len(norm_word2))
|
| 261 |
+
if max_len == 0:
|
| 262 |
+
return 1.0
|
| 263 |
+
|
| 264 |
+
levenshtein_sim = 1 - (levenshtein / max_len)
|
| 265 |
+
damerau_sim = 1 - (damerau / max_len)
|
| 266 |
+
|
| 267 |
+
length_diff = abs(len(norm_word1) - len(norm_word2))
|
| 268 |
+
length_penalty = 1 - (length_diff / max(len(norm_word1), len(norm_word2)))
|
| 269 |
+
|
| 270 |
+
frequency_bonus = 1.0
|
| 271 |
+
if norm_word2 in self.common_words:
|
| 272 |
+
frequency_bonus = 1.3
|
| 273 |
+
|
| 274 |
+
spanish_error_bonus = 1.0
|
| 275 |
+
if self._is_common_spanish_error(word1, word2):
|
| 276 |
+
spanish_error_bonus = 1.2
|
| 277 |
+
|
| 278 |
+
exact_length_bonus = 1.0
|
| 279 |
+
if len(norm_word1) == len(norm_word2):
|
| 280 |
+
exact_length_bonus = 1.1
|
| 281 |
+
|
| 282 |
+
base_score = (
|
| 283 |
+
0.25 * levenshtein_sim +
|
| 284 |
+
0.45 * damerau_sim +
|
| 285 |
+
0.25 * jaro_winkler +
|
| 286 |
+
0.05 * length_penalty
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
final_score = base_score * frequency_bonus * spanish_error_bonus * exact_length_bonus
|
| 290 |
+
|
| 291 |
+
return min(final_score, 1.0)
|
| 292 |
+
|
| 293 |
+
def find_best_matches(self, word: str, top_k: int = 5, threshold: float = 0.4) -> List[Tuple[str, float]]:
|
| 294 |
+
if not word or len(word) < 2:
|
| 295 |
+
return []
|
| 296 |
+
|
| 297 |
+
normalized_word = self._normalize_text(word)
|
| 298 |
+
if normalized_word in self.dictionary:
|
| 299 |
+
return [(word, 1.0)]
|
| 300 |
+
|
| 301 |
+
if word.lower() in self.dictionary:
|
| 302 |
+
return [(word.lower(), 1.0)]
|
| 303 |
+
|
| 304 |
+
candidates = self._get_candidates(normalized_word)
|
| 305 |
+
|
| 306 |
+
scored_matches = []
|
| 307 |
+
for candidate in candidates:
|
| 308 |
+
score = self._calculate_composite_score(word, candidate)
|
| 309 |
+
if score >= threshold:
|
| 310 |
+
heapq.heappush(scored_matches, (-score, candidate, score))
|
| 311 |
+
|
| 312 |
+
results = []
|
| 313 |
+
seen_words = set()
|
| 314 |
+
for _ in range(min(top_k, len(scored_matches))):
|
| 315 |
+
if scored_matches:
|
| 316 |
+
_, candidate, score = heapq.heappop(scored_matches)
|
| 317 |
+
if candidate not in seen_words:
|
| 318 |
+
results.append((candidate, score))
|
| 319 |
+
seen_words.add(candidate)
|
| 320 |
+
|
| 321 |
+
return results
|
| 322 |
+
|
| 323 |
+
def correct_sentence(self, sentence: str, confidence_threshold: float = 0.6) -> str:
|
| 324 |
+
words = re.findall(r'\b\w+\b|\W+', sentence)
|
| 325 |
+
corrected_words = []
|
| 326 |
+
|
| 327 |
+
for token in words:
|
| 328 |
+
if re.match(r'\b\w+\b', token):
|
| 329 |
+
matches = self.find_best_matches(token, top_k=1, threshold=0.3)
|
| 330 |
+
|
| 331 |
+
if matches and matches[0][1] >= confidence_threshold:
|
| 332 |
+
corrected_words.append(matches[0][0])
|
| 333 |
+
else:
|
| 334 |
+
corrected_words.append(token)
|
| 335 |
+
else:
|
| 336 |
+
corrected_words.append(token)
|
| 337 |
+
|
| 338 |
+
return ''.join(corrected_words)
|
| 339 |
+
|
| 340 |
+
def PostProcessing(ocr_sentence):
|
| 341 |
+
try:
|
| 342 |
+
logger.info("Post processing started......")
|
| 343 |
+
matcher = SpanishFuzzyMatcher('Diccionario.Espanol.136k.palabras.txt')
|
| 344 |
+
logger.info("Dictionary loaded successfully!")
|
| 345 |
+
|
| 346 |
+
corrected = matcher.correct_sentence(ocr_sentence, confidence_threshold=0.6)
|
| 347 |
+
logger.info("Post processing completed successfully!")
|
| 348 |
+
return corrected
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
print(e)
|
| 352 |
+
logger.error(f"Post processing failed: {e}")
|
| 353 |
+
return ocr_sentence
|
utils/preprocessing.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import os
|
| 5 |
+
from deskew import determine_skew
|
| 6 |
+
from typing import Tuple, Union
|
| 7 |
+
import math
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
def preprocessImage(image):
|
| 11 |
+
"""
|
| 12 |
+
Preprocesses an image by applying various image processing steps such as denoising, thresholding,
|
| 13 |
+
and removal of horizontal and vertical lines, and saves the final processed image.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
- image_path (str): The file path to the input image to be processed.
|
| 17 |
+
- folder_path (str): The directory where the final processed image will be saved.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
- str: The path of the final processed image.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# Convert the image to grayscale
|
| 24 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 25 |
+
|
| 26 |
+
# Apply denoising
|
| 27 |
+
gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
|
| 28 |
+
|
| 29 |
+
# Apply binary thresholding using Otsu's method
|
| 30 |
+
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
| 31 |
+
|
| 32 |
+
# Copy the original image to preserve it
|
| 33 |
+
removed = image.copy()
|
| 34 |
+
|
| 35 |
+
# Remove vertical lines
|
| 36 |
+
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
|
| 37 |
+
remove_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
|
| 38 |
+
cnts = cv2.findContours(remove_vertical, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 39 |
+
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
|
| 40 |
+
for c in cnts:
|
| 41 |
+
cv2.drawContours(removed, [c], -1, (255, 255, 255), 4)
|
| 42 |
+
|
| 43 |
+
# Remove horizontal lines
|
| 44 |
+
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
|
| 45 |
+
remove_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
|
| 46 |
+
cnts = cv2.findContours(remove_horizontal, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 47 |
+
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
|
| 48 |
+
for c in cnts:
|
| 49 |
+
cv2.drawContours(removed, [c], -1, (255, 255, 255), 5)
|
| 50 |
+
|
| 51 |
+
# Repair kernel
|
| 52 |
+
repair_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 53 |
+
removed = 255 - removed
|
| 54 |
+
dilate = cv2.dilate(removed, repair_kernel, iterations=5)
|
| 55 |
+
dilate = cv2.cvtColor(dilate, cv2.COLOR_BGR2GRAY)
|
| 56 |
+
pre_result = cv2.bitwise_and(dilate, thresh)
|
| 57 |
+
|
| 58 |
+
# Final result
|
| 59 |
+
result = cv2.morphologyEx(pre_result, cv2.MORPH_CLOSE, repair_kernel, iterations=5)
|
| 60 |
+
final = cv2.bitwise_and(result, thresh)
|
| 61 |
+
|
| 62 |
+
# Invert the final image
|
| 63 |
+
invert_final = 255 - final
|
| 64 |
+
|
| 65 |
+
# processed_image_path = os.path.join(folder_path, f"{os.path.splitext(os.path.basename(image_path))[0]}-preprocessed.png")
|
| 66 |
+
# Save the final image
|
| 67 |
+
# cv2.imwrite(processed_image_path, invert_final)
|
| 68 |
+
|
| 69 |
+
return invert_final
|
| 70 |
+
|
| 71 |
+
def process_segment_and_crop_image(model, image, preprocess_image_path, padding=10, min_contour_area=100):
|
| 72 |
+
"""
|
| 73 |
+
Processes an image for segmentation using a U-Net model and crops the original image based on the largest contour.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
- model (tf.keras.Model): Trained U-Net model for image segmentation.
|
| 77 |
+
- img_path (str): Path to the original image.
|
| 78 |
+
- preprocess_image_path (str): Path to the preprocessed image.
|
| 79 |
+
- output_folder (str): Folder to save the cropped image.
|
| 80 |
+
- padding (int): Padding around the detected region.
|
| 81 |
+
- min_contour_area (int): Minimum contour area to be considered for cropping.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
- str: The path of the cropped image.
|
| 85 |
+
"""
|
| 86 |
+
# Read the original image in grayscale
|
| 87 |
+
|
| 88 |
+
img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 89 |
+
|
| 90 |
+
# Apply thresholding to create a binary image
|
| 91 |
+
_, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
|
| 92 |
+
|
| 93 |
+
# Resize the image to the model input size (512x512)
|
| 94 |
+
img = cv2.resize(img, (512, 512))
|
| 95 |
+
|
| 96 |
+
# Expand dimensions to match model input
|
| 97 |
+
img = np.expand_dims(img, axis=-1)
|
| 98 |
+
img_np = np.expand_dims(img, axis=0)
|
| 99 |
+
|
| 100 |
+
# Predict the segmentation mask using the U-Net model
|
| 101 |
+
pred = model.predict(img_np)
|
| 102 |
+
pred = np.squeeze(np.squeeze(pred, axis=0), axis=-1)
|
| 103 |
+
|
| 104 |
+
# # Display the segmentation result
|
| 105 |
+
# plt.imshow(pred, cmap='gray')
|
| 106 |
+
# plt.title('U-Net Segmentation')
|
| 107 |
+
# plt.axis('off')
|
| 108 |
+
# plt.show()
|
| 109 |
+
|
| 110 |
+
# Read the original image
|
| 111 |
+
original_img = cv2.imread(preprocess_image_path)
|
| 112 |
+
|
| 113 |
+
# Get original dimensions
|
| 114 |
+
ori_height, ori_width = original_img.shape[:2]
|
| 115 |
+
|
| 116 |
+
# Resize the mask to match the original image dimensions
|
| 117 |
+
resized_mask = cv2.resize(pred, (ori_width, ori_height))
|
| 118 |
+
|
| 119 |
+
# Convert the resized mask to 8-bit unsigned integer type
|
| 120 |
+
resized_mask = (resized_mask * 255).astype(np.uint8)
|
| 121 |
+
|
| 122 |
+
# Apply Otsu's threshold to get a binary image
|
| 123 |
+
_, binary_mask = cv2.threshold(resized_mask, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 124 |
+
|
| 125 |
+
# Apply morphological operations to remove noise and connect nearby text
|
| 126 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 127 |
+
cleaned_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
| 128 |
+
cleaned_mask = cv2.morphologyEx(cleaned_mask, cv2.MORPH_OPEN, kernel)
|
| 129 |
+
|
| 130 |
+
# Find contours in the cleaned mask
|
| 131 |
+
contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 132 |
+
|
| 133 |
+
# Filter contours based on area to remove small noise
|
| 134 |
+
valid_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > min_contour_area]
|
| 135 |
+
|
| 136 |
+
if not valid_contours:
|
| 137 |
+
print("No valid text regions found.")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
# Find the bounding rectangle that encompasses all valid contours
|
| 141 |
+
x_min, y_min = ori_width, ori_height
|
| 142 |
+
x_max, y_max = 0, 0
|
| 143 |
+
|
| 144 |
+
for contour in valid_contours:
|
| 145 |
+
x, y, w, h = cv2.boundingRect(contour)
|
| 146 |
+
x_min = min(x_min, x)
|
| 147 |
+
y_min = min(y_min, y)
|
| 148 |
+
x_max = max(x_max, x + w)
|
| 149 |
+
y_max = max(y_max, y + h)
|
| 150 |
+
|
| 151 |
+
x_min = max(0, x_min - padding)
|
| 152 |
+
y_min = max(0, y_min - padding)
|
| 153 |
+
x_max = min(ori_width, x_max + padding)
|
| 154 |
+
y_max = min(ori_height, y_max + padding)
|
| 155 |
+
|
| 156 |
+
# Crop the original image
|
| 157 |
+
cropped_img = original_img[y_min:y_max, x_min:x_max]
|
| 158 |
+
|
| 159 |
+
return cropped_img
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def postProcessImage(cropped_image):
|
| 163 |
+
"""
|
| 164 |
+
Post-processes an image by deskewing, sharpening, and applying morphological dilation, then saves the final processed image.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
- image_path (str): Path to the original image.
|
| 168 |
+
- cropped_image_path (str): Path to the cropped image to be post-processed.
|
| 169 |
+
- output_folder (str): Directory where the final post-processed image will be saved.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
- str: The path of the final post-processed image.
|
| 173 |
+
"""
|
| 174 |
+
def rotate(
|
| 175 |
+
image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]]
|
| 176 |
+
) -> np.ndarray:
|
| 177 |
+
old_width, old_height = image.shape[:2]
|
| 178 |
+
angle_radian = math.radians(angle)
|
| 179 |
+
width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width)
|
| 180 |
+
height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height)
|
| 181 |
+
|
| 182 |
+
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
| 183 |
+
rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
| 184 |
+
rot_mat[1, 2] += (width - old_width) / 2
|
| 185 |
+
rot_mat[0, 2] += (height - old_height) / 2
|
| 186 |
+
return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background)
|
| 187 |
+
|
| 188 |
+
# Deskew Image
|
| 189 |
+
# grayscale = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
|
| 190 |
+
# angle = determine_skew(grayscale)
|
| 191 |
+
# rotated = rotate(image, angle, (0, 0, 0))
|
| 192 |
+
rotated = cropped_image
|
| 193 |
+
|
| 194 |
+
# Sharpening (reduced intensity)
|
| 195 |
+
blurred = cv2.GaussianBlur(rotated, (1,1), sigmaX=3, sigmaY=3)
|
| 196 |
+
sharpened = cv2.addWeighted(rotated, 1.5, blurred, -0.5, 0)
|
| 197 |
+
|
| 198 |
+
# Morphological dilation to thicken the text
|
| 199 |
+
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
|
| 200 |
+
dilated = cv2.dilate(sharpened, dilate_kernel, iterations=1)
|
| 201 |
+
|
| 202 |
+
return sharpened
|
utils/unet.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Importing required libraries.
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import cv2
|
| 5 |
+
import os
|
| 6 |
+
from keras.layers import *
|
| 7 |
+
from keras.models import Model
|
| 8 |
+
from keras.optimizers import Adam
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
def unet(pretrained_weights = None,input_size = (512,512,1)):
|
| 12 |
+
inputs = Input(input_size)
|
| 13 |
+
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
|
| 14 |
+
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
|
| 15 |
+
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
|
| 16 |
+
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
|
| 17 |
+
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
|
| 18 |
+
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
|
| 19 |
+
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
|
| 20 |
+
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
|
| 21 |
+
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
|
| 22 |
+
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
|
| 23 |
+
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
|
| 24 |
+
drop4 = Dropout(0.5)(conv4)
|
| 25 |
+
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
|
| 26 |
+
|
| 27 |
+
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
|
| 28 |
+
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
|
| 29 |
+
drop5 = Dropout(0.5)(conv5)
|
| 30 |
+
|
| 31 |
+
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
|
| 32 |
+
merge6 = concatenate([drop4,up6], axis = 3)
|
| 33 |
+
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
|
| 34 |
+
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
|
| 35 |
+
|
| 36 |
+
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
|
| 37 |
+
merge7 = concatenate([conv3,up7], axis = 3)
|
| 38 |
+
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
|
| 39 |
+
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
|
| 40 |
+
|
| 41 |
+
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
|
| 42 |
+
merge8 = concatenate([conv2,up8], axis = 3)
|
| 43 |
+
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
|
| 44 |
+
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
|
| 45 |
+
|
| 46 |
+
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
|
| 47 |
+
merge9 = concatenate([conv1,up9], axis = 3)
|
| 48 |
+
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
|
| 49 |
+
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
|
| 50 |
+
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
|
| 51 |
+
conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
|
| 52 |
+
|
| 53 |
+
model = Model(inputs,conv10)
|
| 54 |
+
|
| 55 |
+
model.compile(optimizer = Adam(learning_rate=1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
|
| 56 |
+
|
| 57 |
+
#model.summary()
|
| 58 |
+
|
| 59 |
+
if(pretrained_weights):
|
| 60 |
+
model.load_weights(pretrained_weights)
|
| 61 |
+
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
vit.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 8 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments, EarlyStoppingCallback
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from evaluate import load
|
| 15 |
+
import albumentations as A
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from configs import model_path, processor_path
|
| 19 |
+
|
| 20 |
+
# Enable mixed precision training
|
| 21 |
+
torch.backends.cudnn.benchmark = True
|
| 22 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Load metrics
|
| 26 |
+
cer_metric = load("cer")
|
| 27 |
+
wer_metric = load("wer")
|
| 28 |
+
|
| 29 |
+
processor = TrOCRProcessor.from_pretrained(processor_path, do_rescale=False,use_fast=True)
|
| 30 |
+
model = VisionEncoderDecoderModel.from_pretrained(model_path,use_safetensors=True)
|
| 31 |
+
|
| 32 |
+
def compute_metrics(eval_pred):
|
| 33 |
+
logits, labels = eval_pred
|
| 34 |
+
if isinstance(logits, tuple):
|
| 35 |
+
logits = logits[0]
|
| 36 |
+
predictions = logits.argmax(-1)
|
| 37 |
+
decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
| 38 |
+
decoded_labels = []
|
| 39 |
+
for label in labels:
|
| 40 |
+
label_filtered = [token for token in label if token != -100]
|
| 41 |
+
decoded_label = processor.tokenizer.decode(label_filtered, skip_special_tokens=True)
|
| 42 |
+
decoded_labels.append(decoded_label)
|
| 43 |
+
cer_score = cer_metric.compute(predictions=decoded_preds, references=decoded_labels)
|
| 44 |
+
wer_score = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
|
| 45 |
+
return {"cer": cer_score, "wer": wer_score}
|
| 46 |
+
|
| 47 |
+
class LineDataset(Dataset):
|
| 48 |
+
def __init__(self, processor, model, line_images, texts, target_size=(384, 96), max_length=512, apply_augmentation=False):
|
| 49 |
+
self.line_images = line_images
|
| 50 |
+
self.texts = texts
|
| 51 |
+
self.processor = processor
|
| 52 |
+
self.processor.image_processor.max_length = max_length
|
| 53 |
+
self.processor.tokenizer.model_max_length = max_length
|
| 54 |
+
self.model = model
|
| 55 |
+
self.model.config.max_length = max_length
|
| 56 |
+
self.target_size = target_size
|
| 57 |
+
self.max_length = max_length
|
| 58 |
+
self.apply_augmentation = apply_augmentation
|
| 59 |
+
|
| 60 |
+
if apply_augmentation:
|
| 61 |
+
self.transform = A.Compose([
|
| 62 |
+
A.OneOf([
|
| 63 |
+
A.Rotate(limit=2, p=1.0),
|
| 64 |
+
A.ElasticTransform(alpha=0.3, sigma=50.0, alpha_affine=0.3, p=1.0),
|
| 65 |
+
A.OpticalDistortion(distort_limit=0.03, shift_limit=0.03, p=1.0),
|
| 66 |
+
A.CLAHE(clip_limit=2, tile_grid_size=(4, 4), p=1.0),
|
| 67 |
+
A.Affine(scale=(0.95, 1.05), translate_percent=(0.02, 0.02), shear=(-2, 2), p=1.0),
|
| 68 |
+
A.Perspective(scale=(0.01, 0.03), p=1.0),
|
| 69 |
+
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
|
| 70 |
+
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
|
| 71 |
+
A.GridDistortion(num_steps=3, distort_limit=0.02, p=1.0),
|
| 72 |
+
A.MedianBlur(blur_limit=3, p=1.0),
|
| 73 |
+
], p=0.7),
|
| 74 |
+
])
|
| 75 |
+
else:
|
| 76 |
+
self.transform = A.Compose([])
|
| 77 |
+
|
| 78 |
+
def __len__(self):
|
| 79 |
+
return len(self.line_images)
|
| 80 |
+
|
| 81 |
+
def __getitem__(self, idx):
|
| 82 |
+
image = self.line_images[idx]
|
| 83 |
+
text = self.texts[idx]
|
| 84 |
+
|
| 85 |
+
if isinstance(image, Image.Image):
|
| 86 |
+
image = np.array(image)
|
| 87 |
+
|
| 88 |
+
if image.ndim == 2:
|
| 89 |
+
image = np.expand_dims(image, axis=-1)
|
| 90 |
+
image = np.repeat(image, 3, axis=-1)
|
| 91 |
+
|
| 92 |
+
image = (image * 255).astype(np.uint8)
|
| 93 |
+
|
| 94 |
+
if self.apply_augmentation and self.transform:
|
| 95 |
+
augmented = self.transform(image=image)
|
| 96 |
+
image = augmented['image']
|
| 97 |
+
|
| 98 |
+
image = Image.fromarray(image)
|
| 99 |
+
image = image.resize(self.target_size, Image.LANCZOS)
|
| 100 |
+
image = np.array(image) / 255.0
|
| 101 |
+
image = np.transpose(image, (2, 0, 1))
|
| 102 |
+
|
| 103 |
+
encoding = self.processor(images=image, text=text, return_tensors="pt")
|
| 104 |
+
encoding['labels'] = encoding['labels'][:, :self.max_length]
|
| 105 |
+
encoding = {k: v.squeeze() for k, v in encoding.items()}
|
| 106 |
+
return encoding
|
| 107 |
+
|
| 108 |
+
def collate_fn(batch):
|
| 109 |
+
pixel_values = torch.stack([item['pixel_values'] for item in batch])
|
| 110 |
+
labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)
|
| 111 |
+
return {'pixel_values': pixel_values, 'labels': labels}
|