Amandeep01 commited on
Commit
3321e09
·
verified ·
1 Parent(s): 2ba73b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -130
app.py CHANGED
@@ -1,166 +1,253 @@
1
  import gradio as gr
2
- import easyocr
3
  import numpy as np
 
4
  from PIL import Image, ImageDraw, ImageFont
5
- from transformers import MarianMTModel, MarianTokenizer
6
- import time
7
- import logging
8
 
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- # Load OCR Reader (CPU only)
14
- try:
15
- ocr_reader = easyocr.Reader(['hi', 'mr', 'en'], gpu=False)
16
- except Exception as e:
17
- logger.error(f"Error loading OCR Reader: {e}")
18
- ocr_reader = None
19
-
20
- # Translator Cache
21
- tokenizers = {}
22
- models = {}
23
-
24
- # Translation function
25
- def translate_text_batch(texts, src_lang, tgt_lang):
26
- try:
27
- model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
28
-
29
- # Check if model is in cache
30
- if model_name not in models:
31
- tokenizer = MarianTokenizer.from_pretrained(model_name)
32
- model = MarianMTModel.from_pretrained(model_name)
33
- tokenizers[model_name] = tokenizer
34
- models[model_name] = model
35
- else:
36
- tokenizer = tokenizers[model_name]
37
- model = models[model_name]
38
-
39
- # Ensure texts is a list
40
- if isinstance(texts, str):
41
- texts = [texts]
42
-
43
- # Process all texts at once
44
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
45
- translated = model.generate(**inputs)
46
- translated_texts = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
47
-
48
- return translated_texts
49
- except Exception as e:
50
- logger.error(f"Translation error: {e}")
51
- return [f"Translation failed: {e}" for _ in texts]
52
-
53
- # Overlay text on image
54
- def overlay_text_on_image(image_np, results, translated_texts):
55
- try:
56
- pil_img = Image.fromarray(image_np)
57
- draw = ImageDraw.Draw(pil_img)
58
-
59
- # Fallback font handling
60
- try:
61
- font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
62
- font = ImageFont.truetype(font_path, 24)
63
- except IOError:
64
- # Fallback to default font if specific font not found
65
- font = ImageFont.load_default()
66
 
67
- # Validate inputs
68
- if len(results) != len(translated_texts):
69
- logger.warning("Mismatch between OCR results and translated texts")
70
- return image_np
71
 
72
- for ((bbox, text), translated) in zip(results, translated_texts):
73
- # Ensure bbox coordinates are valid
74
- try:
75
- top_left = tuple(map(int, bbox[0]))
76
- bottom_right = tuple(map(int, bbox[2]))
77
-
78
- # Draw bounding box
79
- draw.rectangle([top_left, bottom_right], outline="red", width=2)
80
-
81
- # Draw translated text above the bounding box
82
- text_position = (top_left[0], max(0, top_left[1] - 30))
83
- draw.text(text_position, translated, fill="yellow", font=font)
84
- except Exception as e:
85
- logger.error(f"Error drawing text for {text}: {e}")
86
 
87
- return np.array(pil_img)
88
- except Exception as e:
89
- logger.error(f"Overlay text error: {e}")
90
- return image_np
91
-
92
- # Main function
93
- def process_image(image, target_lang):
94
- try:
95
- # Validate inputs
96
- if image is None:
97
- return "Please upload an image."
98
 
99
- start_time = time.time()
 
 
100
 
101
- # Convert image to numpy array
102
- image_np = np.array(image)
 
 
 
 
 
 
 
 
 
 
103
 
104
- # Perform OCR
105
- results = ocr_reader.readtext(image_np)
 
 
 
 
 
 
106
 
107
- if not results:
108
- return "No text detected in the image."
 
 
 
 
 
109
 
110
- # Extract texts from OCR results
111
- texts = [item[1] for item in results]
 
 
 
 
 
 
 
 
 
112
 
