Arsh124 commited on
Commit
ebcc7d1
·
1 Parent(s): 6972ce0

Initial RenAI app

Browse files
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ RUN useradd user
4
+
5
+ USER user
6
+
7
+ ENV HOME=/home/user \
8
+ PATH="/home/user/.local/bin:$PATH" \
9
+ PYTHONUNBUFFERED=1 \
10
+ PYTHONDONTWRITEBYTECODE=1
11
+
12
+ WORKDIR $HOME/app
13
+
14
+ RUN apt-get update && apt-get install -y --no-install-recommends \
15
+ libglib2.0-0 \
16
+ libsm6 \
17
+ libxext6 \
18
+ libxrender-dev \
19
+ libgomp1 \
20
+ libgtk-3-0 \
21
+ libavcodec-dev \
22
+ libavformat-dev \
23
+ libswscale-dev \
24
+ && rm -rf /var/lib/apt/lists/*
25
+
26
+ COPY requirements.txt .
27
+
28
+ RUN pip install --no-cache-dir --timeout=100 -r requirements.txt
29
+
30
+ COPY . .
31
+
32
+ EXPOSE 7860
33
+
34
+ CMD ["uvicorn", "app:app", "--host=0.0.0.0", "--port=7860"]
__pycache__/configs.cpython-312.pyc ADDED
Binary file (284 Bytes). View file
 
__pycache__/inference.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
__pycache__/vit.cpython-312.pyc ADDED
Binary file (7.01 kB). View file
 
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from typing import Optional, Any, Dict, Union
5
+ import shutil
6
+ import os
7
+ import json
8
+ from loguru import logger
9
+ from pathlib import Path
10
+ from main import RenAITranscription
11
+ import tempfile
12
+ import numpy as np
13
+ from datetime import datetime
14
+ import base64
15
+ from io import BytesIO
16
+ from PIL import Image
17
+
18
+ app = FastAPI(title="RenAI Transcription API", version="1.0.0")
19
+
20
+ # Add CORS middleware
21
+ # app.add_middleware(
22
+ # CORSMiddleware,
23
+ # allow_origins=["*"],
24
+ # allow_credentials=True,
25
+ # allow_methods=["*"],
26
+ # allow_headers=["*"],
27
+ # )
28
+
29
+ ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
30
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
31
+
32
+ def numpy_to_base64(array: np.ndarray, format: str = 'PNG', quality: int = 85) -> str:
33
+ """
34
+ Convert numpy array (image) to base64 encoded string for web display.
35
+
36
+ Args:
37
+ array: Numpy array representing the image
38
+ format: Image format ('PNG' or 'JPEG')
39
+ quality: JPEG quality (1-100), only used if format is JPEG
40
+
41
+ Returns:
42
+ Data URI string that can be directly used in HTML <img> src attribute
43
+ """
44
+ try:
45
+ # Convert numpy array to PIL Image
46
+ img = Image.fromarray(array)
47
+
48
+ # Save to bytes buffer
49
+ buffer = BytesIO()
50
+ if format.upper() == 'JPEG':
51
+ # Convert to RGB if needed (JPEG doesn't support transparency)
52
+ if img.mode in ('RGBA', 'LA', 'P'):
53
+ background = Image.new('RGB', img.size, (255, 255, 255))
54
+ if img.mode == 'P':
55
+ img = img.convert('RGBA')
56
+ background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
57
+ img = background
58
+ img.save(buffer, format='JPEG', quality=quality, optimize=True)
59
+ mime_type = 'image/jpeg'
60
+ else:
61
+ img.save(buffer, format='PNG', optimize=True)
62
+ mime_type = 'image/png'
63
+
64
+ # Encode to base64
65
+ img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
66
+ return f"data:{mime_type};base64,{img_str}"
67
+ except Exception as e:
68
+ logger.error(f"Error converting numpy array to base64: {e}")
69
+ return None
70
+
71
+ def format_transcription_result(result: Dict, include_images: bool = False, image_format: str = 'PNG') -> Dict[str, Any]:
72
+ """
73
+ Format transcription result into a structured response.
74
+
75
+ Args:
76
+ result: Dictionary with line IDs as keys, each containing 'image' and 'transcription'
77
+ include_images: Whether to include base64 encoded images in response
78
+ image_format: Image format for base64 encoding ('PNG' or 'JPEG')
79
+
80
+ Returns:
81
+ Formatted dictionary with transcription data
82
+ """
83
+ formatted_lines = {}
84
+ transcription_text = []
85
+
86
+ for line_id, line_data in result.items():
87
+ formatted_line = {
88
+ 'line_id': line_id,
89
+ 'transcription': line_data.get('transcription', '')
90
+ }
91
+
92
+ # Optionally include image as base64 (web-ready format)
93
+ if include_images and 'image' in line_data:
94
+ image_array = line_data['image']
95
+ if isinstance(image_array, np.ndarray):
96
+ image_base64 = numpy_to_base64(image_array, format=image_format)
97
+ if image_base64:
98
+ formatted_line['image'] = image_base64
99
+
100
+ formatted_lines[line_id] = formatted_line
101
+ transcription_text.append(f"{line_id}: {line_data.get('transcription', '')}")
102
+
103
+ return {
104
+ 'lines': formatted_lines,
105
+ 'full_text': '\n'.join(transcription_text),
106
+ 'total_lines': len(result)
107
+ }
108
+
109
+ @app.get("/")
110
+ def home():
111
+ return {
112
+ "message": "Hello, RenAI!",
113
+ "version": "1.0.0",
114
+ "endpoints": {
115
+ "transcribe": "/renai-transcribe (POST)",
116
+ "transcribe_base64": "/renai-transcribe-base64 (POST)",
117
+ "health": "/health (GET)"
118
+ }
119
+ }
120
+
121
+ @app.post("/renai-transcribe")
122
+ async def transcription_endpoint(
123
+ image: UploadFile = File(..., description="Image file to transcribe"),
124
+ userToken: Optional[str] = Form(None, description="User authentication token"),
125
+ post_processing_enabled: bool = Form(False, description="Enable post-processing"),
126
+ unet_enabled: bool = Form(False, description="Enable UNet processing"),
127
+ include_images: bool = Form(True, description="Include base64 encoded line images in response"),
128
+ image_format: str = Form("JPEG", description="Image format for line images: PNG or JPEG")
129
+ ):
130
+ """
131
+ Upload an image file and get transcription results.
132
+
133
+ - **image**: Image file (JPG, PNG, BMP, TIFF, WebP)
134
+ - **userToken**: Optional user authentication token
135
+ - **post_processing_enabled**: Enable/disable post-processing
136
+ - **unet_enabled**: Enable/disable UNet processing
137
+ - **include_images**: Include base64 encoded images of each line (web-ready format)
138
+ - **image_format**: Format for line images: 'PNG' (higher quality, larger) or 'JPEG' (smaller, faster)
139
+ """
140
+ start_time = datetime.now()
141
+ logger.info(f"Transcription request received for file: {image.filename} by userToken: {userToken if userToken else 'Anonymous'}")
142
+
143
+ # Validate file type
144
+ if not image.filename:
145
+ raise HTTPException(status_code=400, detail="No file provided")
146
+
147
+ file_extension = Path(image.filename).suffix.lower()
148
+ if file_extension not in ALLOWED_EXTENSIONS:
149
+ raise HTTPException(
150
+ status_code=400,
151
+ detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
152
+ )
153
+
154
+ # Check file size
155
+ if image.size and image.size > MAX_FILE_SIZE:
156
+ raise HTTPException(
157
+ status_code=400,
158
+ detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB"
159
+ )
160
+
161
+ temp_file_path = None
162
+ try:
163
+ # Create temporary file
164
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
165
+ shutil.copyfileobj(image.file, temp_file)
166
+ temp_file_path = temp_file.name
167
+
168
+ logger.info(f"Processing image: {temp_file_path}")
169
+
170
+ # Call transcription function
171
+ result = RenAITranscription(
172
+ image=temp_file_path,
173
+ post_processing_enabled=post_processing_enabled,
174
+ unet_enabled=unet_enabled
175
+ )
176
+
177
+ logger.info(f"Transcription completed. Result type: {type(result)}, Lines: {len(result)}")
178
+
179
+ # Format the result
180
+ formatted_result = format_transcription_result(result, include_images=include_images, image_format=image_format)
181
+
182
+ # Clean up
183
+ os.unlink(temp_file_path)
184
+
185
+ processing_time = (datetime.now() - start_time).total_seconds()
186
+ logger.info(f"Request completed in {processing_time:.2f}s")
187
+
188
+ response_data = {
189
+ "success": True,
190
+ "filename": image.filename,
191
+ "transcription": formatted_result,
192
+ "metadata": {
193
+ "processing_time_seconds": round(processing_time, 2),
194
+ "timestamp": datetime.now().isoformat(),
195
+ "total_lines": formatted_result['total_lines'],
196
+ "parameters": {
197
+ "post_processing_enabled": post_processing_enabled,
198
+ "unet_enabled": unet_enabled,
199
+ "include_images": include_images,
200
+ "userToken": userToken if userToken else "Anonymous"
201
+ }
202
+ }
203
+ }
204
+
205
+ return JSONResponse(content=response_data)
206
+
207
+ except Exception as e:
208
+ # Clean up
209
+ if temp_file_path and os.path.exists(temp_file_path):
210
+ try:
211
+ os.unlink(temp_file_path)
212
+ except:
213
+ pass
214
+
215
+ logger.error(f"Transcription failed: {e}")
216
+
217
+ raise HTTPException(
218
+ status_code=500,
219
+ detail={
220
+ "error": str(e),
221
+ "type": type(e).__name__
222
+ }
223
+ )
224
+
225
+ @app.post("/renai-transcribe-base64")
226
+ async def transcription_base64_endpoint(
227
+ image_data: str = Form(..., description="Base64 encoded image data"),
228
+ userToken: Optional[str] = Form(None, description="User authentication token"),
229
+ post_processing_enabled: bool = Form(False, description="Enable post-processing"),
230
+ unet_enabled: bool = Form(False, description="Enable UNet processing"),
231
+ include_images: bool = Form(False, description="Include base64 encoded line images in response"),
232
+ image_format: str = Form("JPEG", description="Image format for line images: PNG or JPEG")
233
+ ):
234
+ """
235
+ Alternative endpoint that accepts base64 encoded image data.
236
+ """
237
+ import base64
238
+ import io
239
+ from PIL import Image
240
+
241
+ start_time = datetime.now()
242
+ logger.info(f"Base64 transcription request received by userToken: {userToken if userToken else 'Anonymous'}")
243
+
244
+ temp_file_path = None
245
+ try:
246
+ # Remove data URL prefix if present
247
+ if "," in image_data:
248
+ image_data = image_data.split(",", 1)[1]
249
+
250
+ # Decode base64 image
251
+ image_bytes = base64.b64decode(image_data)
252
+ image_pil = Image.open(io.BytesIO(image_bytes))
253
+
254
+ # Create temporary file
255
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
256
+ image_pil.save(temp_file.name)
257
+ temp_file_path = temp_file.name
258
+
259
+ logger.info(f"Processing base64 image: {temp_file_path}")
260
+
261
+ # Call transcription function
262
+ result = RenAITranscription(
263
+ image=temp_file_path,
264
+ post_processing_enabled=post_processing_enabled,
265
+ unet_enabled=unet_enabled
266
+ )
267
+
268
+ # Format the result
269
+ formatted_result = format_transcription_result(result, include_images=include_images, image_format=image_format)
270
+
271
+ # Clean up
272
+ os.unlink(temp_file_path)
273
+
274
+ processing_time = (datetime.now() - start_time).total_seconds()
275
+ logger.info(f"Base64 request completed in {processing_time:.2f}s")
276
+
277
+ response_data = {
278
+ "success": True,
279
+ "transcription": formatted_result,
280
+ "metadata": {
281
+ "processing_time_seconds": round(processing_time, 2),
282
+ "timestamp": datetime.now().isoformat(),
283
+ "total_lines": formatted_result['total_lines'],
284
+ "parameters": {
285
+ "post_processing_enabled": post_processing_enabled,
286
+ "unet_enabled": unet_enabled,
287
+ "include_images": include_images,
288
+ "image_format": image_format if include_images else None,
289
+ "userToken": userToken if userToken else "Anonymous"
290
+ }
291
+ }
292
+ }
293
+
294
+ return JSONResponse(content=response_data)
295
+
296
+ except Exception as e:
297
+ if temp_file_path and os.path.exists(temp_file_path):
298
+ try:
299
+ os.unlink(temp_file_path)
300
+ except:
301
+ pass
302
+
303
+ logger.error(f"Base64 transcription failed: {e}")
304
+
305
+ raise HTTPException(
306
+ status_code=500,
307
+ detail={
308
+ "error": str(e),
309
+ "type": type(e).__name__
310
+ }
311
+ )
312
+
313
+ @app.get("/health")
314
+ def health_check():
315
+ try:
316
+ return {
317
+ "status": "healthy",
318
+ "service": "RenAI Transcription API",
319
+ "timestamp": datetime.now().isoformat()
320
+ }
321
+ except Exception as e:
322
+ logger.error(f"Health check failed: {e}")
323
+ raise HTTPException(status_code=500, detail="Service unhealthy")
inference.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from PIL import Image
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import re
7
+ import cv2
8
+ import string
9
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
10
+ from vit import LineDataset, collate_fn
11
+ from loguru import logger
12
+ import os
13
+
14
+ class Inference:
15
+ def __init__(self, model_path, processor_path, target_size=(256, 64), batch_size=32):
16
+ """
17
+ Initialize the TextGenerator with model and processor paths.
18
+
19
+ Args:
20
+ model_path (str): Path to the pre-trained model
21
+ processor_path (str): Path to the pre-trained processor
22
+ target_size (tuple): Target size for input images (height, width)
23
+ batch_size (int): Batch size for inference
24
+ """
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ self.model_path = self._get_absolute_path(model_path)
27
+ self.processor_path = self._get_absolute_path(processor_path)
28
+ self.target_size = target_size
29
+ self.batch_size = batch_size
30
+
31
+ # Initialize model and processor
32
+ self.processor = None
33
+ self.model = None
34
+ self._initialize_model()
35
+
36
+ def _get_absolute_path(self, path):
37
+ """Convert relative path to absolute path"""
38
+ if os.path.isabs(path):
39
+ return path
40
+ # If it's a relative path, make it absolute relative to the current working directory
41
+ return os.path.join(os.getcwd(), path.lstrip('./'))
42
+
43
+
44
+ def _initialize_model(self):
45
+ """Load and initialize the model and processor."""
46
+ logger.info("Loading model...")
47
+
48
+ # Check if paths exist
49
+ if not os.path.exists(self.model_path):
50
+ raise FileNotFoundError(f"Model path not found: {self.model_path}")
51
+ if not os.path.exists(self.processor_path):
52
+ raise FileNotFoundError(f"Processor path not found: {self.processor_path}")
53
+
54
+ # List all files in the model directory
55
+ all_files = os.listdir(self.model_path)
56
+
57
+ # Validate that we have the necessary files
58
+ if not any(f in all_files for f in ['pytorch_model.bin', 'model.safetensors']):
59
+ logger.error("No model weights file found! (pytorch_model.bin or model.safetensors)")
60
+ raise FileNotFoundError("Model weights file missing")
61
+
62
+ if 'config.json' not in all_files:
63
+ logger.error("config.json file not found!")
64
+ raise FileNotFoundError("config.json missing")
65
+
66
+ logger.info(f"Loading model from: {self.model_path}")
67
+ logger.info(f"Loading processor from: {self.processor_path}")
68
+
69
+ try:
70
+ # Load processor
71
+ self.processor = TrOCRProcessor.from_pretrained(self.processor_path, do_rescale=False, use_fast=True)
72
+ logger.info("Processor loaded successfully")
73
+
74
+ # Try different loading methods for the model
75
+ logger.info("Attempting to load model...")
76
+
77
+ # Method 1: Try with explicit device mapping
78
+ try:
79
+ self.model = VisionEncoderDecoderModel.from_pretrained(
80
+ self.model_path,
81
+ use_safetensors=True,
82
+ device_map="auto" if torch.cuda.is_available() else None
83
+ )
84
+ logger.info("Model loaded with safetensors=True and device_map")
85
+ except Exception as e1:
86
+ logger.warning(f"Method 1 failed: {e1}")
87
+
88
+ # Method 2: Try without device mapping
89
+ try:
90
+ self.model = VisionEncoderDecoderModel.from_pretrained(
91
+ self.model_path,
92
+ use_safetensors=True
93
+ )
94
+ logger.info("Model loaded with safetensors=True")
95
+ except Exception as e2:
96
+ logger.warning(f"Method 2 failed: {e2}")
97
+
98
+ # Method 3: Try without safetensors
99
+ try:
100
+ self.model = VisionEncoderDecoderModel.from_pretrained(
101
+ self.model_path,
102
+ use_safetensors=True
103
+ )
104
+ logger.info("Model loaded with safetensors=False")
105
+ except Exception as e3:
106
+ logger.error(f"All loading methods failed: {e3}")
107
+ raise
108
+
109
+ # Move model to device if not already done by device_map
110
+ if not hasattr(self.model, 'device') or str(self.model.device) != str(self.device):
111
+ logger.info(f"Moving model to device: {self.device}")
112
+ self.model.to(self.device)
113
+
114
+ self.model.eval()
115
+ logger.info("Model loaded successfully and moved to device")
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error loading model or processor: {e}")
119
+ import traceback
120
+ logger.error(f"Traceback: {traceback.format_exc()}")
121
+ raise
122
+ def preprocess_images(self, line_segments):
123
+ """
124
+ Prepare line images for inference.
125
+
126
+ Args:
127
+ line_segments (dict): Dictionary containing line segment information
128
+
129
+ Returns:
130
+ tuple: (keys, line_images) - keys and corresponding images
131
+ """
132
+ keys = list(line_segments.keys())
133
+ line_images = [line_segments[k]["image"] for k in keys]
134
+ return keys, line_images
135
+
136
+ def create_dataloader(self, line_images):
137
+ """
138
+ Create DataLoader for inference.
139
+
140
+ Args:
141
+ line_images (list): List of line images
142
+
143
+ Returns:
144
+ DataLoader: Configured DataLoader for inference
145
+ """
146
+ # Create dummy labels for inference
147
+ dummy_labels = [""] * len(line_images)
148
+
149
+ dataset = LineDataset(
150
+ self.processor,
151
+ self.model,
152
+ line_images,
153
+ dummy_labels,
154
+ self.target_size,
155
+ apply_augmentation=False
156
+ )
157
+
158
+ dataloader = DataLoader(
159
+ dataset,
160
+ batch_size=self.batch_size,
161
+ shuffle=False,
162
+ collate_fn=collate_fn
163
+ )
164
+
165
+ return dataloader
166
+
167
+ def generate_texts(self, dataloader):
168
+ """
169
+ Generate texts from images using the model.
170
+
171
+ Args:
172
+ dataloader (DataLoader): DataLoader containing preprocessed images
173
+
174
+ Returns:
175
+ list: List of generated texts
176
+ """
177
+ generated_texts = []
178
+
179
+ with torch.no_grad():
180
+ for batch in dataloader:
181
+ pixel_values = batch["pixel_values"].to(self.device)
182
+ generated_ids = self.model.generate(pixel_values)
183
+ generated_texts_batch = self.processor.batch_decode(
184
+ generated_ids,
185
+ skip_special_tokens=True
186
+ )
187
+ generated_texts.extend(generated_texts_batch)
188
+
189
+ return generated_texts
190
+
191
+ def update_line_segments(self, line_segments, keys, generated_texts):
192
+ """
193
+ Update line segments dictionary with generated transcriptions.
194
+
195
+ Args:
196
+ line_segments (dict): Original line segments dictionary
197
+ keys (list): List of keys corresponding to the line segments
198
+ generated_texts (list): List of generated texts
199
+
200
+ Returns:
201
+ dict: Updated line segments dictionary with transcriptions
202
+ """
203
+ for key, text in zip(keys, generated_texts):
204
+ line_segments[key]["transcription"] = text
205
+
206
+ return line_segments
207
+
208
+ def generate_texts_from_images(self, line_segments):
209
+ """
210
+ Main method to generate texts from line segment images.
211
+
212
+ Args:
213
+ line_segments (dict): Dictionary containing line segment information
214
+ with "image" key for each segment
215
+
216
+ Returns:
217
+ dict: Updated line segments dictionary with "transcription" key added
218
+ """
219
+ logger.info("Starting text generation from images...")
220
+ # Preprocess images
221
+ keys, line_images = self.preprocess_images(line_segments)
222
+
223
+ # Create dataloader
224
+ dataloader = self.create_dataloader(line_images)
225
+
226
+ # Generate texts
227
+ generated_texts = self.generate_texts(dataloader)
228
+
229
+ # Update line segments with transcriptions
230
+ updated_line_segments = self.update_line_segments(
231
+ line_segments, keys, generated_texts
232
+ )
233
+
234
+ return updated_line_segments
235
+
236
+ def generate_single_image(self, image):
237
+ """
238
+ Generate text from a single image.
239
+
240
+ Args:
241
+ image: PIL Image or numpy array
242
+
243
+ Returns:
244
+ str: Generated text
245
+ """
246
+ if isinstance(image, np.ndarray):
247
+ image = Image.fromarray(image)
248
+
249
+ # Create a temporary line_segments-like structure
250
+ temp_segments = {"temp_key": {"image": image}}
251
+
252
+ # Use the main method
253
+ result = self.generate_texts_from_images(temp_segments)
254
+
255
+ return result["temp_key"]["transcription"]
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage.io import imread, imsave
2
+ from skimage.color import rgb2gray
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from skimage.transform import resize
7
+ from utils.preprocessing import preprocessImage, postProcessImage, process_segment_and_crop_image
8
+ from utils.line_segmentation import segment_image_to_lines
9
+ from configs import unet_enabled
10
+ from utils.helper import load_images_from_json
11
+ from inference import Inference
12
+ from configs import model_path, processor_path, unet_model_path
13
+ from utils.postprocessing import PostProcessing
14
+ from loguru import logger
15
+
16
+ def RenAITranscription(image, post_processing_enabled=False,unet_enabled=False):
17
+ # 1- preprocessing
18
+ org_img = imread(image)[: , : ,:]
19
+
20
+ logger.info(f'Image Dimensions : {org_img.shape[0]} x {org_img.shape[1]}')
21
+
22
+ intial_process_image = preprocessImage(org_img)
23
+
24
+ if unet_enabled:
25
+ logger.info("Masked based segmentation and cropping enabled...")
26
+ cropped_img = process_segment_and_crop_image(unet_model_path, org_img, intial_process_image, padding=10, min_contour_area=100)
27
+ processed_image = postProcessImage(cropped_img)
28
+ logger.info(f"Image cropped and Pre-processed successfully.....")
29
+ else:
30
+ logger.info("Image Preprocessing started......")
31
+ processed_image = postProcessImage(intial_process_image)
32
+ logger.info(f"Image Pre-processed successfully.....")
33
+
34
+ # 2 - Line segmentation Algorithm
35
+ line_segments = segment_image_to_lines(processed_image, base_key="line",ct=0)
36
+
37
+ # 3 - Model Inference
38
+
39
+ transciption_generator = Inference(
40
+ model_path=model_path,
41
+ processor_path=processor_path,
42
+ target_size=(256, 64),
43
+ batch_size=32
44
+ )
45
+ result = transciption_generator.generate_texts_from_images(line_segments)
46
+
47
+ # Generated texts
48
+ for key, value in result.items():
49
+ print(f"{key}: {value['transcription']}")
50
+
51
+ # 4 - Post processing
52
+ # Dictionary based fuzzy matching
53
+ if post_processing_enabled:
54
+ for key, value in result.items():
55
+ corrected = PostProcessing(value['transcription'])
56
+ result[key]['post_processed'] = corrected
57
+ print(f"{key}: {value['post_processed']}")
58
+
59
+ print(result)
60
+
61
+ logger.info("Transcription completed successfully!")
62
+ return result
63
+
64
+ if __name__ == "__main__":
65
+ RenAITranscription("1.png", post_processing_enabled=False, unet_enabled=False)
requirements.txt ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.5.1
3
+ aiohappyeyeballs==2.5.0
4
+ aiohttp==3.11.13
5
+ aiosignal==1.3.2
6
+ albucore==0.0.23
7
+ albumentations==2.0.5
8
+ annotated-types==0.7.0
9
+ anyio==4.8.0
10
+ argon2-cffi==23.1.0
11
+ argon2-cffi-bindings==21.2.0
12
+ arrow==1.3.0
13
+ asttokens==3.0.0
14
+ astunparse==1.6.3
15
+ async-lru==2.0.4
16
+ attrs==25.1.0
17
+ babel==2.17.0
18
+ beautifulsoup4==4.13.3
19
+ bleach==6.2.0
20
+ blinker==1.9.0
21
+ blis==1.2.0
22
+ catalogue==2.0.10
23
+ certifi==2025.1.31
24
+ cffi==1.17.1
25
+ charset-normalizer==3.4.1
26
+ click==8.1.8
27
+ cloudpathlib==0.21.0
28
+ colorama==0.4.6
29
+ comm==0.2.2
30
+ confection==0.1.5
31
+ contourpy==1.3.1
32
+ cycler==0.12.1
33
+ cymem==2.0.11
34
+ datasets==3.3.2
35
+ datetime
36
+ debugpy==1.8.13
37
+ decorator==5.2.1
38
+ defusedxml==0.7.1
39
+ deskew
40
+ dill==0.3.8
41
+ editdistance==0.8.1
42
+ einops==0.8.1
43
+ evaluate==0.4.3
44
+ executing==2.2.0
45
+ fastapi
46
+ fastjsonschema==2.21.1
47
+ filelock==3.17.0
48
+ Flask==3.1.0
49
+ flatbuffers==25.2.10
50
+ fonttools==4.56.0
51
+ fqdn==1.5.1
52
+ frozenlist==1.5.0
53
+ fsspec==2024.12.0
54
+ gast==0.6.0
55
+ gensim==4.3.3
56
+ google-pasta==0.2.0
57
+ greenlet==3.1.1
58
+ grpcio==1.71.0
59
+ h11==0.14.0
60
+ h5py==3.13.0
61
+ httpcore==1.0.7
62
+ httpx==0.28.1
63
+ huggingface-hub==0.34.4
64
+ idna==3.10
65
+ imageio==2.37.0
66
+ iniconfig==2.0.0
67
+ inquirerpy==0.3.4
68
+ ipykernel==6.29.5
69
+ ipython==9.0.2
70
+ ipython_pygments_lexers==1.1.1
71
+ ipywidgets==8.1.5
72
+ isoduration==20.11.0
73
+ itsdangerous==2.2.0
74
+ jedi==0.19.2
75
+ Jinja2==3.1.6
76
+ jiwer==3.1.0
77
+ joblib==1.4.2
78
+ json5==0.10.0
79
+ jsonpointer==3.0.0
80
+ jsonschema==4.23.0
81
+ jsonschema-specifications==2024.10.1
82
+ jupyter==1.1.1
83
+ jupyter-console==6.6.3
84
+ jupyter-events==0.12.0
85
+ jupyter-lsp==2.2.5
86
+ jupyter_client==8.6.3
87
+ jupyter_core==5.7.2
88
+ jupyter_server==2.15.0
89
+ jupyter_server_terminals==0.5.3
90
+ jupyterlab==4.3.5
91
+ jupyterlab_pygments==0.3.0
92
+ jupyterlab_server==2.27.3
93
+ jupyterlab_widgets==3.0.13
94
+ keras==3.9.0
95
+ kiwisolver==1.4.8
96
+ langcodes==3.5.0
97
+ language_data==1.3.0
98
+ lazy_loader==0.4
99
+ Levenshtein==0.27.1
100
+ libclang==18.1.1
101
+ loguru
102
+ lxml==5.3.1
103
+ marisa-trie==1.2.1
104
+ Markdown==3.7
105
+ markdown-it-py==3.0.0
106
+ MarkupSafe==3.0.2
107
+ matplotlib==3.10.1
108
+ matplotlib-inline==0.1.7
109
+ mdurl==0.1.2
110
+ mistune==3.1.2
111
+ ml_dtypes==0.5.1
112
+ mpmath==1.3.0
113
+ multidict==6.1.0
114
+ multiprocess==0.70.16
115
+ murmurhash==1.0.12
116
+ namex==0.0.8
117
+ narwhals==1.30.0
118
+ nbclient==0.10.2
119
+ nbconvert==7.16.6
120
+ nbformat==5.10.4
121
+ nest-asyncio==1.6.0
122
+ networkx==3.3
123
+ ninja==1.11.1.4
124
+ nltk==3.9.1
125
+ notebook==7.3.2
126
+ notebook_shim==0.2.4
127
+ numpy==1.26.4
128
+ opencv-python==4.11.0.86
129
+ opencv-python-headless
130
+ opt_einsum==3.4.0
131
+ optree==0.14.1
132
+ overrides==7.7.0
133
+ packaging==24.2
134
+ pandas==2.2.3
135
+ pandocfilters==1.5.1
136
+ parso==0.8.4
137
+ pfzy==0.3.4
138
+ pillow==11.1.0
139
+ platformdirs==4.3.6
140
+ plotly==6.0.0
141
+ pluggy==1.5.0
142
+ preshed==3.0.9
143
+ prometheus_client==0.21.1
144
+ prompt_toolkit==3.0.50
145
+ propcache==0.3.0
146
+ protobuf==4.25.6
147
+ psutil==7.0.0
148
+ pure_eval==0.2.3
149
+ pyarrow==19.0.1
150
+ pycparser==2.22
151
+ pydantic==2.10.6
152
+ pydantic_core==2.27.2
153
+ Pygments==2.19.1
154
+ PyMuPDF==1.25.3
155
+ pyparsing==3.2.1
156
+ pytest==8.3.5
157
+ python-dateutil==2.9.0.post0
158
+ python-docx==1.1.2
159
+ python-json-logger==3.3.0
160
+ python-Levenshtein==0.27.1
161
+ python-multipart
162
+ pytz==2025.1
163
+ # pywin32==309
164
+ # pywinpty==2.0.15
165
+ PyYAML==6.0.2
166
+ pyzmq==26.2.1
167
+ RapidFuzz==3.12.2
168
+ referencing==0.36.2
169
+ regex==2024.11.6
170
+ requests==2.32.3
171
+ rfc3339-validator==0.1.4
172
+ rfc3986-validator==0.1.1
173
+ rich==13.9.4
174
+ rpds-py==0.23.1
175
+ safetensors==0.5.3
176
+ scikit-image==0.25.2
177
+ scikit-learn==1.6.1
178
+ scipy==1.13.1
179
+ seaborn==0.13.2
180
+ Send2Trash==1.8.3
181
+ setuptools==75.8.0
182
+ shellingham==1.5.4
183
+ simsimd==6.2.1
184
+ six==1.17.0
185
+ smart-open==7.1.0
186
+ sniffio==1.3.1
187
+ soupsieve==2.6
188
+ spacy==3.8.4
189
+ spacy-legacy==3.0.12
190
+ spacy-loggers==1.0.5
191
+ SQLAlchemy==2.0.38
192
+ srsly==2.5.1
193
+ stack-data==0.6.3
194
+ stringzilla==3.12.3
195
+ sympy==1.13.1
196
+ tensorboard==2.19.0
197
+ tensorboard-data-server==0.7.2
198
+ tensorflow==2.19.0
199
+ # tensorflow-intel==2.16.1
200
+ termcolor==2.5.0
201
+ terminado==0.18.1
202
+ tf_keras==2.19.0
203
+ thinc==8.3.4
204
+ threadpoolctl==3.5.0
205
+ tifffile==2025.2.18
206
+ tinycss2==1.4.0
207
+ tokenizers==0.21.0
208
+ torch==2.4.1
209
+ torchaudio==2.4.1
210
+ torchvision==0.19.1
211
+ tornado==6.4.2
212
+ tqdm==4.67.1
213
+ traitlets==5.14.3
214
+ transformers==4.49.0
215
+ typer==0.15.2
216
+ types-python-dateutil==2.9.0.20241206
217
+ typing_extensions==4.12.2
218
+ tzdata==2025.1
219
+ uri-template==1.3.0
220
+ urllib3==2.3.0
221
+ uvicorn
222
+ wasabi==1.1.3
223
+ wcwidth==0.2.13
224
+ weasel==0.4.1
225
+ webcolors==24.11.1
226
+ webencodings==0.5.1
227
+ websocket-client==1.8.0
228
+ Werkzeug==3.1.3
229
+ wheel==0.45.1
230
+ widgetsnbextension==4.0.13
231
+ wrapt==1.17.2
232
+ xxhash==3.5.0
233
+ yarl==1.18.3
utils/__pycache__/configs.cpython-312.pyc ADDED
Binary file (331 Bytes). View file
 
utils/__pycache__/helper.cpython-312.pyc ADDED
Binary file (660 Bytes). View file
 
utils/__pycache__/inference.cpython-312.pyc ADDED
Binary file (7.26 kB). View file
 
utils/__pycache__/line_segmentation.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
utils/__pycache__/postprocessing.cpython-312.pyc ADDED
Binary file (20.1 kB). View file
 
utils/__pycache__/preprocessing.cpython-312.pyc ADDED
Binary file (9.67 kB). View file
 
utils/__pycache__/vit.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
utils/helper.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def load_images_from_json(line_segments):
4
+ line_images = []
5
+ image_paths = []
6
+ for key, value in line_segments.items():
7
+ line_images.append(value["image"])
8
+ image_paths.append(value.get("image_path", f"{key}.png"))
9
+
10
+ return line_images, image_paths
utils/line_segmentation.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage.io import imread
2
+ from skimage.color import rgb2gray
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from skimage.filters import threshold_otsu
6
+ import os
7
+ from skimage.graph import route_through_array
8
+ from heapq import heappush, heappop
9
+ from loguru import logger
10
+
11
+ def heuristic(a, b):
12
+ """Calculate the squared distance between two points."""
13
+ return (b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2
14
+
15
+
16
+ def get_binary(img):
17
+ """Binarize the image using Otsu's threshold."""
18
+ mean = np.mean(img)
19
+ if mean == 0.0 or mean == 1.0:
20
+ return img
21
+
22
+ thresh = threshold_otsu(img)
23
+ binary = img <= thresh
24
+ binary = binary.astype(np.uint8)
25
+ return binary
26
+
27
+
28
+ def astar(array, start, goal):
29
+ """Perform A* algorithm to find a path from start to goal in a binary array."""
30
+ neighbors = [(0,1),(0,-1),(1,0),(-1,0),(1,1),(1,-1),(-1,1),(-1,-1)]
31
+ close_set = set()
32
+ came_from = {}
33
+ gscore = {start:0}
34
+ fscore = {start:heuristic(start, goal)}
35
+ oheap = []
36
+
37
+ heappush(oheap, (fscore[start], start))
38
+
39
+ while oheap:
40
+ current = heappop(oheap)[1]
41
+
42
+ if current == goal:
43
+ data = []
44
+ while current in came_from:
45
+ data.append(current)
46
+ current = came_from[current]
47
+ return data
48
+
49
+ close_set.add(current)
50
+ for i, j in neighbors:
51
+ neighbor = current[0] + i, current[1] + j
52
+ tentative_g_score = gscore[current] + heuristic(current, neighbor)
53
+ if 0 <= neighbor[0] < array.shape[0]:
54
+ if 0 <= neighbor[1] < array.shape[1]:
55
+ if array[neighbor[0]][neighbor[1]] == 1:
56
+ continue
57
+ else:
58
+ # array bound y walls
59
+ continue
60
+ else:
61
+ # array bound x walls
62
+ continue
63
+
64
+ if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0):
65
+ continue
66
+
67
+ if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in oheap]:
68
+ came_from[neighbor] = current
69
+ gscore[neighbor] = tentative_g_score
70
+ fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal)
71
+ heappush(oheap, (fscore[neighbor], neighbor))
72
+
73
+ return []
74
+
75
+
76
+ def preprocess_image(img, target_size):
77
+ """Read and convert an image to grayscale."""
78
+ try:
79
+ if target_size is not None:
80
+ img = img[target_size[0]:target_size[1], target_size[2]:target_size[3],:]
81
+ if img.ndim == 3 and img.shape[2] == 4:
82
+ img = img[..., :3]
83
+ if img.ndim > 2:
84
+ img = rgb2gray(img)
85
+ return img
86
+ except Exception as e:
87
+ print(f"Error in preprocessing: {e}")
88
+ return None
89
+
90
+
91
+ def horizontal_projections(sobel_image):
92
+ """Calculate horizontal projections of the binary image."""
93
+ return np.sum(sobel_image, axis=1)
94
+
95
+
96
+ def binarize_image(image):
97
+ """Binarize an image using Otsu's threshold."""
98
+ threshold = threshold_otsu(image)
99
+ return image < threshold
100
+
101
+
102
+ def find_peak_regions(hpp, threshold):
103
+ """Identify peak regions based on the horizontal projection profile."""
104
+ peaks = []
105
+ for i, hppv in enumerate(hpp):
106
+ if hppv < threshold:
107
+ peaks.append(i)
108
+ return peaks
109
+
110
+
111
+ def line_segmentation(image, threshold=None, min_peak_group_size=7, target_size=None,
112
+ ct=0, parent_line_num=None, recursive=False, recursive_count=1,
113
+ base_key="line"):
114
+ """
115
+ Segment an image into lines using horizontal projections and A*.
116
+
117
+ Args:
118
+ image: Input image (numpy array)
119
+ threshold (float, optional): Threshold for peak detection
120
+ min_peak_group_size (int): Minimum size of peak groups to consider
121
+ target_size (tuple, optional): Target size for image preprocessing
122
+ ct (int): Counter for line numbering
123
+ parent_line_num (str, optional): Parent line number for recursive segmentation
124
+ recursive (bool): Whether this is a recursive call
125
+ recursive_count (int): Counter for recursive segmentation numbering
126
+ base_key (str): Base key for dictionary entries
127
+
128
+ Returns:
129
+ tuple: (segmented_images_dict, counter value, bool indicating if valid separations were found)
130
+ """
131
+ segmented_images_dict = {}
132
+
133
+ img = preprocess_image(image, target_size)
134
+ if img is None:
135
+ print(f"Failed to preprocess image")
136
+ return segmented_images_dict, ct, False
137
+
138
+ # Binarize image and get projections
139
+ binarized_image = binarize_image(img)
140
+ hpp = horizontal_projections(binarized_image)
141
+
142
+ if threshold is None:
143
+ threshold = (np.max(hpp) - np.min(hpp)) / 2
144
+
145
+ # Find peaks
146
+ peaks = find_peak_regions(hpp, threshold)
147
+ if not peaks:
148
+ print(f"No peaks found in image")
149
+ return segmented_images_dict, ct, False
150
+
151
+ peaks_indexes = np.array(peaks).astype(int)
152
+
153
+ segmented_img = np.copy(img)
154
+ r, c = segmented_img.shape
155
+ for ri in range(r):
156
+ if ri in peaks_indexes:
157
+ segmented_img[ri, :] = 0
158
+
159
+ # Group peaks
160
+ diff_between_consec_numbers = np.diff(peaks_indexes)
161
+ indexes_with_larger_diff = np.where(diff_between_consec_numbers > 1)[0].flatten()
162
+ peak_groups = np.split(peaks_indexes, indexes_with_larger_diff + 1)
163
+ peak_groups = [item for item in peak_groups if len(item) > min_peak_group_size]
164
+
165
+ if not peak_groups:
166
+ print(f"No valid peak groups found in image")
167
+ return segmented_images_dict, ct, False
168
+
169
+ binary_image = get_binary(img)
170
+ segment_separating_lines = []
171
+
172
+ for sub_image_index in peak_groups:
173
+ try:
174
+ start_row = sub_image_index[0]
175
+ end_row = sub_image_index[-1]
176
+
177
+ start_row = max(0, start_row)
178
+ end_row = min(binary_image.shape[0], end_row)
179
+
180
+ if end_row <= start_row:
181
+ continue
182
+
183
+ nmap = binary_image[start_row:end_row, :]
184
+
185
+ if nmap.size == 0:
186
+ continue
187
+
188
+ start_point = (int(nmap.shape[0] / 2), 0)
189
+ end_point = (int(nmap.shape[0] / 2), nmap.shape[1] - 1)
190
+
191
+ path, _ = route_through_array(nmap, start_point, end_point)
192
+ path = np.array(path) + start_row
193
+ segment_separating_lines.append(path)
194
+ except Exception as e:
195
+ print(f"Failed to process sub-image: {e}")
196
+ continue
197
+
198
+ if not segment_separating_lines:
199
+ print(f"No valid segment separating lines found in image")
200
+ return segmented_images_dict, ct, False
201
+
202
+ # Separate images based on line segments
203
+ seperated_images = []
204
+ for index in range(len(segment_separating_lines) - 1):
205
+ try:
206
+ lower_line = np.min(segment_separating_lines[index][:, 0])
207
+ upper_line = np.max(segment_separating_lines[index + 1][:, 0])
208
+
209
+ if lower_line < upper_line and upper_line <= img.shape[0]:
210
+ line_image = img[lower_line:upper_line]
211
+ if line_image.size > 0:
212
+ seperated_images.append(line_image)
213
+ except Exception as e:
214
+ print(f"Failed to separate image at index {index}: {e}")
215
+ continue
216
+
217
+ if not seperated_images:
218
+ print(f"No valid separated images found in image")
219
+ return segmented_images_dict, ct, False
220
+
221
+ # Calculate height threshold
222
+ try:
223
+ image_heights = [line_image.shape[0] for line_image in seperated_images if line_image.size > 0]
224
+ if not image_heights:
225
+ print(f"No valid image heights found")
226
+ return segmented_images_dict, ct, False
227
+ height_threshold = np.percentile(image_heights, 90)
228
+ except Exception as e:
229
+ print(f"Failed to calculate height threshold: {e}")
230
+ return segmented_images_dict, ct, False
231
+
232
+ # Process each separated image
233
+ for index, line_image in enumerate(seperated_images):
234
+ try:
235
+ if line_image.size == 0 or line_image.shape[0] == 0 or line_image.shape[1] == 0:
236
+ continue
237
+
238
+ if parent_line_num is None:
239
+ dict_key = f'{base_key}_{ct + 1}'
240
+ else:
241
+ dict_key = f'{base_key}_{recursive_count}'
242
+ if index < len(seperated_images) - 1:
243
+ continue
244
+
245
+ segmented_images_dict[dict_key] = {
246
+ "image": line_image.copy(),
247
+ "transcription": f"{dict_key}"
248
+ }
249
+
250
+ # print(f"Added line image to dictionary with key: {dict_key}")
251
+
252
+ # Handle recursive segmentation
253
+ if line_image.shape[0] > height_threshold and not recursive:
254
+ try:
255
+ # Create recursive base key
256
+ recursive_base_key = f"{base_key}_{ct + 1}"
257
+
258
+ # Do recursive segmentation
259
+ recursive_dict, ct, found_valid_separations = line_segmentation(
260
+ line_image, threshold=threshold,
261
+ min_peak_group_size=3,
262
+ parent_line_num=str(ct + 1),
263
+ recursive=True,
264
+ ct=ct,
265
+ recursive_count=1,
266
+ base_key=recursive_base_key
267
+ )
268
+
269
+ if found_valid_separations:
270
+ del segmented_images_dict[dict_key]
271
+ segmented_images_dict.update(recursive_dict)
272
+ print(f"Replaced {dict_key} with recursive segmentation results")
273
+ else:
274
+ print(f"Keeping original image {dict_key} as no valid separations were found")
275
+
276
+ except Exception as e:
277
+ print(f"Failed during recursive segmentation of {dict_key}: {e}")
278
+
279
+ ct += 1
280
+ if recursive:
281
+ recursive_count += 1
282
+
283
+ except Exception as e:
284
+ print(f"Failed to process line image at index {index}: {e}")
285
+ continue
286
+ logger.info(f"Total lines segment found: {len(segmented_images_dict)}")
287
+ return segmented_images_dict, ct, len(seperated_images) > 0
288
+
289
+
290
+ def segment_image_to_lines(image_array, **kwargs):
291
+ """
292
+ Convenience function to segment an image into lines.
293
+
294
+ Args:
295
+ image_array: Input image as numpy array
296
+ **kwargs: Additional arguments for line_segmentation
297
+
298
+ Returns:
299
+ dict: Dictionary with line keys and segmented image arrays as values
300
+ """
301
+ try:
302
+
303
+ logger.info("Starting line segmentation...")
304
+ segmented_dict, _, success = line_segmentation(image_array, **kwargs)
305
+ if success:
306
+ logger.info(f"Line segmentation successful.....")
307
+
308
+ return segmented_dict
309
+ except Exception as e:
310
+ logger.error(f"Line segmentation failed: {e}")
311
+ return {}
312
+
313
+ # if __name__ == "__main__":
314
+ # # Example usage
315
+ # image_path = "./renAI-deploy/1.png"
316
+ # image = imread(image_path)
317
+ # segmented_lines = segment_image_to_lines(image, threshold=None, min_peak_group_size=10)
318
+
319
+
320
+ # print(len(segmented_lines.values()))
321
+
322
+ # for key, value in segmented_lines.items():
323
+ # print(f"{key}: {value['image'].shape}")
324
+ # print(f"{key}: {value['transcription']}")
325
+ # # plt.imshow(img, cmap='gray')
326
+ # # plt.title(key)
327
+ # # plt.show()
utils/postprocessing.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+ from collections import defaultdict
4
+ from typing import List, Tuple, Dict, Set
5
+ import heapq
6
+ from loguru import logger
7
+
8
+ class SpanishFuzzyMatcher:
9
+ def __init__(self, dictionary_path: str):
10
+ self.dictionary = set()
11
+ self.word_by_length = defaultdict(list)
12
+ self.ngram_index = defaultdict(set)
13
+ self.common_words = set()
14
+
15
+ self._load_dictionary(dictionary_path)
16
+ self._build_indexes()
17
+ self._load_common_words()
18
+
19
+ def _detect_encoding(self, path: str) -> str:
20
+ encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252', 'utf-16']
21
+
22
+ for encoding in encodings:
23
+ try:
24
+ with open(path, 'r', encoding=encoding) as f:
25
+ f.read(1024) # Try to read first 1KB
26
+ return encoding
27
+ except (UnicodeDecodeError, UnicodeError):
28
+ continue
29
+
30
+ return 'utf-8'
31
+
32
+ def _load_dictionary(self, path: str):
33
+ try:
34
+ encoding = self._detect_encoding(path)
35
+ print(f"Detected encoding: {encoding}")
36
+
37
+ with open(path, 'r', encoding=encoding, errors='ignore') as f:
38
+ for line_num, line in enumerate(f, 1):
39
+ try:
40
+ word = line.strip().lower()
41
+ if word and len(word) > 1:
42
+ # Remove any non-alphabetic characters except hyphens and apostrophes
43
+ cleaned_word = re.sub(r"[^a-záéíóúüñç\-']", "", word)
44
+ if cleaned_word and len(cleaned_word) > 1:
45
+ self.dictionary.add(cleaned_word)
46
+ self.word_by_length[len(cleaned_word)].append(cleaned_word)
47
+ except Exception as e:
48
+ print(f"Warning: Skipping line {line_num} due to error: {e}")
49
+ continue
50
+
51
+ print(f"Loaded {len(self.dictionary)} words from dictionary")
52
+
53
+ except FileNotFoundError:
54
+ raise FileNotFoundError(f"Dictionary file not found: {path}")
55
+ except Exception as e:
56
+ raise Exception(f"Error loading dictionary: {e}")
57
+
58
+ def _load_common_words(self):
59
+ common_spanish = {
60
+ 'el', 'la', 'de', 'que', 'y', 'a', 'en', 'un', 'es', 'se', 'no', 'te', 'lo', 'le', 'da', 'su', 'por', 'son', 'con', 'para', 'al', 'las', 'del', 'los', 'una', 'mi', 'muy', 'mas', 'me', 'si', 'ya', 'todo', 'como', 'pero', 'hay', 'o', 'cuando', 'esta', 'ser', 'tiene', 'estar', 'hacer', 'sobre', 'entre', 'poder', 'antes', 'tiempo', 'año', 'casa', 'día', 'vida', 'trabajo', 'hombre', 'mujer', 'mundo', 'parte', 'momento', 'lugar', 'país', 'forma', 'manera', 'estado', 'caso', 'grupo', 'agua', 'punto', 'vez', 'donde', 'quien', 'haber', 'tener', 'hacer', 'decir', 'ir', 'ver', 'dar', 'saber', 'querer', 'llegar', 'pasar', 'deber', 'poner', 'parecer', 'quedar', 'creer', 'hablar', 'llevar', 'dejar', 'seguir', 'encontrar', 'llamar', 'venir', 'pensar', 'salir', 'volver', 'tomar', 'conocer', 'vivir', 'sentir', 'tratar', 'mirar', 'contar', 'empezar', 'esperar', 'buscar', 'existir', 'entrar', 'trabajar', 'escribir', 'perder', 'producir', 'ocurrir', 'entender', 'pedir', 'recibir', 'recordar', 'terminar', 'permitir', 'aparecer', 'conseguir', 'comenzar', 'servir', 'sacar', 'necesitar', 'mantener', 'resultar', 'leer', 'caer', 'cambiar', 'presentar', 'crear', 'abrir', 'considerar', 'oír', 'acabar', 'convertir', 'ganar', 'traer', 'realizar', 'suponer', 'comprender', 'explicar', 'dedicar', 'andar', 'estudiar', 'mano', 'cabeza', 'ojo', 'cara', 'pie', 'corazón', 'vez', 'palabra', 'número', 'color', 'mesa', 'silla', 'libro', 'papel', 'coche', 'calle', 'puerta', 'ventana', 'ciudad', 'pueblo', 'escuela', 'hospital', 'iglesia', 'tienda', 'mercado', 'banco', 'hotel', 'restaurante', 'café', 'bar', 'teatro', 'cine', 'museo', 'parque', 'jardín', 'playa', 'montaña', 'río', 'mar', 'lago', 'bosque', 'árbol', 'flor', 'animal', 'perro', 'gato', 'pájaro', 'pez', 'comida', 'pan', 'carne', 'pollo', 'pescado', 'leche', 'huevo', 'queso', 'fruta', 'verdura', 'patata', 'tomate', 'cebolla', 'ajo', 'sal', 'azúcar', 'aceite', 'vino', 'cerveza', 'café', 'té', 'agua', 'fuego', 'aire', 'tierra', 'sol', 'luna', 'estrella', 'nube', 'lluvia', 'nieve', 'viento', 'calor', 'frío', 'luz', 'sombra', 'mañana', 'tarde', 'noche', 'hoy', 'ayer', 'mañana', 'semana', 'mes', 'año', 'hora', 'minuto', 'segundo', 'lunes', 'martes', 'miércoles', 'jueves', 'viernes', 'sábado', 'domingo', 'enero', 'febrero', 'marzo', 'abril', 'mayo', 'junio', 'julio', 'agosto', 'septiembre', 'octubre', 'noviembre', 'diciembre', 'primavera', 'verano', 'otoño', 'invierno', 'bueno', 'malo', 'grande', 'pequeño', 'alto', 'bajo', 'largo', 'corto', 'ancho', 'estrecho', 'grueso', 'delgado', 'fuerte', 'débil', 'rápido', 'lento', 'fácil', 'difícil', 'nuevo', 'viejo', 'joven', 'mayor', 'blanco', 'negro', 'rojo', 'azul', 'verde', 'amarillo', 'gris', 'marrón', 'rosa', 'naranja', 'morado', 'feliz', 'triste', 'contento', 'enfadado', 'cansado', 'aburrido', 'interesante', 'divertido', 'importante', 'necesario', 'posible', 'imposible', 'seguro', 'peligroso', 'rico', 'pobre', 'caro', 'barato', 'limpio', 'sucio', 'sano', 'enfermo', 'vivo', 'muerto', 'lleno', 'vacío', 'abierto', 'cerrado', 'caliente', 'frío', 'seco', 'mojado', 'duro', 'blando', 'suave', 'áspero', 'dulce', 'amargo', 'salado', 'picante', 'conocerte', 'tengas'
61
+ }
62
+ self.common_words = {word for word in common_spanish if word in self.dictionary}
63
+ print(f"Loaded {len(self.common_words)} common words")
64
+
65
+ def _is_common_spanish_error(self, ocr_word: str, dict_word: str) -> bool:
66
+ ocr_lower = ocr_word.lower()
67
+ dict_lower = dict_word.lower()
68
+
69
+ # Common OCR confusions in Spanish
70
+ ocr_substitutions = {
71
+ 'b': 'v', 'v': 'b', # b/v confusion
72
+ 'c': 's', 's': 'c', # c/s confusion
73
+ 'z': 's', 's': 'z', # z/s confusion
74
+ 'j': 'g', 'g': 'j', # j/g confusion
75
+ 'y': 'i', 'i': 'y', # y/i confusion
76
+ 'u': 'n', 'n': 'u', # u/n confusion (handwriting)
77
+ 'll': 'y', 'y': 'll', # ll/y confusion
78
+ 'ñ': 'n', 'n': 'ñ', # ñ/n confusion
79
+ }
80
+
81
+ if len(ocr_lower) == len(dict_lower):
82
+ diff_count = sum(1 for a, b in zip(ocr_lower, dict_lower) if a != b)
83
+ if diff_count == 1:
84
+ for i, (a, b) in enumerate(zip(ocr_lower, dict_lower)):
85
+ if a != b:
86
+ return a in ocr_substitutions and ocr_substitutions[a] == b
87
+
88
+ return False
89
+ def _build_indexes(self):
90
+ for word in self.dictionary:
91
+ padded_word = f"${word}$"
92
+ for i in range(len(padded_word) - 2):
93
+ trigram = padded_word[i:i+3]
94
+ self.ngram_index[trigram].add(word)
95
+
96
+ def _normalize_text(self, text: str) -> str:
97
+ text = unicodedata.normalize('NFD', text)
98
+ text = ''.join(c for c in text if unicodedata.category(c) != 'Mn')
99
+ return text.lower()
100
+
101
+ def _levenshtein_distance(self, s1: str, s2: str) -> int:
102
+ if len(s1) < len(s2):
103
+ return self._levenshtein_distance(s2, s1)
104
+
105
+ if len(s2) == 0:
106
+ return len(s1)
107
+
108
+ previous_row = list(range(len(s2) + 1))
109
+ for i, c1 in enumerate(s1):
110
+ current_row = [i + 1]
111
+ for j, c2 in enumerate(s2):
112
+ insertions = previous_row[j + 1] + 1
113
+ deletions = current_row[j] + 1
114
+ substitutions = previous_row[j] + (c1 != c2)
115
+ current_row.append(min(insertions, deletions, substitutions))
116
+ previous_row = current_row
117
+
118
+ return previous_row[-1]
119
+
120
+ def _damerau_levenshtein_distance(self, s1: str, s2: str) -> int:
121
+ len1, len2 = len(s1), len(s2)
122
+
123
+ da = {}
124
+ for char in s1 + s2:
125
+ if char not in da:
126
+ da[char] = 0
127
+
128
+ max_dist = len1 + len2
129
+ h = [[max_dist for _ in range(len2 + 2)] for _ in range(len1 + 2)]
130
+
131
+ h[0][0] = max_dist
132
+ for i in range(0, len1 + 1):
133
+ h[i + 1][0] = max_dist
134
+ h[i + 1][1] = i
135
+ for j in range(0, len2 + 1):
136
+ h[0][j + 1] = max_dist
137
+ h[1][j + 1] = j
138
+
139
+ for i in range(1, len1 + 1):
140
+ db = 0
141
+ for j in range(1, len2 + 1):
142
+ k = da[s2[j - 1]]
143
+ l = db
144
+ if s1[i - 1] == s2[j - 1]:
145
+ cost = 0
146
+ db = j
147
+ else:
148
+ cost = 1
149
+
150
+ h[i + 1][j + 1] = min(
151
+ h[i][j] + cost, # substitution
152
+ h[i + 1][j] + 1, # insertion
153
+ h[i][j + 1] + 1, # deletion
154
+ h[k][l] + (i - k - 1) + 1 + (j - l - 1) # transposition
155
+ )
156
+
157
+ da[s1[i - 1]] = i
158
+
159
+ return h[len1 + 1][len2 + 1]
160
+
161
+ def _jaro_winkler_similarity(self, s1: str, s2: str) -> float:
162
+ def jaro_similarity(s1: str, s2: str) -> float:
163
+ if s1 == s2:
164
+ return 1.0
165
+
166
+ len1, len2 = len(s1), len(s2)
167
+ if len1 == 0 or len2 == 0:
168
+ return 0.0
169
+
170
+ match_window = max(len1, len2) // 2 - 1
171
+ if match_window < 0:
172
+ match_window = 0
173
+
174
+ s1_matches = [False] * len1
175
+ s2_matches = [False] * len2
176
+
177
+ matches = 0
178
+ transpositions = 0
179
+
180
+ for i in range(len1):
181
+ start = max(0, i - match_window)
182
+ end = min(i + match_window + 1, len2)
183
+
184
+ for j in range(start, end):
185
+ if s2_matches[j] or s1[i] != s2[j]:
186
+ continue
187
+ s1_matches[i] = s2_matches[j] = True
188
+ matches += 1
189
+ break
190
+
191
+ if matches == 0:
192
+ return 0.0
193
+
194
+ k = 0
195
+ for i in range(len1):
196
+ if not s1_matches[i]:
197
+ continue
198
+ while not s2_matches[k]:
199
+ k += 1
200
+ if s1[i] != s2[k]:
201
+ transpositions += 1
202
+ k += 1
203
+
204
+ jaro = (matches / len1 + matches / len2 +
205
+ (matches - transpositions / 2) / matches) / 3
206
+ return jaro
207
+
208
+ jaro = jaro_similarity(s1, s2)
209
+
210
+ prefix_len = 0
211
+ for i in range(min(len(s1), len(s2), 4)):
212
+ if s1[i] == s2[i]:
213
+ prefix_len += 1
214
+ else:
215
+ break
216
+
217
+ return jaro + (0.1 * prefix_len * (1 - jaro))
218
+
219
+ def _get_candidates(self, word: str, max_candidates: int = 200) -> Set[str]:
220
+ candidates = set()
221
+ word_len = len(word)
222
+
223
+ common_candidates = set()
224
+ for common_word in self.common_words:
225
+ if abs(len(common_word) - word_len) <= 2:
226
+ common_candidates.add(common_word)
227
+
228
+ candidates.update(common_candidates)
229
+
230
+ for length in range(max(1, word_len - 2), word_len + 3):
231
+ length_words = self.word_by_length[length]
232
+ # Sort by length (shorter words first) and limit
233
+ sorted_words = sorted(length_words, key=len)[:max_candidates//3]
234
+ candidates.update(sorted_words)
235
+
236
+ padded_word = f"${word}$"
237
+ trigram_candidates = set()
238
+ trigram_scores = defaultdict(int)
239
+
240
+ for i in range(len(padded_word) - 2):
241
+ trigram = padded_word[i:i+3]
242
+ if trigram in self.ngram_index:
243
+ for candidate in self.ngram_index[trigram]:
244
+ trigram_scores[candidate] += 1
245
+
246
+ sorted_trigram = sorted(trigram_scores.items(), key=lambda x: x[1], reverse=True)
247
+ trigram_candidates = {word for word, score in sorted_trigram[:max_candidates//2]}
248
+ candidates.update(trigram_candidates)
249
+
250
+ return candidates
251
+
252
+ def _calculate_composite_score(self, word1: str, word2: str) -> float:
253
+ norm_word1 = self._normalize_text(word1)
254
+ norm_word2 = self._normalize_text(word2)
255
+
256
+ levenshtein = self._levenshtein_distance(norm_word1, norm_word2)
257
+ damerau = self._damerau_levenshtein_distance(norm_word1, norm_word2)
258
+ jaro_winkler = self._jaro_winkler_similarity(norm_word1, norm_word2)
259
+
260
+ max_len = max(len(norm_word1), len(norm_word2))
261
+ if max_len == 0:
262
+ return 1.0
263
+
264
+ levenshtein_sim = 1 - (levenshtein / max_len)
265
+ damerau_sim = 1 - (damerau / max_len)
266
+
267
+ length_diff = abs(len(norm_word1) - len(norm_word2))
268
+ length_penalty = 1 - (length_diff / max(len(norm_word1), len(norm_word2)))
269
+
270
+ frequency_bonus = 1.0
271
+ if norm_word2 in self.common_words:
272
+ frequency_bonus = 1.3
273
+
274
+ spanish_error_bonus = 1.0
275
+ if self._is_common_spanish_error(word1, word2):
276
+ spanish_error_bonus = 1.2
277
+
278
+ exact_length_bonus = 1.0
279
+ if len(norm_word1) == len(norm_word2):
280
+ exact_length_bonus = 1.1
281
+
282
+ base_score = (
283
+ 0.25 * levenshtein_sim +
284
+ 0.45 * damerau_sim +
285
+ 0.25 * jaro_winkler +
286
+ 0.05 * length_penalty
287
+ )
288
+
289
+ final_score = base_score * frequency_bonus * spanish_error_bonus * exact_length_bonus
290
+
291
+ return min(final_score, 1.0)
292
+
293
+ def find_best_matches(self, word: str, top_k: int = 5, threshold: float = 0.4) -> List[Tuple[str, float]]:
294
+ if not word or len(word) < 2:
295
+ return []
296
+
297
+ normalized_word = self._normalize_text(word)
298
+ if normalized_word in self.dictionary:
299
+ return [(word, 1.0)]
300
+
301
+ if word.lower() in self.dictionary:
302
+ return [(word.lower(), 1.0)]
303
+
304
+ candidates = self._get_candidates(normalized_word)
305
+
306
+ scored_matches = []
307
+ for candidate in candidates:
308
+ score = self._calculate_composite_score(word, candidate)
309
+ if score >= threshold:
310
+ heapq.heappush(scored_matches, (-score, candidate, score))
311
+
312
+ results = []
313
+ seen_words = set()
314
+ for _ in range(min(top_k, len(scored_matches))):
315
+ if scored_matches:
316
+ _, candidate, score = heapq.heappop(scored_matches)
317
+ if candidate not in seen_words:
318
+ results.append((candidate, score))
319
+ seen_words.add(candidate)
320
+
321
+ return results
322
+
323
+ def correct_sentence(self, sentence: str, confidence_threshold: float = 0.6) -> str:
324
+ words = re.findall(r'\b\w+\b|\W+', sentence)
325
+ corrected_words = []
326
+
327
+ for token in words:
328
+ if re.match(r'\b\w+\b', token):
329
+ matches = self.find_best_matches(token, top_k=1, threshold=0.3)
330
+
331
+ if matches and matches[0][1] >= confidence_threshold:
332
+ corrected_words.append(matches[0][0])
333
+ else:
334
+ corrected_words.append(token)
335
+ else:
336
+ corrected_words.append(token)
337
+
338
+ return ''.join(corrected_words)
339
+
340
+ def PostProcessing(ocr_sentence):
341
+ try:
342
+ logger.info("Post processing started......")
343
+ matcher = SpanishFuzzyMatcher('Diccionario.Espanol.136k.palabras.txt')
344
+ logger.info("Dictionary loaded successfully!")
345
+
346
+ corrected = matcher.correct_sentence(ocr_sentence, confidence_threshold=0.6)
347
+ logger.info("Post processing completed successfully!")
348
+ return corrected
349
+
350
+ except Exception as e:
351
+ print(e)
352
+ logger.error(f"Post processing failed: {e}")
353
+ return ocr_sentence
utils/preprocessing.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from deskew import determine_skew
6
+ from typing import Tuple, Union
7
+ import math
8
+ from loguru import logger
9
+
10
+ def preprocessImage(image):
11
+ """
12
+ Preprocesses an image by applying various image processing steps such as denoising, thresholding,
13
+ and removal of horizontal and vertical lines, and saves the final processed image.
14
+
15
+ Args:
16
+ - image_path (str): The file path to the input image to be processed.
17
+ - folder_path (str): The directory where the final processed image will be saved.
18
+
19
+ Returns:
20
+ - str: The path of the final processed image.
21
+ """
22
+
23
+ # Convert the image to grayscale
24
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
25
+
26
+ # Apply denoising
27
+ gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
28
+
29
+ # Apply binary thresholding using Otsu's method
30
+ thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
31
+
32
+ # Copy the original image to preserve it
33
+ removed = image.copy()
34
+
35
+ # Remove vertical lines
36
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
37
+ remove_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
38
+ cnts = cv2.findContours(remove_vertical, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
39
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
40
+ for c in cnts:
41
+ cv2.drawContours(removed, [c], -1, (255, 255, 255), 4)
42
+
43
+ # Remove horizontal lines
44
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
45
+ remove_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
46
+ cnts = cv2.findContours(remove_horizontal, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
47
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
48
+ for c in cnts:
49
+ cv2.drawContours(removed, [c], -1, (255, 255, 255), 5)
50
+
51
+ # Repair kernel
52
+ repair_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
53
+ removed = 255 - removed
54
+ dilate = cv2.dilate(removed, repair_kernel, iterations=5)
55
+ dilate = cv2.cvtColor(dilate, cv2.COLOR_BGR2GRAY)
56
+ pre_result = cv2.bitwise_and(dilate, thresh)
57
+
58
+ # Final result
59
+ result = cv2.morphologyEx(pre_result, cv2.MORPH_CLOSE, repair_kernel, iterations=5)
60
+ final = cv2.bitwise_and(result, thresh)
61
+
62
+ # Invert the final image
63
+ invert_final = 255 - final
64
+
65
+ # processed_image_path = os.path.join(folder_path, f"{os.path.splitext(os.path.basename(image_path))[0]}-preprocessed.png")
66
+ # Save the final image
67
+ # cv2.imwrite(processed_image_path, invert_final)
68
+
69
+ return invert_final
70
+
71
+ def process_segment_and_crop_image(model, image, preprocess_image_path, padding=10, min_contour_area=100):
72
+ """
73
+ Processes an image for segmentation using a U-Net model and crops the original image based on the largest contour.
74
+
75
+ Args:
76
+ - model (tf.keras.Model): Trained U-Net model for image segmentation.
77
+ - img_path (str): Path to the original image.
78
+ - preprocess_image_path (str): Path to the preprocessed image.
79
+ - output_folder (str): Folder to save the cropped image.
80
+ - padding (int): Padding around the detected region.
81
+ - min_contour_area (int): Minimum contour area to be considered for cropping.
82
+
83
+ Returns:
84
+ - str: The path of the cropped image.
85
+ """
86
+ # Read the original image in grayscale
87
+
88
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
89
+
90
+ # Apply thresholding to create a binary image
91
+ _, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
92
+
93
+ # Resize the image to the model input size (512x512)
94
+ img = cv2.resize(img, (512, 512))
95
+
96
+ # Expand dimensions to match model input
97
+ img = np.expand_dims(img, axis=-1)
98
+ img_np = np.expand_dims(img, axis=0)
99
+
100
+ # Predict the segmentation mask using the U-Net model
101
+ pred = model.predict(img_np)
102
+ pred = np.squeeze(np.squeeze(pred, axis=0), axis=-1)
103
+
104
+ # # Display the segmentation result
105
+ # plt.imshow(pred, cmap='gray')
106
+ # plt.title('U-Net Segmentation')
107
+ # plt.axis('off')
108
+ # plt.show()
109
+
110
+ # Read the original image
111
+ original_img = cv2.imread(preprocess_image_path)
112
+
113
+ # Get original dimensions
114
+ ori_height, ori_width = original_img.shape[:2]
115
+
116
+ # Resize the mask to match the original image dimensions
117
+ resized_mask = cv2.resize(pred, (ori_width, ori_height))
118
+
119
+ # Convert the resized mask to 8-bit unsigned integer type
120
+ resized_mask = (resized_mask * 255).astype(np.uint8)
121
+
122
+ # Apply Otsu's threshold to get a binary image
123
+ _, binary_mask = cv2.threshold(resized_mask, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
124
+
125
+ # Apply morphological operations to remove noise and connect nearby text
126
+ kernel = np.ones((5, 5), np.uint8)
127
+ cleaned_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
128
+ cleaned_mask = cv2.morphologyEx(cleaned_mask, cv2.MORPH_OPEN, kernel)
129
+
130
+ # Find contours in the cleaned mask
131
+ contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
132
+
133
+ # Filter contours based on area to remove small noise
134
+ valid_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > min_contour_area]
135
+
136
+ if not valid_contours:
137
+ print("No valid text regions found.")
138
+ return None
139
+
140
+ # Find the bounding rectangle that encompasses all valid contours
141
+ x_min, y_min = ori_width, ori_height
142
+ x_max, y_max = 0, 0
143
+
144
+ for contour in valid_contours:
145
+ x, y, w, h = cv2.boundingRect(contour)
146
+ x_min = min(x_min, x)
147
+ y_min = min(y_min, y)
148
+ x_max = max(x_max, x + w)
149
+ y_max = max(y_max, y + h)
150
+
151
+ x_min = max(0, x_min - padding)
152
+ y_min = max(0, y_min - padding)
153
+ x_max = min(ori_width, x_max + padding)
154
+ y_max = min(ori_height, y_max + padding)
155
+
156
+ # Crop the original image
157
+ cropped_img = original_img[y_min:y_max, x_min:x_max]
158
+
159
+ return cropped_img
160
+
161
+
162
+ def postProcessImage(cropped_image):
163
+ """
164
+ Post-processes an image by deskewing, sharpening, and applying morphological dilation, then saves the final processed image.
165
+
166
+ Args:
167
+ - image_path (str): Path to the original image.
168
+ - cropped_image_path (str): Path to the cropped image to be post-processed.
169
+ - output_folder (str): Directory where the final post-processed image will be saved.
170
+
171
+ Returns:
172
+ - str: The path of the final post-processed image.
173
+ """
174
+ def rotate(
175
+ image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]]
176
+ ) -> np.ndarray:
177
+ old_width, old_height = image.shape[:2]
178
+ angle_radian = math.radians(angle)
179
+ width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width)
180
+ height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height)
181
+
182
+ image_center = tuple(np.array(image.shape[1::-1]) / 2)
183
+ rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
184
+ rot_mat[1, 2] += (width - old_width) / 2
185
+ rot_mat[0, 2] += (height - old_height) / 2
186
+ return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background)
187
+
188
+ # Deskew Image
189
+ # grayscale = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
190
+ # angle = determine_skew(grayscale)
191
+ # rotated = rotate(image, angle, (0, 0, 0))
192
+ rotated = cropped_image
193
+
194
+ # Sharpening (reduced intensity)
195
+ blurred = cv2.GaussianBlur(rotated, (1,1), sigmaX=3, sigmaY=3)
196
+ sharpened = cv2.addWeighted(rotated, 1.5, blurred, -0.5, 0)
197
+
198
+ # Morphological dilation to thicken the text
199
+ dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
200
+ dilated = cv2.dilate(sharpened, dilate_kernel, iterations=1)
201
+
202
+ return sharpened
utils/unet.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Importing required libraries.
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import os
6
+ from keras.layers import *
7
+ from keras.models import Model
8
+ from keras.optimizers import Adam
9
+ import random
10
+
11
+ def unet(pretrained_weights = None,input_size = (512,512,1)):
12
+ inputs = Input(input_size)
13
+ conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
14
+ conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
15
+ pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
16
+ conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
17
+ conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
18
+ pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
19
+ conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
20
+ conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
21
+ pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
22
+ conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
23
+ conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
24
+ drop4 = Dropout(0.5)(conv4)
25
+ pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
26
+
27
+ conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
28
+ conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
29
+ drop5 = Dropout(0.5)(conv5)
30
+
31
+ up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
32
+ merge6 = concatenate([drop4,up6], axis = 3)
33
+ conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
34
+ conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
35
+
36
+ up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
37
+ merge7 = concatenate([conv3,up7], axis = 3)
38
+ conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
39
+ conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
40
+
41
+ up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
42
+ merge8 = concatenate([conv2,up8], axis = 3)
43
+ conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
44
+ conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
45
+
46
+ up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
47
+ merge9 = concatenate([conv1,up9], axis = 3)
48
+ conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
49
+ conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
50
+ conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
51
+ conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
52
+
53
+ model = Model(inputs,conv10)
54
+
55
+ model.compile(optimizer = Adam(learning_rate=1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
56
+
57
+ #model.summary()
58
+
59
+ if(pretrained_weights):
60
+ model.load_weights(pretrained_weights)
61
+
62
+ return model
63
+
64
+
vit.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader, random_split
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments, EarlyStoppingCallback
9
+ from PIL import Image
10
+ import numpy as np
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from torch.optim import AdamW
13
+ import torch.nn.functional as F
14
+ from evaluate import load
15
+ import albumentations as A
16
+ import os
17
+
18
+ from configs import model_path, processor_path
19
+
20
+ # Enable mixed precision training
21
+ torch.backends.cudnn.benchmark = True
22
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
23
+
24
+
25
+ # Load metrics
26
+ cer_metric = load("cer")
27
+ wer_metric = load("wer")
28
+
29
+ processor = TrOCRProcessor.from_pretrained(processor_path, do_rescale=False,use_fast=True)
30
+ model = VisionEncoderDecoderModel.from_pretrained(model_path,use_safetensors=True)
31
+
32
+ def compute_metrics(eval_pred):
33
+ logits, labels = eval_pred
34
+ if isinstance(logits, tuple):
35
+ logits = logits[0]
36
+ predictions = logits.argmax(-1)
37
+ decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
38
+ decoded_labels = []
39
+ for label in labels:
40
+ label_filtered = [token for token in label if token != -100]
41
+ decoded_label = processor.tokenizer.decode(label_filtered, skip_special_tokens=True)
42
+ decoded_labels.append(decoded_label)
43
+ cer_score = cer_metric.compute(predictions=decoded_preds, references=decoded_labels)
44
+ wer_score = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
45
+ return {"cer": cer_score, "wer": wer_score}
46
+
47
+ class LineDataset(Dataset):
48
+ def __init__(self, processor, model, line_images, texts, target_size=(384, 96), max_length=512, apply_augmentation=False):
49
+ self.line_images = line_images
50
+ self.texts = texts
51
+ self.processor = processor
52
+ self.processor.image_processor.max_length = max_length
53
+ self.processor.tokenizer.model_max_length = max_length
54
+ self.model = model
55
+ self.model.config.max_length = max_length
56
+ self.target_size = target_size
57
+ self.max_length = max_length
58
+ self.apply_augmentation = apply_augmentation
59
+
60
+ if apply_augmentation:
61
+ self.transform = A.Compose([
62
+ A.OneOf([
63
+ A.Rotate(limit=2, p=1.0),
64
+ A.ElasticTransform(alpha=0.3, sigma=50.0, alpha_affine=0.3, p=1.0),
65
+ A.OpticalDistortion(distort_limit=0.03, shift_limit=0.03, p=1.0),
66
+ A.CLAHE(clip_limit=2, tile_grid_size=(4, 4), p=1.0),
67
+ A.Affine(scale=(0.95, 1.05), translate_percent=(0.02, 0.02), shear=(-2, 2), p=1.0),
68
+ A.Perspective(scale=(0.01, 0.03), p=1.0),
69
+ A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
70
+ A.GaussianBlur(blur_limit=(3, 7), p=1.0),
71
+ A.GridDistortion(num_steps=3, distort_limit=0.02, p=1.0),
72
+ A.MedianBlur(blur_limit=3, p=1.0),
73
+ ], p=0.7),
74
+ ])
75
+ else:
76
+ self.transform = A.Compose([])
77
+
78
+ def __len__(self):
79
+ return len(self.line_images)
80
+
81
+ def __getitem__(self, idx):
82
+ image = self.line_images[idx]
83
+ text = self.texts[idx]
84
+
85
+ if isinstance(image, Image.Image):
86
+ image = np.array(image)
87
+
88
+ if image.ndim == 2:
89
+ image = np.expand_dims(image, axis=-1)
90
+ image = np.repeat(image, 3, axis=-1)
91
+
92
+ image = (image * 255).astype(np.uint8)
93
+
94
+ if self.apply_augmentation and self.transform:
95
+ augmented = self.transform(image=image)
96
+ image = augmented['image']
97
+
98
+ image = Image.fromarray(image)
99
+ image = image.resize(self.target_size, Image.LANCZOS)
100
+ image = np.array(image) / 255.0
101
+ image = np.transpose(image, (2, 0, 1))
102
+
103
+ encoding = self.processor(images=image, text=text, return_tensors="pt")
104
+ encoding['labels'] = encoding['labels'][:, :self.max_length]
105
+ encoding = {k: v.squeeze() for k, v in encoding.items()}
106
+ return encoding
107
+
108
+ def collate_fn(batch):
109
+ pixel_values = torch.stack([item['pixel_values'] for item in batch])
110
+ labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)
111
+ return {'pixel_values': pixel_values, 'labels': labels}