NaseemTahir commited on
Commit
1448da0
·
verified ·
1 Parent(s): 75adf64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import pytesseract
8
+ import re
9
+ import csv
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from ultralytics import YOLO
12
+ import keras_ocr
13
+ from datetime import datetime
14
+ from sentence_transformers import SentenceTransformer
15
+ from sklearn.metrics.pairwise import cosine_similarity
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ # Initialize the multilingual similarity model
19
+ similarity_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
20
+
21
+ def preprocess_text(text):
22
+ """Normalize text for comparison"""
23
+ text = text.lower()
24
+ text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
25
+ text = ' '.join(text.split()) # Normalize whitespace
26
+ return text
27
+
28
+ def load_translations(csv_path):
29
+ """Load translations with precomputed embeddings"""
30
+ translations = []
31
+ with open(csv_path, mode='r', encoding='utf-8') as file:
32
+ reader = csv.DictReader(file)
33
+ for row in reader:
34
+ original = preprocess_text(row['original'])
35
+ # Encode the original text during loading
36
+ embedding = similarity_model.encode(original, convert_to_tensor=False)
37
+ translations.append({
38
+ 'original_raw': row['original'].strip(),
39
+ 'original_processed': original,
40
+ 'translated': row['translated'].strip(),
41
+ 'embedding': embedding
42
+ })
43
+ return translations
44
+
45
+ def find_best_match(text, translations, threshold=0.6):
46
+ """Find best match using cosine similarity"""
47
+ processed = preprocess_text(text)
48
+ query_embed = similarity_model.encode(processed, convert_to_tensor=False)
49
+
50
+ best_match = None
51
+ highest_score = 0
52
+
53
+ for entry in translations:
54
+ score = cosine_similarity([query_embed], [entry['embedding']])[0][0]
55
+ if score > highest_score and score >= threshold:
56
+ highest_score = score
57
+ best_match = entry
58
+ best_match['score'] = round(score * 100, 1) # Convert to percentage
59
+
60
+ return best_match
61
+
62
+ # Enhanced Inpainting Functions
63
+ def create_text_mask(region, pipeline):
64
+ prediction_groups = pipeline.recognize([region])
65
+ mask = np.zeros(region.shape[:2], dtype="uint8")
66
+ for box in prediction_groups[0]:
67
+ poly = np.array(box[1], dtype=np.int32)
68
+ cv2.fillPoly(mask, [poly], 255)
69
+ return cv2.dilate(mask, np.ones((5,5), np.uint8), iterations=2)
70
+
71
+ def process_bubble_region(region, pipeline):
72
+ mask = create_text_mask(region, pipeline)
73
+ inpainted = cv2.inpaint(region, mask, 25, cv2.INPAINT_TELEA)
74
+ return cv2.medianBlur(inpainted, 5)
75
+
76
+ # Text Rendering Functions (Improved Version)
77
+ def fit_text_to_box(x, y, w, h, text, font_path, max_size=50, min_size=8, padding_top=3):
78
+ for size in range(max_size, min_size-1, -1):
79
+ font = ImageFont.truetype(font_path, size)
80
+ temp_draw = ImageDraw.Draw(Image.new('RGB', (1,1)))
81
+
82
+ # Calculate line breaks
83
+ lines = []
84
+ words = text.split()
85
+ current_line = []
86
+ max_width = w * 0.9 # Allow 10% padding
87
+
88
+ for word in words:
89
+ test_line = ' '.join(current_line + [word])
90
+ bbox = temp_draw.textbbox((0,0), test_line, font=font)
91
+ line_width = bbox[2] - bbox[0]
92
+
93
+ if line_width < max_width:
94
+ current_line.append(word)
95
+ else:
96
+ lines.append(' '.join(current_line))
97
+ current_line = [word]
98
+
99
+ lines.append(' '.join(current_line))
100
+
101
+ # Calculate total height
102
+ line_height = font.getbbox("Mg")[3] - font.getbbox("Mg")[1]
103
+ total_height = len(lines) * line_height
104
+
105
+ if total_height <= h * 0.9: # Allow 10% vertical padding
106
+ y_position = y + padding_top + (h - total_height) // 2
107
+ return lines, font, line_height, y_position
108
+
109
+ # Fallback to minimum size
110
+ font = ImageFont.truetype(font_path, min_size)
111
+ return [text], font, font.getbbox("Mg")[3], y + padding_top
112
+
113
+ def refine_ocr_text(text):
114
+ """Clean OCR output from common artifacts"""
115
+ patterns = [
116
+ r'[\x00-\x1F\x7F-\x9F]', # Remove control characters
117
+ r'\s{2,}', # Replace multiple spaces
118
+ r'^\s+|\s+$' # Trim whitespace
119
+ ]
120
+ for pattern in patterns:
121
+ text = re.sub(pattern, ' ', text)
122
+ return text.strip()
123
+
124
+ # Main Processing Pipeline
125
+ def process_image(input_path, output_path, model_path, font_path, csv_path, match_threshold=0.5):
126
+ # Initialize components
127
+ model = YOLO(model_path)
128
+ pipeline = keras_ocr.pipeline.Pipeline()
129
+ translations = load_translations(csv_path)
130
+ processing_log = []
131
+
132
+ # Load original image
133
+ original = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)
134
+ working_img = original.copy()
135
+
136
+ # Detect text regions
137
+ results = model.predict(original, verbose=False)[0]
138
+ boxes = results.boxes.xyxy.cpu().numpy()
139
+
140
+ # First pass: Clean all text regions
141
+ for box in boxes:
142
+ x1, y1, x2, y2 = map(int, box)
143
+ x1, y1 = max(x1,0), max(y1,0)
144
+ x2, y2 = min(x2,original.shape[1]), min(y2,original.shape[0])
145
+
146
+ bubble_region = original[y1:y2, x1:x2]
147
+ if bubble_region.size == 0: continue
148
+
149
+ # Clean the region
150
+ cleaned_region = process_bubble_region(bubble_region, pipeline)
151
+ working_img[y1:y2, x1:x2] = cleaned_region
152
+
153
+ # Prepare image for text rendering
154
+ pil_img = Image.fromarray(working_img)
155
+ draw = ImageDraw.Draw(pil_img)
156
+
157
+ # Second pass: OCR and text placement
158
+ for idx, box in enumerate(boxes):
159
+ x1, y1, x2, y2 = map(int, box)
160
+ w, h = x2 - x1, y2 - y1
161
+
162
+ # OCR processing on original image
163
+ bubble_region = original[y1:y2, x1:x2]
164
+ text = pytesseract.image_to_string(bubble_region, lang='ita').strip()
165
+ text = re.sub(r'\s+', ' ', text)
166
+ if not text: continue
167
+ print(f"Processing region {idx+1}: Extracted text: {text}")
168
+
169
+ # Find best matching translation
170
+ best_match = find_best_match(text, translations, match_threshold)
171
+ if best_match:
172
+ translated_text = best_match['translated']
173
+ print(f"Matched (Score: {best_match['score']}): {best_match['original_raw']}")
174
+ else:
175
+ translated_text = text # Fallback to original text
176
+ print(f"No good match found for: {text}")
177
+
178
+ # Render text
179
+ lines, font, line_height, y_pos = fit_text_to_box(
180
+ x1, y1, w, h, translated_text, font_path
181
+ )
182
+
183
+ for line in lines:
184
+ bbox = draw.textbbox((x1, y_pos), line, font=font)
185
+ text_w = bbox[2] - bbox[0]
186
+ draw.text(
187
+ (x1 + (w - text_w)//2, y_pos),
188
+ line,
189
+ font=font,
190
+ fill="black" # This should be the color you want for the text
191
+ )
192
+ y_pos += line_height
193
+
194
+ # Log results
195
+ processing_log.append({
196
+ "region": idx+1,
197
+ "coordinates": f"({x1},{y1})-({x2},{y2})",
198
+ "original": text,
199
+ "translated": translated_text,
200
+ "score": best_match['score'] if best_match else 0
201
+ })
202
+
203
+ # Save outputs
204
+ pil_img.save(output_path)
205
+ report_path = os.path.splitext(output_path)[0] + "_report.csv"
206
+ with open(report_path, 'w', encoding='utf-8') as f:
207
+ writer = csv.DictWriter(f, fieldnames=processing_log[0].keys())
208
+ writer.writeheader()
209
+ writer.writerows(processing_log)
210
+
211
+ return pil_img, processing_log
212
+
213
+ # Streamlit App Configuration
214
+ st.set_page_config(page_title="Comic Translation Pipeline", layout="wide")
215
+
216
+ # Sidebar for Input Parameters
217
+ with st.sidebar:
218
+ st.header("Configuration")
219
+ yolo_model_path = hf_hub_download(
220
+ repo_id="NaseemTahir/comic-text-segmenter",
221
+ filename="comic-text-segmenter.pt"
222
+ )
223
+ match_threshold = st.slider("Translation Match Threshold", 0, 100, 75)
224
+
225
+ # Main Interface
226
+ st.title("Comic Translation Pipeline")
227
+ st.write("Upload a comic image and translation CSV to get started")
228
+
229
+ # File Upload Section
230
+ col1, col2, col3 = st.columns(3)
231
+ with col1:
232
+ image_file = st.file_uploader("Upload Comic Image", type=["jpg", "png", "jpeg"])
233
+ with col2:
234
+ csv_file = st.file_uploader("Upload Translations CSV", type=["csv"])
235
+ with col3:
236
+ font_file = st.file_uploader("Upload Font File", type=["ttf", "otf"])
237
+
238
+ # Processing Pipeline
239
+ if st.button("Run Full Pipeline") and all([image_file, csv_file, font_file]):
240
+ with tempfile.TemporaryDirectory() as tmp_dir:
241
+ # Save uploaded files
242
+ image_path = os.path.join(tmp_dir, image_file.name)
243
+ with open(image_path, "wb") as f:
244
+ f.write(image_file.getbuffer())
245
+
246
+ csv_path = os.path.join(tmp_dir, csv_file.name)
247
+ with open(csv_path, "wb") as f:
248
+ f.write(csv_file.getbuffer())
249
+
250
+ font_path = os.path.join(tmp_dir, font_file.name)
251
+ with open(font_path, "wb") as f:
252
+ f.write(font_file.getbuffer())
253
+
254
+ # Create output directory
255
+ output_dir = os.path.join(tmp_dir, "output")
256
+ os.makedirs(output_dir, exist_ok=True)
257
+
258
+ # Run pipeline
259
+ try:
260
+ with st.spinner("Processing..."):
261
+ final_output = os.path.join(output_dir, "final_output.png")
262
+ process_image(
263
+ input_path=image_path,
264
+ output_path=final_output,
265
+ model_path=yolo_model_path,
266
+ font_path=font_path,
267
+ csv_path=csv_path,
268
+ match_threshold=match_threshold / 100
269
+ )
270
+
271
+ # Display results
272
+ st.success("Processing complete!")
273
+ st.image(Image.open(final_output), caption="Final Result", use_column_width=True)
274
+
275
+ # Download button
276
+ with open(final_output, "rb") as f:
277
+ st.download_button(
278
+ label="Download Final Image",
279
+ data=f,
280
+ file_name="translated_comic.png",
281
+ mime="image/png"
282
+ )
283
+
284
+ except Exception as e:
285
+ st.error(f"Error processing image: {str(e)}")