Amandeep01 commited on
Commit
7c076ee
·
verified ·
1 Parent(s): 0684873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -149
app.py CHANGED
@@ -1,163 +1,64 @@
 
1
  import gradio as gr
2
  import easyocr
3
- import numpy as np
4
- from PIL import Image, ImageDraw, ImageFont
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
- import torch
7
 
8
- # Simplified Language Mapping
9
- LANG_MAP = {
10
- 'en': 'eng',
11
- 'hi': 'hin',
12
- 'mr': 'mar',
13
- 'fr': 'fra',
14
- 'de': 'deu',
15
- 'es': 'spa'
 
 
 
16
  }
17
 
18
- # Initialize OCR Reader with optimized languages
19
- ocr_reader = easyocr.Reader(['en', 'hi'], gpu=False)
 
 
 
 
 
 
 
20
 
21
- # Translation Model Cache
22
- class TranslationCache:
23
- def __init__(self):
24
- self.models = {}
25
- self.tokenizers = {}
26
-
27
- def get_model(self, src_lang, tgt_lang):
28
- model_key = f"{src_lang}-{tgt_lang}"
29
-
30
- if model_key not in self.models:
31
- try:
32
- model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
35
-
36
- self.models[model_key] = model
37
- self.tokenizers[model_key] = tokenizer
38
- except Exception as e:
39
- print(f"Error loading translation model {model_key}: {e}")
40
- return None, None
41
-
42
- return self.models[model_key], self.tokenizers[model_key]
43
 
44
- # Global translation cache
45
- translation_cache = TranslationCache()
46
 
47
- def detect_language(text):
48
- """Attempt to detect language more accurately"""
49
- # Simple language detection based on script
50
- if any('\u0900' <= char <= '\u097F' for char in text):
51
- return 'hi'
52
- return 'en'
53
 
54
- def translate_text(text, src_lang, tgt_lang):
55
- """Improved translation function with better error handling"""
56
- try:
57
- # Ensure language codes match model requirements
58
- src_lang = src_lang.lower()[:2]
59
- tgt_lang = tgt_lang.lower()[:2]
60
-
61
- # Get model and tokenizer
62
- model, tokenizer = translation_cache.get_model(src_lang, tgt_lang)
63
-
64
- if not model or not tokenizer:
65
- return text # Fallback to original text if model fails
66
-
67
- # Prepare inputs
68
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
69
-
70
- # Generate translation
71
- with torch.no_grad():
72
- outputs = model.generate(**inputs)
73
-
74
- # Decode translation
75
- translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
- return translated
77
- except Exception as e:
78
- print(f"Translation error: {e}")
79
- return text
80
 
81
- def process_image(image, target_lang):
82
- """Optimized image processing with improved error handling"""
83
- if image is None:
84
- return "Please upload an image."
85
-
86
- try:
87
- # Convert image to numpy
88
- image_np = np.array(image)
89
-
90
- # Perform OCR with confidence filtering
91
- results = ocr_reader.readtext(image_np, threshold=0.3, low_text=0.4)
92
-
93
- if not results:
94
- return "No clear text detected in the image."
95
-
96
- # Prepare PIL image for drawing
97
- pil_img = Image.fromarray(image_np)
98
- draw = ImageDraw.Draw(pil_img)
99
-
100
- # Use a more universal font
101
- try:
102
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
103
- except IOError:
104
- font = ImageFont.load_default()
105
-
106
- # Process each detected text
107
- for detection in results:
108
- bbox, text, confidence = detection
109
-
110
- # Detect source language
111
- src_lang = detect_language(text)
112
-
113
- # Translate text
114
- translated_text = translate_text(text, src_lang, target_lang)
115
-
116
- # Convert bbox to integers
117
- bbox = np.array(bbox).astype(int)
118
-
119
- # Draw bounding box
120
- draw.polygon(bbox.reshape(-1, 2).tolist(), outline='red', width=2)
121
-
122
- # Draw translated text
123
- text_bbox = bbox[0] # Top-left corner
124
- draw.text((text_bbox[0], text_bbox[1] - 25),
125
- translated_text,
126
- fill='yellow',
127
- font=font)
128
-
129
- return np.array(pil_img)
130
-
131
  except Exception as e:
132
- print(f"Processing error: {e}")
133
- return f"An error occurred: {str(e)}"
134
 
135
  # Gradio Interface
136
- def create_interface():
137
- with gr.Blocks() as demo:
138
- gr.Markdown("# 🌍 TravelOCR: Multilingual Signboard Translator")
139
-
140
- with gr.Row():
141
- image_input = gr.Image(type="pil", label="Upload Signboard Image")
142
- lang_dropdown = gr.Dropdown(
143
- label="Target Language",
144
- choices=["en", "hi", "fr", "de", "es"],
145
- value="en"
146
- )
147
-
148
- translate_btn = gr.Button("Translate & Overlay")
149
- output_img = gr.Image(label="Translated Output")
150
-
151
- translate_btn.click(
152
- fn=process_image,
153
- inputs=[image_input, lang_dropdown],
154
- outputs=output_img
155
- )
156
-
157
- return demo
158
-
159
- # Launch the app
160
- demo = create_interface()
161
 
162
- if __name__ == "__main__":
163
- demo.launch()
 
1
+ # app.py
2
  import gradio as gr
3
  import easyocr
4
+ from transformers import MarianMTModel, MarianTokenizer
 
 
 
5
 
6
+ # OCR Reader Initialization
7
+ reader = easyocr.Reader(['en', 'hi', 'fr', 'de', 'es', 'ru'], gpu=False) # Add more if needed
8
+
9
+ # Supported Languages for Translation
10
+ LANGUAGE_CODES = {
11
+ "English": "en",
12
+ "Hindi": "hi",
13
+ "French": "fr",
14
+ "German": "de",
15
+ "Spanish": "es",
16
+ "Russian": "ru"
17
  }
18
 
19
+ # Function to load MarianMT model
20
+ model_cache = {}
21
+ def get_model(target_lang):
22
+ model_name = f"Helsinki-NLP/opus-mt-ROMANCE-{target_lang}" if target_lang in ['fr', 'es', 'ro', 'pt'] else f"Helsinki-NLP/opus-mt-en-{target_lang}"
23
+ if model_name not in model_cache:
24
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
25
+ model = MarianMTModel.from_pretrained(model_name)
26
+ model_cache[model_name] = (tokenizer, model)
27
+ return model_cache[model_name]
28
 
29
+ # Main function
30
+ def translate_image_text(image, target_lang):
31
+ try:
32
+ # OCR
33
+ result = reader.readtext(image, detail=0, paragraph=True)
34
+ extracted_text = " ".join(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ if not extracted_text.strip():
37
+ return "No text found in the image."
38
 
39
+ # Get model
40
+ code = LANGUAGE_CODES[target_lang]
41
+ tokenizer, model = get_model(code)
 
 
 
42
 
43
+ # Translation
44
+ batch = tokenizer([extracted_text], return_tensors="pt", padding=True)
45
+ gen = model.generate(**batch)
46
+ translated = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ return translated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
+ return f"Error: {str(e)}"
 
51
 
52
  # Gradio Interface
53
+ iface = gr.Interface(
54
+ fn=translate_image_text,
55
+ inputs=[
56
+ gr.Image(type="filepath", label="Upload Image"),
57
+ gr.Dropdown(choices=list(LANGUAGE_CODES.keys()), label="Translate To")
58
+ ],
59
+ outputs=gr.Textbox(label="Translated Text"),
60
+ title="Image Text Translator",
61
+ description="Upload an image containing text, and choose a language to translate the extracted text."
62
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ iface.launch()