113
- # Detect source language (assume first result's language)
114
- src_lang = 'en' # Default to English
115
- if results and len(results[0]) > 2:
116
- detected_lang = results[0][2]
117
- if detected_lang in ['hi', 'en', 'mr', 'fr', 'de', 'es']:
118
- src_lang = detected_lang[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Translate texts
121
- translated_texts = translate_text_batch(texts, src_lang, target_lang)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Overlay translated text
124
- overlaid_image = overlay_text_on_image(image_np,
125
- [(r[0], r[1]) for r in results],
126
- translated_texts)
 
 
 
 
127
 
128
- # Calculate and log processing time
129
- end_time = time.time()
130
- logger.info(f"Processing time: {end_time - start_time} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- return overlaid_image
133
-
134
- except Exception as e:
135
- logger.error(f"Process image error: {e}")
136
- return f"An error occurred: {str(e)}"
137
 
138
- # Gradio UI
139
- def interface():
 
 
 
140
  with gr.Blocks() as demo:
141
- gr.Markdown("# 🌍 TravelOCR: Multilingual Signboard Reader + Translator")
142
- gr.Markdown("Upload a signboard image in any language and translate it!")
143
 
144
  with gr.Row():
145
  image_input = gr.Image(type="pil", label="Upload Signboard Image")
146
  lang_dropdown = gr.Dropdown(
147
  label="Target Language",
148
- choices=["en", "hi", "fr", "de", "es"],
149
- value="en"
150
  )
151
- translate_btn = gr.Button("Translate & Overlay")
152
- output_img = gr.Image(type="numpy", label="Translated Output")
 
153
 
154
  translate_btn.click(
155
- fn=process_image,
156
  inputs=[image_input, lang_dropdown],
157
  outputs=output_img
158
  )
159
 
160
  return demo
161
 
162
- # Create and launch the app
163
- demo = interface()
164
 
165
  if __name__ == "__main__":
166
  demo.launch()
 
1
  import gradio as gr
2
+ import cv2
3
  import numpy as np
4
+ import pytesseract
5
  from PIL import Image, ImageDraw, ImageFont
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import re
9
 
10
+ class UltimateTravelOCR:
11
+ def __init__(self):
12
+ # Tesseract configuration for multiple languages
13
+ self.tesseract_config = r'--oem 3 --psm 6 -l eng+hin'
14
+
15
+ # Translation model cache
16
+ self.translation_models = {}
17
+ self.translation_tokenizers = {}
18
+
19
+ def preprocess_image(self, image):
20
+ """
21
+ Advanced image preprocessing for better OCR accuracy
22
+ """
23
+ # Convert to grayscale
24
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
25
+
26
+ # Apply adaptive thresholding
27
+ thresh = cv2.adaptiveThreshold(
28
+ gray, 255,
29
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
30
+ cv2.THRESH_BINARY, 11, 2
31
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Denoise
34
+ denoised = cv2.fastNlMeansDenoising(thresh, None, 10, 7, 21)
 
 
35
 
36
+ return denoised
37
+
38
+ def extract_text(self, preprocessed_image):
39
+ """
40
+ Advanced text extraction using Tesseract
41
+ """
42
+ # Extract text using Tesseract
43
+ text = pytesseract.image_to_string(
44
+ preprocessed_image,
45
+ config=self.tesseract_config
46
+ )
 
 
 
47
 
48
+ # Clean and process extracted text
49
+ def clean_text(txt):
50
+ # Remove special characters and extra whitespace
51
+ txt = re.sub(r'[^\w\s]', '', txt)
52
+ txt = ' '.join(txt.split())
53
+ return txt
 
 
 
 
 
54
 
55
+ # Split text into lines and clean
56
+ lines = text.split('\n')
57
+ cleaned_lines = [clean_text(line) for line in lines if clean_text(line)]
58
 
59
+ return cleaned_lines
60
+
61
+ def get_text_regions(self, preprocessed_image):
62
+ """
63
+ Detect text regions with precise bounding boxes
64
+ """
65
+ # Find contours
66
+ contours, _ = cv2.findContours(
67
+ preprocessed_image,
68
+ cv2.RETR_EXTERNAL,
69
+ cv2.CHAIN_APPROX_SIMPLE
70
+ )
71
 
72
+ # Filter and process contours
73
+ text_regions = []
74
+ for contour in contours:
75
+ # Filter contours by area to remove noise
76
+ area = cv2.contourArea(contour)
77
+ if 100 < area < 10000: # Adjust these thresholds as needed
78
+ x, y, w, h = cv2.boundingRect(contour)
79
+ text_regions.append((x, y, w, h))
80
 
81
+ return text_regions
82
+
83
+ def _load_translation_model(self, src_lang, tgt_lang):
84
+ """
85
+ Load and cache translation models
86
+ """
87
+ model_key = f"{src_lang}-{tgt_lang}"
88
 
89
+ if model_key not in self.translation_models:
90
+ try:
91
+ model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
92
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
94
+
95
+ self.translation_models[model_key] = model
96
+ self.translation_tokenizers[model_key] = tokenizer
97
+ except Exception as e:
98
+ print(f"Translation model loading error: {e}")
99
+ return None, None
100
 
101
+ return self.translation_models[model_key], self.translation_tokenizers[model_key]
102
+
103
+ def translate_text(self, text, target_lang):
104
+ """
105
+ Advanced text translation with fallback mechanisms
106
+ """
107
+ try:
108
+ # Determine source language (default to English)
109
+ src_lang = 'en'
110
+
111
+ # Load translation model
112
+ model, tokenizer = self._load_translation_model(src_lang, target_lang)
113
+
114
+ if not model or not tokenizer:
115
+ return text
116
+
117
+ # Prepare and translate
118
+ inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
119
+
120
+ with torch.no_grad():
121
+ outputs = model.generate(**inputs)
122
+
123
+ translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
+ return translated
125
+ except Exception as e:
126
+ print(f"Translation error for '{text}': {e}")
127
+ return text
128
+
129
+ def overlay_translations(self, original_image, preprocessed_image, text_regions, lines, target_lang):
130
+ """
131
+ Overlay translated text with advanced rendering
132
+ """
133
+ # Convert to PIL for drawing
134
+ pil_image = Image.fromarray(original_image)
135
+ draw = ImageDraw.Draw(pil_image)
136
+
137
+ # Load a robust font
138
+ try:
139
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 25)
140
+ except IOError:
141
+ font = ImageFont.load_default()
142
 
143
+ # Translate and overlay each text region
144
+ for (x, y, w, h), text in zip(text_regions, lines):
145
+ # Skip empty texts
146
+ if not text.strip():
147
+ continue
148
+
149
+ # Translate text
150
+ translated_text = self.translate_text(text, target_lang)
151
+
152
+ # Draw bounding box
153
+ draw.rectangle(
154
+ [x, y, x+w, y+h],
155
+ outline='red',
156
+ width=2
157
+ )
158
+
159
+ # Position translation text
160
+ text_position = (x, max(0, y - 35))
161
+
162
+ # Draw semi-transparent background
163
+ text_bbox = draw.textbbox(text_position, translated_text, font=font)
164
+ draw.rectangle(
165
+ text_bbox,
166
+ fill=(0, 0, 0, 128) # Semi-transparent black
167
+ )
168
+
169
+ # Draw translated text
170
+ draw.text(
171
+ text_position,
172
+ translated_text,
173
+ fill='white',
174
+ font=font
175
+ )
176
 
177
+ return np.array(pil_image)
178
+
179
+ def process_image(self, image, target_lang):
180
+ """
181
+ Comprehensive image processing pipeline
182
+ """
183
+ if image is None:
184
+ return None
185
 
186
+ try:
187
+ # Convert to numpy if needed
188
+ original_image = np.array(image)
189
+
190
+ # Preprocess image
191
+ preprocessed_image = self.preprocess_image(original_image)
192
+
193
+ # Extract text
194
+ lines = self.extract_text(preprocessed_image)
195
+
196
+ if not lines:
197
+ print("No text detected in the image.")
198
+ return original_image
199
+
200
+ # Get text regions
201
+ text_regions = self.get_text_regions(preprocessed_image)
202
+
203
+ # Ensure we have enough regions
204
+ if len(text_regions) < len(lines):
205
+ text_regions = [(0, i*30, original_image.shape[1], 30) for i in range(len(lines))]
206
+
207
+ # Overlay translations
208
+ result_image = self.overlay_translations(
209
+ original_image,
210
+ preprocessed_image,
211
+ text_regions[:len(lines)],
212
+ lines,
213
+ target_lang
214
+ )
215
+
216
+ return result_image
217
 
218
+ except Exception as e:
219
+ print(f"Comprehensive processing error: {e}")
220
+ return original_image
 
 
221
 
222
+ # Create global OCR translator instance
223
+ ocr_translator = UltimateTravelOCR()
224
+
225
+ # Gradio Interface
226
+ def create_interface():
227
  with gr.Blocks() as demo:
228
+ gr.Markdown("# 🌍 Ultimate TravelOCR: Multilingual Signboard Translator")
 
229
 
230
  with gr.Row():
231
  image_input = gr.Image(type="pil", label="Upload Signboard Image")
232
  lang_dropdown = gr.Dropdown(
233
  label="Target Language",
234
+ choices=['en', 'hi', 'fr', 'de', 'es'],
235
+ value="hi"
236
  )
237
+
238
+ translate_btn = gr.Button("Translate & Overlay")
239
+ output_img = gr.Image(label="Translated Output")
240
 
241
  translate_btn.click(
242
+ fn=ocr_translator.process_image,
243
  inputs=[image_input, lang_dropdown],
244
  outputs=output_img
245
  )
246
 
247
  return demo
248
 
249
+ # Launch the app
250
+ demo = create_interface()
251
 
252
  if __name__ == "__main__":
253
  demo.launch()