daneigh commited on
Commit
bddb462
·
verified ·
1 Parent(s): f2f67b5

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +43 -0
  2. best_multimodal_v3.pth +3 -0
  3. flask_app.py +36 -0
  4. requirements.txt +10 -0
  5. test_mode.py +434 -0
  6. util.py +65 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import os
4
+ from flask import Flask, request, jsonify
5
+ import threading
6
+
7
+ # Import your inference function directly
8
+ from test_mode import run_inference
9
+
10
+ def classify_meme(image):
11
+ try:
12
+ # Convert PIL to bytes if needed
13
+ import io
14
+ if hasattr(image, 'save'):
15
+ img_bytes = io.BytesIO()
16
+ image.save(img_bytes, format='PNG')
17
+ img_bytes = img_bytes.getvalue()
18
+ else:
19
+ img_bytes = image
20
+
21
+ result = run_inference(img_bytes)
22
+
23
+ if "error" in result:
24
+ return f"Error: {result['error']}"
25
+
26
+ prediction = result['prediction']
27
+ confidence = max(result['probabilities'][0]) * 100
28
+
29
+ return f"Classification: {prediction}\nConfidence: {confidence:.1f}%"
30
+ except Exception as e:
31
+ return f"Error: {str(e)}"
32
+
33
+ # Simple Gradio interface with API enabled
34
+ iface = gr.Interface(
35
+ fn=classify_meme,
36
+ inputs=gr.Image(type="pil"),
37
+ outputs="text",
38
+ title="MemeSenseX Backend",
39
+ description="Meme content classifier"
40
+ )
41
+
42
+ # Launch with API enabled
43
+ iface.launch(share=False)
best_multimodal_v3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40016216a9127daf17477fcd088970536ccc2ba905135d7563f07f8f94eb8f5d
3
+ size 516666381
flask_app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from test_mode import run_inference
3
+ from flask_cors import CORS
4
+
5
+ app = Flask(__name__)
6
+ CORS(app)
7
+
8
+
9
+ # Endpoint to handle image upload and return processed text and image tensor shape
10
+ @app.route('/process_predict', methods=['POST'])
11
+ def process_predict():
12
+ # Check if the post request has the file part
13
+ if 'image' not in request.files:
14
+ return jsonify({"error": "No image file provided"}), 400
15
+
16
+ # convert to bytes
17
+ image = request.files['image']
18
+ image_bytes = image.read()
19
+
20
+ # result
21
+ result = run_inference(image_bytes)
22
+
23
+ print(f"Processed result: {result}")
24
+
25
+ if "error" in result:
26
+ return jsonify(result), 500
27
+
28
+ return jsonify({
29
+ "status": "success",
30
+ "message": "Image processed successfully",
31
+ "data": result
32
+ }), 201
33
+
34
+
35
+ if __name__ == '__main__':
36
+ app.run(debug=True, port=5001)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ easyocr
5
+ Pillow
6
+ opencv-python-headless
7
+ flask
8
+ flask-cors
9
+ gradio
10
+ matplotlib
test_mode.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers import AutoModel, AutoTokenizer
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import easyocr
10
+ import numpy as np
11
+ import re
12
+ import os
13
+ import io
14
+ import cv2
15
+
16
+
17
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ MODEL_PATH = os.path.join(BASE_DIR, "best_multimodal_v3.pth")
19
+
20
+ # =========================
21
+ # 1. Text Preprocessing
22
+ # =========================
23
+ def preprocess_text(text):
24
+ emoji_pattern = re.compile(
25
+ "["
26
+ "\U0001F600-\U0001F64F" # emoticons
27
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
28
+ "\U0001F680-\U0001F6FF" # transport & map symbols
29
+ "\U0001F1E0-\U0001F1FF" # flags
30
+ "\U00002700-\U000027BF" # dingbats
31
+ "\U0001F900-\U0001F9FF" # supplemental symbols
32
+ "\U00002600-\U000026FF" # misc symbols
33
+ "\U00002B00-\U00002BFF" # arrows, etc.
34
+ "\U0001FA70-\U0001FAFF" # extended symbols
35
+ "]+",
36
+ flags=re.UNICODE
37
+ )
38
+ # Remove emojis
39
+ text = emoji_pattern.sub(r'', text)
40
+ # Lowercase and strip
41
+ text = text.lower().strip()
42
+ # Keep letters (including accented), and spaces
43
+ text = re.sub(r'[^a-zñáéíóúü\s]', '', text)
44
+ # Normalize whitespace
45
+ text = re.sub(r'\s+', ' ', text)
46
+
47
+ return text
48
+
49
+ # =========================
50
+ # 2. OCR Extraction
51
+ # =========================
52
+ def ocr_extract_text(image_path, confidence_threshold=0.6):
53
+ reader = easyocr.Reader(['en', 'tl'], gpu=torch.cuda.is_available())
54
+ # # preprocess image for ocr
55
+ # image = cv2.cvtColor(image_path, cv2.COLOR_RGB2GRAY)
56
+ # # image = cv2.GaussianBlur(image,(5,5),0)
57
+
58
+ # result = reader.readtext(image, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
59
+
60
+ # # Extract text and confidence scores
61
+ # texts = []
62
+ # confidences = []
63
+
64
+ # for detection in result:
65
+ # bbox, text, confidence = detection
66
+ # texts.append(text)
67
+ # confidences.append(confidence)
68
+ # final_text = " ".join(texts)
69
+ # preprocess_txt = preprocess_text(final_text)
70
+ # avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
71
+ # return final_text, preprocess_txt, avg_confidence
72
+
73
+ # Convert to grayscale
74
+ gray = cv2.cvtColor(image_path, cv2.COLOR_RGB2GRAY)
75
+
76
+ # First pass: OCR on raw grayscale
77
+ result = reader.readtext(gray, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
78
+ texts, confidences = [], []
79
+
80
+ for detection in result:
81
+ if len(detection) == 3:
82
+ _, text, conf = detection
83
+ else:
84
+ text, conf = detection
85
+
86
+ if isinstance(text, list):
87
+ text = " ".join([str(t) for t in text if isinstance(t, str)])
88
+
89
+ texts.append(text)
90
+ try:
91
+ confidences.append(float(conf))
92
+ except (ValueError, TypeError):
93
+ confidences.append(0.0)
94
+
95
+ final_text = " ".join(texts)
96
+ avg_conf = sum(confidences)/len(confidences) if confidences else 0.0
97
+
98
+ # If confidence is low, retry with Gaussian blur
99
+ if avg_conf < confidence_threshold:
100
+ texts, confidences = [], []
101
+ gauss_img = cv2.GaussianBlur(gray, (5,5), 0)
102
+ result = reader.readtext(gauss_img, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
103
+
104
+ for detection in result:
105
+ if len(detection) == 3:
106
+ _, text, conf = detection
107
+ else:
108
+ text, conf = detection
109
+
110
+ if isinstance(text, list):
111
+ text = " ".join([str(t) for t in text if isinstance(t, str)])
112
+
113
+ texts.append(text)
114
+ try:
115
+ confidences.append(float(conf))
116
+ except (ValueError, TypeError):
117
+ confidences.append(0.0)
118
+
119
+ final_text_gauss = " ".join(texts)
120
+ avg_conf_gauss = sum(confidences)/len(confidences) if confidences else 0.0
121
+
122
+ # Keep the version with higher confidence
123
+ if avg_conf_gauss > avg_conf:
124
+ final_text, avg_conf = final_text_gauss, avg_conf_gauss
125
+
126
+ if not final_text:
127
+ return "", "", 0.0
128
+
129
+ preprocess_txt = preprocess_text(final_text)
130
+ return final_text, preprocess_txt, avg_conf
131
+
132
+
133
+ # =========================
134
+ # 3. Image Preprocessing
135
+ # =========================
136
+ def resize_normalize_image(image_path, target_size=(224, 224)):
137
+
138
+ preprocess_image = transforms.Compose([
139
+ transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR),
140
+ transforms.ToTensor(),
141
+ transforms.Normalize(
142
+ mean=[0.485, 0.456, 0.406],
143
+ std=[0.229, 0.224, 0.225]
144
+ )
145
+ ])
146
+
147
+ image_tensor = preprocess_image(image_path).unsqueeze(0) # Add batch dimension
148
+ return image_tensor
149
+
150
+
151
+ # =========================
152
+ # 4. Model Definitions
153
+ # =========================
154
+ class CrossAttentionModule(nn.Module):
155
+ def __init__(self, query_dim, key_value_dim, hidden_dim=256, num_heads=8, dropout=0.1):
156
+ super(CrossAttentionModule, self).__init__()
157
+
158
+ self.hidden_dim = hidden_dim
159
+ self.num_heads = num_heads
160
+ self.head_dim = hidden_dim // num_heads
161
+ self.scale = math.sqrt(self.head_dim) # √dk
162
+
163
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
164
+
165
+ # Query projection for H (image features)
166
+ self.query_proj = nn.Linear(query_dim, hidden_dim)
167
+
168
+ # Key and Value projections for G (text features)
169
+ self.key_proj = nn.Linear(key_value_dim, hidden_dim)
170
+ self.value_proj = nn.Linear(key_value_dim, hidden_dim)
171
+
172
+ # Output projection WO
173
+ self.out_proj = nn.Linear(hidden_dim, query_dim)
174
+
175
+ # Layer normalization
176
+ self.norm1 = nn.LayerNorm(query_dim)
177
+ self.norm2 = nn.LayerNorm(query_dim)
178
+
179
+ # MLP for final transformation
180
+ self.mlp = nn.Sequential(
181
+ nn.Linear(query_dim, query_dim * 4),
182
+ nn.ReLU(),
183
+ nn.Dropout(dropout),
184
+ nn.Linear(query_dim * 4, query_dim),
185
+ nn.Dropout(dropout)
186
+ )
187
+
188
+ self.dropout = nn.Dropout(dropout)
189
+
190
+ def forward(self, H, G):
191
+ """
192
+ Args:
193
+ H: Query features [batch_size, seq_len_h, query_dim] (e.g., image patches)
194
+ G: Key/Value features [batch_size, seq_len_g, key_value_dim] (e.g., text tokens)
195
+ """
196
+ batch_size, seq_len_h, _ = H.shape
197
+ seq_len_g = G.shape[1]
198
+
199
+ # Step 1: Project to Q, K, V
200
+ Q = self.query_proj(H) # WiQ H
201
+ K = self.key_proj(G) # WiK G
202
+ V = self.value_proj(G) # WiV G
203
+
204
+ # Step 2: Reshape for multi-head attention
205
+ Q = Q.view(batch_size, seq_len_h, self.num_heads, self.head_dim).transpose(1, 2)
206
+ K = K.view(batch_size, seq_len_g, self.num_heads, self.head_dim).transpose(1, 2)
207
+ V = V.view(batch_size, seq_len_g, self.num_heads, self.head_dim).transpose(1, 2)
208
+
209
+ # Step 3: Compute attention ATTi(H,G) = softmax((WiQ H)T(WiK G)/√dk)(WiV G)T
210
+ attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
211
+ attention_weights = F.softmax(attention_scores, dim=-1)
212
+ attention_weights = self.dropout(attention_weights)
213
+ attention_output = torch.matmul(attention_weights, V)
214
+
215
+ # Step 4: Concatenate heads and apply output projection
216
+ attention_output = attention_output.transpose(1, 2).contiguous().view(
217
+ batch_size, seq_len_h, self.hidden_dim
218
+ )
219
+
220
+ # MATT(H,G) = [ATT1...ATTh]WO
221
+ matt_output = self.out_proj(attention_output)
222
+
223
+ # Step 5: Z = LN(H + MATT(H,G))
224
+ Z = self.norm1(H + matt_output)
225
+
226
+ # Step 6: TIM(H,G) = LN(Z + MLP(Z))
227
+ mlp_output = self.mlp(Z)
228
+ tim_output = self.norm2(Z + mlp_output)
229
+
230
+ return tim_output
231
+
232
+ class MultimodalClassifier(nn.Module):
233
+ def __init__(self, num_classes=2, model_name='jcblaise/roberta-tagalog-base'):
234
+ super(MultimodalClassifier, self).__init__()
235
+
236
+ # Image encoder (ResNet-18, keep spatial features)
237
+ resnet = models.resnet18(pretrained=True)
238
+ modules = list(resnet.children())[:-2] # keep until last conv (before avgpool)
239
+ self.image_encoder = nn.Sequential(*modules) # output: (B, 512, 7, 7)
240
+
241
+ # Text encoder
242
+ self.text_encoder = AutoModel.from_pretrained(model_name)
243
+
244
+ # Cross-attention using paper formula
245
+ # Image attends to text
246
+ self.img_to_text_attention = CrossAttentionModule(
247
+ query_dim=512,
248
+ key_value_dim=self.text_encoder.config.hidden_size,
249
+ hidden_dim=256,
250
+ num_heads=8
251
+ )
252
+
253
+ # Text attends to image
254
+ self.text_to_img_attention = CrossAttentionModule(
255
+ query_dim=self.text_encoder.config.hidden_size,
256
+ key_value_dim=512,
257
+ hidden_dim=256,
258
+ num_heads=8
259
+ )
260
+
261
+ # Fusion & classifier
262
+ self.fusion = nn.Sequential(
263
+ nn.Linear(512 + self.text_encoder.config.hidden_size, 512),
264
+ nn.ReLU(),
265
+ nn.Dropout(0.3),
266
+ nn.Linear(512, 128),
267
+ nn.ReLU(),
268
+ nn.Dropout(0.3),
269
+ nn.Linear(128, num_classes)
270
+ )
271
+
272
+ def forward(self, images, input_ids, attention_mask):
273
+ # Extract image features
274
+ batch_size = images.size(0)
275
+ img_feats = self.image_encoder(images) # (B, 512, 7, 7)
276
+ img_feats = img_feats.flatten(2).permute(0, 2, 1) # (B, 49, 512)
277
+
278
+ # Extract text features
279
+ text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
280
+ txt_feats = text_outputs.last_hidden_state # (B, seq_len, H)
281
+
282
+ # Cross-attention following paper formula
283
+ # TIM(img_feats, txt_feats) and TIM(txt_feats, img_feats)
284
+ attended_img = self.img_to_text_attention(img_feats, txt_feats)
285
+ attended_txt = self.text_to_img_attention(txt_feats, img_feats)
286
+
287
+ # Pool attended outputs
288
+ img_repr = attended_img.mean(dim=1) # (B, 512)
289
+ txt_repr = attended_txt[:, 0, :] # CLS token (B, hidden_size)
290
+
291
+ # Fusion
292
+ fused = torch.cat([img_repr, txt_repr], dim=1)
293
+ return self.fusion(fused)
294
+
295
+ # =========================
296
+ # 5. Load Model & Tokenizer
297
+ # =========================
298
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
299
+
300
+ model = MultimodalClassifier(num_classes=2)
301
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
302
+ model.to(device)
303
+ model.eval()
304
+
305
+ tokenizer = AutoTokenizer.from_pretrained("jcblaise/roberta-tagalog-base")
306
+
307
+ # =========================
308
+ # 6. Inference Function
309
+ # =========================
310
+ def run_inference(image_path):
311
+ # Convert bytes → PIL image
312
+ if isinstance(image_path, (bytes, bytearray)):
313
+ pil_img = Image.open(io.BytesIO(image_path)).convert("RGB")
314
+ elif isinstance(image_path, str):
315
+ pil_img = Image.open(image_path).convert("RGB")
316
+ elif isinstance(image_path, Image.Image):
317
+ pil_img = image_path.convert("RGB")
318
+ else:
319
+ raise TypeError(f"Unsupported input type: {type(image_path)}")
320
+
321
+ # OCR
322
+ np_image= np.array(pil_img)
323
+ raw_text, clean_text, confidence= ocr_extract_text(np_image)
324
+
325
+ if clean_text == "":
326
+ return {
327
+ "error": "This is not a meme. Upload a valid meme image with text.",
328
+ }
329
+
330
+ elif len(clean_text.split()) < 3:
331
+ return {
332
+ "error": "Insufficient text detected in the meme. Please upload a meme with more text. Minimum is 3 words.",
333
+ "clean_text": clean_text,
334
+ "raw_text": raw_text,
335
+ "confidence": confidence
336
+ }
337
+
338
+ # Image
339
+ img_tensor = resize_normalize_image(pil_img).to(device)
340
+
341
+ # Tokenize text
342
+ encoding = tokenizer(
343
+ clean_text, return_tensors='pt',
344
+ padding=True, truncation=True, max_length=128
345
+ )
346
+ input_ids = encoding['input_ids'].to(device)
347
+ attention_mask = encoding['attention_mask'].to(device)
348
+
349
+
350
+ # Forward pass
351
+ with torch.no_grad():
352
+ logits = model(img_tensor, input_ids, attention_mask)
353
+ probs = torch.softmax(logits, dim=1)
354
+ pred_class = torch.argmax(probs, dim=1).item()
355
+ pred_class = 'sexual' if pred_class == 1 else 'non-sexual'
356
+
357
+ return {
358
+ 'original_size': pil_img.size,
359
+ 'prediction': pred_class,
360
+ 'probabilities': probs.cpu().numpy().tolist(),
361
+ 'raw_text': raw_text,
362
+ 'clean_text': clean_text,
363
+ 'confidence': confidence
364
+ }
365
+
366
+
367
+ # =========================
368
+ # 7. Run as main
369
+ # =========================
370
+ # if __name__ == "__main__":
371
+ # # Example: load image from path
372
+ # IMAGE_PATH = "backend/OIP (1).jfif"
373
+
374
+ # # test_dimension_sensitivity(IMAGE_PATH)
375
+
376
+ # result = run_inference(IMAGE_PATH)
377
+
378
+ # # Print results
379
+ # print("Original Image Size:", result['original_size'])
380
+ # print("Prediction:", result['prediction'])
381
+ # print("Probabilities:", result['probabilities'])
382
+ # print("Raw OCR Text:", result['raw_text'])
383
+ # print("Clean OCR Text:", result['clean_text'])
384
+ # print("OCR Confidence:", result['confidence'])
385
+
386
+
387
+ # # Preprocess image
388
+ # pil_img = Image.open(IMAGE_PATH).convert("RGB")
389
+ # img_tensor = resize_normalize_image(pil_img).to(device)
390
+
391
+ # # -----------------------------
392
+ # # Generate ResNet heatmap
393
+ # # -----------------------------
394
+ # features = {}
395
+ # def hook_fn(module, input, output):
396
+ # features['value'] = output.detach()
397
+
398
+ # last_conv = model.image_encoder[-1]
399
+ # hook_handle = last_conv.register_forward_hook(hook_fn)
400
+
401
+ # with torch.no_grad():
402
+ # _ = model(img_tensor,
403
+ # input_ids=torch.zeros(1,1, dtype=torch.long, device=device),
404
+ # attention_mask=torch.ones(1,1, dtype=torch.long, device=device))
405
+
406
+ # hook_handle.remove()
407
+
408
+ # feat_tensor = features['value']
409
+ # heatmap_grid = feat_tensor[0].mean(dim=0).cpu().numpy()
410
+ # heatmap_grid = (heatmap_grid - heatmap_grid.min()) / (heatmap_grid.max() - heatmap_grid.min())
411
+ # heatmap_resized = np.uint8(255 * heatmap_grid)
412
+ # heatmap_resized = Image.fromarray(heatmap_resized).resize(pil_img.size, Image.NEAREST)
413
+ # heatmap_resized = np.array(heatmap_resized)
414
+
415
+ # probs = result['probabilities'][0]
416
+ # prob_text = f"non-sexual: {probs[0]:.2f}, sexual: {probs[1]:.2f}"
417
+
418
+ # # -----------------------------
419
+ # # Plot everything in one figure
420
+ # # -----------------------------
421
+ # fig, ax = plt.subplots(figsize=(6,6))
422
+
423
+ # ax.imshow(pil_img) # original image
424
+ # ax.imshow(heatmap_resized, cmap='jet', alpha=0.4, interpolation='nearest') # overlay heatmap
425
+ # ax.axis('off')
426
+ # ax.set_title(f"{result['prediction']} ({prob_text})", fontsize=14, color='blue')
427
+
428
+ # # Add colorbar
429
+ # sm = plt.cm.ScalarMappable(cmap='jet', norm=plt.Normalize(vmin=0, vmax=1))
430
+ # sm.set_array([])
431
+ # cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
432
+ # cbar.set_label('Feature Intensity')
433
+
434
+ # plt.show()
util.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import easyocr
2
+ from PIL import Image
3
+ import re
4
+ from torchvision import transforms
5
+ import matplotlib.pyplot as plt
6
+
7
+ # data={"image_path":"", "text":"", "preprocess_image":""}
8
+
9
+ # to preprocess the text extracted from the meme
10
+ def preprocess_text(text):
11
+ emoji_pattern = re.compile(
12
+ "["
13
+ "\U0001F600-\U0001F64F" # emoticons
14
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
15
+ "\U0001F680-\U0001F6FF" # transport & map symbols
16
+ "\U0001F1E0-\U0001F1FF" # flags
17
+ "\U00002700-\U000027BF" # dingbats
18
+ "\U0001F900-\U0001F9FF" # supplemental symbols
19
+ "\U00002600-\U000026FF" # miscellaneous symbols
20
+ "\U00002B00-\U00002BFF" # arrows, etc.
21
+ "\U0001FA70-\U0001FAFF" # extended symbols
22
+ "]+",
23
+ flags=re.UNICODE
24
+ )
25
+ text = emoji_pattern.sub(r'', text)
26
+ text = text.lower().strip()
27
+ text = re.sub(r'[^a-z0-9\s]', '', text)
28
+ text = re.sub(r'\s+', ' ', text)
29
+ text = re.sub(r'\b\w\b', '', text)
30
+ text = re.sub(r'[^\w\s]', '', text)
31
+ return text
32
+
33
+ # to extract and preprocess text from image using OCR
34
+ def ocr_extract_text(image_path):
35
+ reader = easyocr.Reader(['en', 'tl'], gpu=True)
36
+ result = reader.readtext(image_path, detail=0)
37
+ final_text = " ".join(result)
38
+ preprocess_txt= preprocess_text(final_text)
39
+ return final_text, preprocess_txt
40
+
41
+ # to resize and normalize image for model input
42
+ def resize_normalize_image(image, target_size= (224, 224)):
43
+ preprocess_image= transforms.Compose([
44
+ transforms.Resize(target_size),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(
47
+ mean=[0.485, 0.456, 0.406],
48
+ std=[0.229, 0.224, 0.225]
49
+ )
50
+ ])
51
+ image = Image.open(image).convert('RGB')
52
+
53
+ image = preprocess_image(image)
54
+ image = image.unsqueeze(0)
55
+ return image
56
+
57
+ # if __name__ == "__main__":
58
+ # input_image= "backend/test_image.jpg"
59
+ # data["image_path"]= input_image
60
+ # data["text"]= ocr_extract_text(input_image)
61
+ # data["preprocess_image"]= resize_normalize_image(input_image)
62
+ # print(data)
63
+
64
+
65
+