fisherman611 commited on
Commit
ef43578
ยท
verified ยท
1 Parent(s): 9e7b5ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +777 -771
app.py CHANGED
@@ -1,771 +1,777 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import cv2
5
- import albumentations as A
6
- from albumentations.pytorch import ToTensorV2
7
- from PIL import Image
8
- import matplotlib.pyplot as plt
9
- import torch.nn.functional as F
10
- import os
11
- import sys
12
-
13
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
-
15
- import json
16
- from models.can.can import CAN, create_can_model
17
- from models.can.can_dataloader import Vocabulary, INPUT_HEIGHT, INPUT_WIDTH
18
-
19
- # Load configuration
20
- with open("config.json", "r") as json_file:
21
- cfg = json.load(json_file)
22
- CAN_CONFIG = cfg["can"]
23
-
24
- # Global constants
25
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
- BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
27
- PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
28
- CHECKPOINT_PATH = f'checkpoints/{BACKBONE_TYPE}_can_best.pth' if not PRETRAINED_BACKBONE else f'checkpoints/p_{BACKBONE_TYPE}_can_best.pth'
29
-
30
- # Modified process_img to accept numpy array and validate shapes
31
- def process_img(image, convert_to_rgb=False):
32
- """
33
- Process a numpy array image: binarize, ensure black background, resize, and apply padding.
34
-
35
- Args:
36
- image: Numpy array (grayscale)
37
- convert_to_rgb: Whether to convert to RGB
38
-
39
- Returns:
40
- Processed image and crop information, or None if invalid
41
- """
42
- def is_effectively_binary(img, threshold_percentage=0.9):
43
- dark_pixels = np.sum(img < 20)
44
- bright_pixels = np.sum(img > 235)
45
- total_pixels = img.size
46
- return (dark_pixels + bright_pixels) / total_pixels > threshold_percentage
47
-
48
- def before_padding(image):
49
- if image.shape[0] < 2 or image.shape[1] < 2:
50
- return None, None # Invalid image size
51
-
52
- # Ensure image is uint8
53
- if image.dtype != np.uint8:
54
- if image.max() <= 1.0: # If image is normalized (0-1)
55
- image = (image * 255).astype(np.uint8)
56
- else: # If image is in other float format
57
- image = np.clip(image, 0, 255).astype(np.uint8)
58
-
59
- edges = cv2.Canny(image, 50, 150)
60
- kernel = np.ones((7, 13), np.uint8)
61
- dilated = cv2.dilate(edges, kernel, iterations=8)
62
- num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilated, connectivity=8)
63
- sorted_components = sorted(range(1, num_labels), key=lambda i: stats[i, cv2.CC_STAT_AREA], reverse=True)
64
- best_f1 = 0
65
- best_crop = (0, 0, image.shape[1], image.shape[0])
66
- total_white_pixels = np.sum(dilated > 0)
67
- current_mask = np.zeros_like(dilated)
68
- x_min, y_min = image.shape[1], image.shape[0]
69
- x_max, y_max = 0, 0
70
-
71
- for component_idx in sorted_components:
72
- component_mask = labels == component_idx
73
- current_mask = np.logical_or(current_mask, component_mask)
74
- comp_y, comp_x = np.where(component_mask)
75
- if len(comp_x) > 0 and len(comp_y) > 0:
76
- x_min = min(x_min, np.min(comp_x))
77
- y_min = min(y_min, np.min(comp_y))
78
- x_max = max(x_max, np.max(comp_x))
79
- y_max = max(y_max, np.max(comp_y))
80
- width = x_max - x_min + 1
81
- height = y_max - y_min + 1
82
- if width < 2 or height < 2:
83
- continue
84
- crop_area = width * height
85
- crop_mask = np.zeros_like(dilated)
86
- crop_mask[y_min:y_max + 1, x_min:x_max + 1] = 1
87
- white_in_crop = np.sum(np.logical_and(dilated > 0, crop_mask > 0))
88
- precision = white_in_crop / crop_area if crop_area > 0 else 0
89
- recall = white_in_crop / total_white_pixels if total_white_pixels > 0 else 0
90
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
91
- if f1 > best_f1:
92
- best_f1 = f1
93
- best_crop = (x_min, y_min, x_max, y_max)
94
-
95
- x_min, y_min, x_max, y_max = best_crop
96
- cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
97
- if cropped_image.shape[0] < 2 or cropped_image.shape[1] < 2:
98
- return None, None
99
- if is_effectively_binary(cropped_image):
100
- _, thresh = cv2.threshold(cropped_image, 127, 255, cv2.THRESH_BINARY)
101
- else:
102
- thresh = cv2.adaptiveThreshold(cropped_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
103
- white = np.sum(thresh == 255)
104
- black = np.sum(thresh == 0)
105
- if white > black:
106
- thresh = 255 - thresh
107
- denoised = cv2.medianBlur(thresh, 3)
108
- for _ in range(3):
109
- denoised = cv2.medianBlur(denoised, 3)
110
- result = cv2.copyMakeBorder(denoised, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=0)
111
- return result, best_crop
112
-
113
- if len(image.shape) != 2:
114
- return None, None # Expect grayscale image
115
-
116
- # Ensure image is uint8 before processing
117
- if image.dtype != np.uint8:
118
- if image.max() <= 1.0: # If image is normalized (0-1)
119
- image = (image * 255).astype(np.uint8)
120
- else: # If image is in other float format
121
- image = np.clip(image, 0, 255).astype(np.uint8)
122
-
123
- bin_img, best_crop = before_padding(image)
124
- if bin_img is None:
125
- return None, None
126
- h, w = bin_img.shape
127
- if h < 2 or w < 2:
128
- return None, None
129
- new_w = int((INPUT_HEIGHT / h) * w)
130
-
131
- if new_w > INPUT_WIDTH:
132
- resized_img = cv2.resize(bin_img, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_AREA)
133
- else:
134
- resized_img = cv2.resize(bin_img, (new_w, INPUT_HEIGHT), interpolation=cv2.INTER_AREA)
135
- padded_img = np.zeros((INPUT_HEIGHT, INPUT_WIDTH), dtype=np.uint8)
136
- x_offset = (INPUT_WIDTH - new_w) // 2
137
- padded_img[:, x_offset:x_offset + new_w] = resized_img
138
- resized_img = padded_img
139
-
140
- if convert_to_rgb:
141
- resized_img = cv2.cvtColor(resized_img, cv2.COLOR_GRAY2BGR)
142
-
143
- return resized_img, best_crop
144
-
145
- # Load model and vocabulary
146
- def load_checkpoint(checkpoint_path, device, pretrained_backbone=True, backbone='densenet'):
147
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
148
- vocab = checkpoint.get('vocab')
149
- if vocab is None:
150
- vocab_path = os.path.join(os.path.dirname(checkpoint_path), 'hmer_vocab.pth')
151
- if os.path.exists(vocab_path):
152
- vocab_data = torch.load(vocab_path)
153
- vocab = Vocabulary()
154
- vocab.word2idx = vocab_data['word2idx']
155
- vocab.idx2word = vocab_data['idx2word']
156
- vocab.idx = vocab_data['idx']
157
- vocab.pad_token = vocab.word2idx['<pad>']
158
- vocab.start_token = vocab.word2idx['<start>']
159
- vocab.end_token = vocab.word2idx['<end>']
160
- vocab.unk_token = vocab.word2idx['<unk>']
161
- else:
162
- raise ValueError(f"Vocabulary not found in checkpoint and {vocab_path} does not exist")
163
-
164
- hidden_size = checkpoint.get('hidden_size', 256)
165
- embedding_dim = checkpoint.get('embedding_dim', 256)
166
- use_coverage = checkpoint.get('use_coverage', True)
167
-
168
- model = create_can_model(
169
- num_classes=len(vocab),
170
- hidden_size=hidden_size,
171
- embedding_dim=embedding_dim,
172
- use_coverage=use_coverage,
173
- pretrained_backbone=pretrained_backbone,
174
- backbone_type=backbone
175
- ).to(device)
176
-
177
- model.load_state_dict(checkpoint['model'])
178
- model.eval()
179
- return model, vocab
180
-
181
- model, vocab = load_checkpoint(CHECKPOINT_PATH, DEVICE, PRETRAINED_BACKBONE, BACKBONE_TYPE)
182
-
183
- # Image processing function for Gradio
184
- def gradio_process_img(image, convert_to_rgb=False):
185
- # Convert Gradio image (PIL, numpy, or dict from Sketchpad) to grayscale numpy array
186
- if isinstance(image, dict): # Handle Sketchpad input
187
- # The Sketchpad component returns a dict with 'background' and 'layers' keys
188
- # We need to combine the background and layers to get the final image
189
- background = np.array(image['background'])
190
- layers = image['layers']
191
-
192
- # Start with the background
193
- final_image = background.copy()
194
-
195
- # Add each layer on top
196
- for layer in layers:
197
- if layer is not None: # Some layers might be None
198
- layer_img = np.array(layer)
199
- # Create a mask for non-transparent pixels
200
- mask = layer_img[..., 3] > 0
201
- # Replace pixels in final_image where mask is True, keeping the alpha channel
202
- final_image[mask] = layer_img[mask]
203
-
204
- # Convert to grayscale using the alpha channel
205
- if len(final_image.shape) == 3:
206
- # Use alpha channel to determine which pixels to keep
207
- alpha_mask = final_image[..., 3] > 0
208
- # Convert to grayscale using standard formula
209
- gray = np.dot(final_image[..., :3], [0.299, 0.587, 0.114])
210
- # Create a white background
211
- final_image = np.ones_like(gray) * 255
212
- # Apply the drawing where alpha > 0
213
- final_image[alpha_mask] = gray[alpha_mask]
214
- # Invert the image to get black on white
215
- final_image = 255 - final_image
216
- elif isinstance(image, Image.Image):
217
- image = np.array(image.convert('L'))
218
- elif isinstance(image, np.ndarray):
219
- if len(image.shape) == 3:
220
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
221
- elif len(image.shape) != 2:
222
- raise ValueError("Invalid image format: Expected grayscale or RGB image")
223
- else:
224
- raise ValueError("Unsupported image input type")
225
-
226
- # For Sketchpad input, use the final_image we created
227
- if isinstance(image, dict):
228
- image = final_image
229
-
230
- # Apply modified process_img
231
- processed_img, best_crop = process_img(image, convert_to_rgb=False)
232
- if processed_img is None:
233
- raise ValueError("Image processing failed: Resulted in invalid image size")
234
-
235
- # Prepare for model input
236
- transform = A.Compose([
237
- A.Normalize(mean=[0.0], std=[1.0]),
238
- ToTensorV2()
239
- ])
240
- processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
241
- image_tensor = transform(image=processed_img)['image'].unsqueeze(0).to(DEVICE)
242
-
243
- return image_tensor, processed_img, best_crop
244
-
245
- # Model inference
246
- def recognize_image(image_tensor, processed_img, best_crop):
247
- with torch.no_grad():
248
- predictions, _ = model.recognize(
249
- image_tensor,
250
- max_length=150,
251
- start_token=vocab.start_token,
252
- end_token=vocab.end_token,
253
- beam_width=5
254
- )
255
-
256
- # Convert indices to LaTeX tokens
257
- latex_tokens = []
258
- for idx in predictions:
259
- if idx == vocab.end_token:
260
- break
261
- if idx != vocab.start_token:
262
- latex_tokens.append(vocab.idx2word[idx])
263
-
264
- latex = ' '.join(latex_tokens)
265
-
266
- # Format LaTeX for rendering
267
- rendered_latex = f"$${latex}$$"
268
-
269
- return latex, rendered_latex
270
-
271
- # Gradio interface function
272
- def process_draw(image):
273
- if image is None:
274
- return "Please draw an expression", ""
275
-
276
- try:
277
- # Process image
278
- image_tensor, processed_img, best_crop = gradio_process_img(image)
279
-
280
- # Recognize
281
- latex, rendered_latex = recognize_image(image_tensor, processed_img, best_crop)
282
-
283
- return latex, rendered_latex
284
- except Exception as e:
285
- return f"Error processing image: {str(e)}", ""
286
-
287
- def process_upload(image):
288
- if image is None:
289
- return "Please upload an image", ""
290
-
291
- try:
292
- # Process image
293
- image_tensor, processed_img, best_crop = gradio_process_img(image)
294
-
295
- # Recognize
296
- latex, rendered_latex = recognize_image(image_tensor, processed_img, best_crop)
297
-
298
- return latex, rendered_latex
299
- except Exception as e:
300
- return f"Error processing image: {str(e)}", ""
301
-
302
- # Enhanced custom CSS with expanded input areas
303
- custom_css = """
304
- /* Global styles */
305
- .gradio-container {
306
- max-width: 1400px !important;
307
- margin: 0 auto !important;
308
- font-family: 'Segoe UI', 'Roboto', sans-serif !important;
309
- padding: 1rem !important;
310
- box-sizing: border-box !important;
311
- }
312
-
313
- /* Header styling */
314
- .header-title {
315
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
316
- -webkit-background-clip: text !important;
317
- -webkit-text-fill-color: transparent !important;
318
- background-clip: text !important;
319
- text-align: center !important;
320
- font-size: clamp(1.8rem, 5vw, 2.5rem) !important;
321
- font-weight: 700 !important;
322
- margin-bottom: 1.5rem !important;
323
- text-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
324
- padding: 0 1rem !important;
325
- }
326
-
327
- /* Main container styling */
328
- .main-container {
329
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%) !important;
330
- border-radius: 20px !important;
331
- padding: clamp(1rem, 3vw, 2rem) !important;
332
- box-shadow: 0 10px 30px rgba(0,0,0,0.1) !important;
333
- margin: 1rem 0 !important;
334
- }
335
-
336
- /* Input section styling - RESPONSIVE */
337
- .input-section {
338
- background: white !important;
339
- border-radius: 15px !important;
340
- padding: clamp(1rem, 3vw, 2rem) !important;
341
- box-shadow: 0 4px 15px rgba(0,0,0,0.05) !important;
342
- border: 1px solid #e1e8ed !important;
343
- min-height: min(700px, 80vh) !important;
344
- width: 100% !important;
345
- box-sizing: border-box !important;
346
- }
347
-
348
- /* Output section styling - RESPONSIVE */
349
- .output-section {
350
- background: white !important;
351
- border-radius: 15px !important;
352
- padding: clamp(1rem, 3vw, 1.5rem) !important;
353
- box-shadow: 0 4px 15px rgba(0,0,0,0.05) !important;
354
- border: 1px solid #e1e8ed !important;
355
- min-height: min(700px, 80vh) !important;
356
- width: 100% !important;
357
- box-sizing: border-box !important;
358
- }
359
-
360
- /* Tab styling - RESPONSIVE */
361
- .tab-nav {
362
- background: #f8f9fa !important;
363
- border-radius: 10px !important;
364
- padding: 0.5rem !important;
365
- margin-bottom: 1.5rem !important;
366
- display: flex !important;
367
- flex-wrap: wrap !important;
368
- gap: 0.5rem !important;
369
- }
370
-
371
- .tab-nav button {
372
- border-radius: 8px !important;
373
- padding: clamp(0.5rem, 2vw, 0.75rem) clamp(1rem, 3vw, 1.5rem) !important;
374
- font-weight: 600 !important;
375
- transition: all 0.3s ease !important;
376
- border: none !important;
377
- background: transparent !important;
378
- color: #6c757d !important;
379
- font-size: clamp(0.9rem, 2vw, 1rem) !important;
380
- white-space: nowrap !important;
381
- }
382
-
383
- /* Sketchpad styling - RESPONSIVE */
384
- .sketchpad-container {
385
- border: 3px dashed #667eea !important;
386
- border-radius: 15px !important;
387
- background: #fafbfc !important;
388
- transition: all 0.3s ease !important;
389
- overflow: hidden !important;
390
- min-height: min(500px, 60vh) !important;
391
- height: min(500px, 60vh) !important;
392
- width: 100% !important;
393
- box-sizing: border-box !important;
394
- }
395
-
396
- .sketchpad-container canvas {
397
- width: 100% !important;
398
- height: 100% !important;
399
- min-height: min(500px, 60vh) !important;
400
- touch-action: none !important;
401
- }
402
-
403
- /* Upload area styling - RESPONSIVE */
404
- .upload-container {
405
- border: 3px dashed #667eea !important;
406
- border-radius: 15px !important;
407
- background: #fafbfc !important;
408
- padding: clamp(1.5rem, 5vw, 3rem) !important;
409
- text-align: center !important;
410
- transition: all 0.3s ease !important;
411
- min-height: min(500px, 60vh) !important;
412
- display: flex !important;
413
- flex-direction: column !important;
414
- justify-content: center !important;
415
- align-items: center !important;
416
- box-sizing: border-box !important;
417
- }
418
-
419
- .upload-container img {
420
- max-height: min(400px, 50vh) !important;
421
- max-width: 100% !important;
422
- object-fit: contain !important;
423
- }
424
-
425
- /* Button styling - RESPONSIVE */
426
- .process-button {
427
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
428
- border: none !important;
429
- border-radius: 12px !important;
430
- padding: clamp(0.8rem, 2vw, 1.2rem) clamp(1.5rem, 4vw, 2.5rem) !important;
431
- font-size: clamp(1rem, 2.5vw, 1.2rem) !important;
432
- font-weight: 600 !important;
433
- color: white !important;
434
- cursor: pointer !important;
435
- transition: all 0.3s ease !important;
436
- box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3) !important;
437
- text-transform: uppercase !important;
438
- letter-spacing: 0.5px !important;
439
- width: 100% !important;
440
- margin-top: 1.5rem !important;
441
- white-space: nowrap !important;
442
- }
443
-
444
- /* Output text styling - RESPONSIVE */
445
- .latex-output {
446
- background: #f8f9fa !important;
447
- border: 1px solid #e9ecef !important;
448
- border-radius: 10px !important;
449
- padding: clamp(1rem, 3vw, 1.5rem) !important;
450
- font-family: 'Monaco', 'Consolas', monospace !important;
451
- font-size: clamp(0.9rem, 2vw, 1rem) !important;
452
- line-height: 1.6 !important;
453
- min-height: min(200px, 30vh) !important;
454
- overflow-x: auto !important;
455
- white-space: pre-wrap !important;
456
- word-break: break-word !important;
457
- }
458
-
459
- .rendered-output {
460
- background: white !important;
461
- border: 1px solid #e9ecef !important;
462
- border-radius: 10px !important;
463
- padding: clamp(1.5rem, 4vw, 2.5rem) !important;
464
- text-align: center !important;
465
- min-height: min(300px, 40vh) !important;
466
- display: flex !important;
467
- align-items: center !important;
468
- justify-content: center !important;
469
- font-size: clamp(1.2rem, 3vw, 1.8rem) !important;
470
- overflow-x: auto !important;
471
- }
472
-
473
- /* Instructions styling - RESPONSIVE */
474
- .instructions {
475
- background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%) !important;
476
- border-radius: 12px !important;
477
- padding: clamp(1rem, 3vw, 1.5rem) !important;
478
- margin-bottom: clamp(1rem, 3vw, 2rem) !important;
479
- border-left: 4px solid #28a745 !important;
480
- }
481
-
482
- .instructions h3 {
483
- color: #155724 !important;
484
- margin-bottom: 0.8rem !important;
485
- font-weight: 600 !important;
486
- font-size: clamp(1rem, 2.5vw, 1.1rem) !important;
487
- }
488
-
489
- .instructions p {
490
- color: #155724 !important;
491
- margin: 0.5rem 0 !important;
492
- font-size: clamp(0.9rem, 2vw, 1rem) !important;
493
- line-height: 1.5 !important;
494
- }
495
-
496
- /* Drawing tips styling - RESPONSIVE */
497
- .drawing-tips {
498
- background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%) !important;
499
- border-radius: 10px !important;
500
- padding: clamp(0.8rem, 2vw, 1rem) !important;
501
- margin-top: 1rem !important;
502
- border-left: 4px solid #fd7e14 !important;
503
- }
504
-
505
- .drawing-tips h4 {
506
- color: #8a4100 !important;
507
- margin-bottom: 0.5rem !important;
508
- font-weight: 600 !important;
509
- font-size: clamp(0.9rem, 2vw, 1rem) !important;
510
- }
511
-
512
- .drawing-tips ul {
513
- color: #8a4100 !important;
514
- margin: 0 !important;
515
- padding-left: clamp(1rem, 3vw, 1.5rem) !important;
516
- }
517
-
518
- .drawing-tips li {
519
- margin: 0.3rem 0 !important;
520
- font-size: clamp(0.8rem, 1.8vw, 0.9rem) !important;
521
- }
522
-
523
- /* Full-width layout adjustments - RESPONSIVE */
524
- .input-output-container {
525
- display: grid !important;
526
- grid-template-columns: repeat(auto-fit, minmax(min(100%, 600px), 1fr)) !important;
527
- gap: clamp(1rem, 3vw, 2rem) !important;
528
- align-items: start !important;
529
- width: 100% !important;
530
- box-sizing: border-box !important;
531
- }
532
-
533
- /* Examples section - RESPONSIVE */
534
- .examples-grid {
535
- display: grid !important;
536
- grid-template-columns: repeat(auto-fit, minmax(min(100%, 250px), 1fr)) !important;
537
- gap: clamp(1rem, 3vw, 1.5rem) !important;
538
- text-align: center !important;
539
- }
540
-
541
- .example-card {
542
- padding: clamp(1rem, 3vw, 1.5rem) !important;
543
- background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%) !important;
544
- border-radius: 12px !important;
545
- border: 2px solid #dee2e6 !important;
546
- }
547
-
548
- .example-card strong {
549
- color: #495057 !important;
550
- font-size: clamp(0.9rem, 2.5vw, 1.1rem) !important;
551
- display: block !important;
552
- margin-bottom: 0.5rem !important;
553
- }
554
-
555
- .example-card span {
556
- font-family: monospace !important;
557
- color: #6c757d !important;
558
- font-size: clamp(0.8rem, 2vw, 0.9rem) !important;
559
- line-height: 1.6 !important;
560
- }
561
-
562
- /* Performance metrics section - RESPONSIVE */
563
- .metrics-grid {
564
- display: grid !important;
565
- grid-template-columns: repeat(auto-fit, minmax(min(100%, 200px), 1fr)) !important;
566
- gap: clamp(0.8rem, 2vw, 1rem) !important;
567
- }
568
-
569
- .metric-item {
570
- text-align: center !important;
571
- padding: clamp(0.5rem, 2vw, 1rem) !important;
572
- }
573
-
574
- .metric-item strong {
575
- color: #e65100 !important;
576
- font-size: clamp(0.9rem, 2.5vw, 1rem) !important;
577
- display: block !important;
578
- margin-bottom: 0.3rem !important;
579
- }
580
-
581
- .metric-item span {
582
- color: #bf360c !important;
583
- font-size: clamp(0.8rem, 2vw, 0.9rem) !important;
584
- }
585
-
586
- /* Responsive breakpoints */
587
- @media (max-width: 1200px) {
588
- .gradio-container {
589
- padding: 0.8rem !important;
590
- }
591
- }
592
-
593
- @media (max-width: 768px) {
594
- .gradio-container {
595
- padding: 0.5rem !important;
596
- }
597
-
598
- .main-container {
599
- padding: 0.8rem !important;
600
- margin: 0.5rem 0 !important;
601
- }
602
-
603
- .input-section, .output-section {
604
- padding: 0.8rem !important;
605
- }
606
-
607
- .tab-nav {
608
- flex-direction: column !important;
609
- }
610
-
611
- .tab-nav button {
612
- width: 100% !important;
613
- }
614
- }
615
-
616
- @media (max-width: 480px) {
617
- .gradio-container {
618
- padding: 0.3rem !important;
619
- }
620
-
621
- .main-container {
622
- padding: 0.5rem !important;
623
- margin: 0.3rem 0 !important;
624
- }
625
-
626
- .input-section, .output-section {
627
- padding: 0.5rem !important;
628
- }
629
-
630
- .process-button {
631
- padding: 0.8rem 1.2rem !important;
632
- font-size: 0.9rem !important;
633
- }
634
- }
635
-
636
- /* Touch device optimizations */
637
- @media (hover: none) {
638
- .process-button:hover {
639
- transform: none !important;
640
- }
641
-
642
- .sketchpad-container {
643
- touch-action: none !important;
644
- -webkit-touch-callout: none !important;
645
- -webkit-user-select: none !important;
646
- user-select: none !important;
647
- }
648
-
649
- .tab-nav button {
650
- padding: 1rem !important;
651
- }
652
- }
653
-
654
- /* Print styles */
655
- @media print {
656
- .gradio-container {
657
- max-width: 100% !important;
658
- padding: 0 !important;
659
- }
660
-
661
- .input-section, .output-section {
662
- break-inside: avoid !important;
663
- }
664
-
665
- .process-button, .tab-nav {
666
- display: none !important;
667
- }
668
- }
669
- """
670
-
671
- # Create the enhanced Gradio interface with expanded input
672
- with gr.Blocks(css=custom_css, title="Math Expression Recognition") as demo:
673
- gr.HTML('<h1 class="header-title">๐Ÿงฎ Handwritten Mathematical Expression Recognition</h1>')
674
-
675
- # Enhanced Instructions
676
- gr.HTML("""
677
- <div class="instructions">
678
- <h3>๐Ÿ“ How to use this expanded interface:</h3>
679
- <p><strong>โœ๏ธ Draw Tab:</strong> Use the large drawing canvas (800x500px) to draw mathematical expressions with your mouse or touch device</p>
680
- <p><strong>๐Ÿ“ Upload Tab:</strong> Upload high-resolution images containing handwritten mathematical expressions</p>
681
- <p><strong>๐ŸŽฏ Tips:</strong> Write clearly, use proper mathematical notation, and ensure good contrast between your writing and the background</p>
682
- </div>
683
- """)
684
-
685
- with gr.Row(elem_classes="input-output-container"):
686
- # Expanded Input Section
687
- with gr.Column(elem_classes="input-section"):
688
- gr.HTML('<h2 style="text-align: center; color: #667eea; margin-bottom: 1.5rem; font-size: 1.5rem;">๐Ÿ“ฅ Input Area</h2>')
689
-
690
- with gr.Tabs():
691
- with gr.TabItem("โœ๏ธ Draw Expression"):
692
- gr.HTML("""
693
- <div class="drawing-tips">
694
- <h4>๐ŸŽจ Drawing Tips:</h4>
695
- <ul>
696
- <li>Use clear, legible handwriting</li>
697
- <li>Draw symbols at reasonable sizes</li>
698
- <li>Leave space between different parts</li>
699
- <li>Use standard mathematical notation</li>
700
- <li>Avoid overlapping symbols</li>
701
- </ul>
702
- </div>
703
- """)
704
-
705
- draw_input = gr.Sketchpad(
706
- label="Draw your mathematical expression here",
707
- elem_classes="sketchpad-container",
708
- height=500,
709
- width=800,
710
- canvas_size=(800, 500)
711
- )
712
- draw_button = gr.Button("๐Ÿš€ Recognize Drawn Expression", elem_classes="process-button")
713
-
714
- with gr.TabItem("๐Ÿ“ Upload Image"):
715
- gr.HTML("""
716
- <div class="drawing-tips">
717
- <h4>๐Ÿ“ท Upload Tips:</h4>
718
- <ul>
719
- <li>Use high-resolution images (minimum 300 DPI)</li>
720
- <li>Ensure good lighting and contrast</li>
721
- <li>Crop the image to focus on the expression</li>
722
- <li>Avoid shadows or glare</li>
723
- <li>Supported formats: PNG, JPG, JPEG</li>
724
- </ul>
725
- </div>
726
- """)
727
-
728
- upload_input = gr.Image(
729
- label="Upload your mathematical expression image",
730
- elem_classes="upload-container",
731
- height=500,
732
- type="pil"
733
- )
734
- upload_button = gr.Button("๐Ÿš€ Recognize Uploaded Expression", elem_classes="process-button")
735
-
736
- # Output Section
737
- with gr.Column(elem_classes="output-section"):
738
- gr.HTML('<h2 style="text-align: center; color: #667eea; margin-bottom: 1.5rem; font-size: 1.5rem;">๐Ÿ“ค Recognition Results</h2>')
739
-
740
- with gr.Tabs():
741
- with gr.TabItem("๐Ÿ“„ LaTeX Code"):
742
- latex_output = gr.Textbox(
743
- label="Generated LaTeX Code",
744
- elem_classes="latex-output",
745
- lines=8,
746
- placeholder="Your LaTeX code will appear here...\n\nThis is the raw LaTeX markup that represents your mathematical expression. You can copy this code and use it in any LaTeX document or LaTeX-compatible system.",
747
- interactive=False
748
- )
749
-
750
- with gr.TabItem("๐ŸŽจ Rendered Expression"):
751
- rendered_output = gr.Markdown(
752
- label="Rendered Mathematical Expression",
753
- elem_classes="rendered-output",
754
- value="*Your beautifully rendered mathematical expression will appear here...*\n\n*Draw or upload an expression to see the magic happen!*"
755
- )
756
-
757
- # Connect the buttons to their respective functions
758
- draw_button.click(
759
- fn=process_draw,
760
- inputs=[draw_input],
761
- outputs=[latex_output, rendered_output]
762
- )
763
-
764
- upload_button.click(
765
- fn=process_upload,
766
- inputs=[upload_input],
767
- outputs=[latex_output, rendered_output]
768
- )
769
-
770
- if __name__ == "__main__":
771
- demo.launch(share=True, inbrowser=True)
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import torch.nn.functional as F
10
+ import os
11
+ import sys
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ import json
16
+ from models.can.can import CAN, create_can_model
17
+ from models.can.can_dataloader import Vocabulary, INPUT_HEIGHT, INPUT_WIDTH
18
+
19
+ # Load configuration
20
+ with open("config.json", "r") as json_file:
21
+ cfg = json.load(json_file)
22
+ CAN_CONFIG = cfg["can"]
23
+
24
+ # Global constants
25
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
27
+ PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
28
+ CHECKPOINT_PATH = f'checkpoints/{BACKBONE_TYPE}_can_best.pth' if not PRETRAINED_BACKBONE else f'checkpoints/p_{BACKBONE_TYPE}_can_best.pth'
29
+
30
+ # Modified process_img to accept numpy array and validate shapes
31
+ def process_img(image, convert_to_rgb=False):
32
+ """
33
+ Process a numpy array image: binarize, ensure black background, resize, and apply padding.
34
+
35
+ Args:
36
+ image: Numpy array (grayscale)
37
+ convert_to_rgb: Whether to convert to RGB
38
+
39
+ Returns:
40
+ Processed image and crop information, or None if invalid
41
+ """
42
+ def is_effectively_binary(img, threshold_percentage=0.9):
43
+ dark_pixels = np.sum(img < 20)
44
+ bright_pixels = np.sum(img > 235)
45
+ total_pixels = img.size
46
+ return (dark_pixels + bright_pixels) / total_pixels > threshold_percentage
47
+
48
+ def before_padding(image):
49
+ if image.shape[0] < 2 or image.shape[1] < 2:
50
+ return None, None # Invalid image size
51
+
52
+ # Ensure image is uint8
53
+ if image.dtype != np.uint8:
54
+ if image.max() <= 1.0: # If image is normalized (0-1)
55
+ image = (image * 255).astype(np.uint8)
56
+ else: # If image is in other float format
57
+ image = np.clip(image, 0, 255).astype(np.uint8)
58
+
59
+ edges = cv2.Canny(image, 50, 150)
60
+ kernel = np.ones((7, 13), np.uint8)
61
+ dilated = cv2.dilate(edges, kernel, iterations=8)
62
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilated, connectivity=8)
63
+ sorted_components = sorted(range(1, num_labels), key=lambda i: stats[i, cv2.CC_STAT_AREA], reverse=True)
64
+ best_f1 = 0
65
+ best_crop = (0, 0, image.shape[1], image.shape[0])
66
+ total_white_pixels = np.sum(dilated > 0)
67
+ current_mask = np.zeros_like(dilated)
68
+ x_min, y_min = image.shape[1], image.shape[0]
69
+ x_max, y_max = 0, 0
70
+
71
+ for component_idx in sorted_components:
72
+ component_mask = labels == component_idx
73
+ current_mask = np.logical_or(current_mask, component_mask)
74
+ comp_y, comp_x = np.where(component_mask)
75
+ if len(comp_x) > 0 and len(comp_y) > 0:
76
+ x_min = min(x_min, np.min(comp_x))
77
+ y_min = min(y_min, np.min(comp_y))
78
+ x_max = max(x_max, np.max(comp_x))
79
+ y_max = max(y_max, np.max(comp_y))
80
+ width = x_max - x_min + 1
81
+ height = y_max - y_min + 1
82
+ if width < 2 or height < 2:
83
+ continue
84
+ crop_area = width * height
85
+ crop_mask = np.zeros_like(dilated)
86
+ crop_mask[y_min:y_max + 1, x_min:x_max + 1] = 1
87
+ white_in_crop = np.sum(np.logical_and(dilated > 0, crop_mask > 0))
88
+ precision = white_in_crop / crop_area if crop_area > 0 else 0
89
+ recall = white_in_crop / total_white_pixels if total_white_pixels > 0 else 0
90
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
91
+ if f1 > best_f1:
92
+ best_f1 = f1
93
+ best_crop = (x_min, y_min, x_max, y_max)
94
+
95
+ x_min, y_min, x_max, y_max = best_crop
96
+ cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
97
+ if cropped_image.shape[0] < 2 or cropped_image.shape[1] < 2:
98
+ return None, None
99
+ if is_effectively_binary(cropped_image):
100
+ _, thresh = cv2.threshold(cropped_image, 127, 255, cv2.THRESH_BINARY)
101
+ else:
102
+ thresh = cv2.adaptiveThreshold(cropped_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
103
+ white = np.sum(thresh == 255)
104
+ black = np.sum(thresh == 0)
105
+ if white > black:
106
+ thresh = 255 - thresh
107
+ denoised = cv2.medianBlur(thresh, 3)
108
+ for _ in range(3):
109
+ denoised = cv2.medianBlur(denoised, 3)
110
+ result = cv2.copyMakeBorder(denoised, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=0)
111
+ return result, best_crop
112
+
113
+ if len(image.shape) != 2:
114
+ return None, None # Expect grayscale image
115
+
116
+ # Ensure image is uint8 before processing
117
+ if image.dtype != np.uint8:
118
+ if image.max() <= 1.0: # If image is normalized (0-1)
119
+ image = (image * 255).astype(np.uint8)
120
+ else: # If image is in other float format
121
+ image = np.clip(image, 0, 255).astype(np.uint8)
122
+
123
+ bin_img, best_crop = before_padding(image)
124
+ if bin_img is None:
125
+ return None, None
126
+ h, w = bin_img.shape
127
+ if h < 2 or w < 2:
128
+ return None, None
129
+ new_w = int((INPUT_HEIGHT / h) * w)
130
+
131
+ if new_w > INPUT_WIDTH:
132
+ resized_img = cv2.resize(bin_img, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_AREA)
133
+ else:
134
+ resized_img = cv2.resize(bin_img, (new_w, INPUT_HEIGHT), interpolation=cv2.INTER_AREA)
135
+ padded_img = np.zeros((INPUT_HEIGHT, INPUT_WIDTH), dtype=np.uint8)
136
+ x_offset = (INPUT_WIDTH - new_w) // 2
137
+ padded_img[:, x_offset:x_offset + new_w] = resized_img
138
+ resized_img = padded_img
139
+
140
+ if convert_to_rgb:
141
+ resized_img = cv2.cvtColor(resized_img, cv2.COLOR_GRAY2BGR)
142
+
143
+ return resized_img, best_crop
144
+
145
+ # Load model and vocabulary
146
+ def load_checkpoint(checkpoint_path, device, pretrained_backbone=True, backbone='densenet'):
147
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
148
+ vocab = checkpoint.get('vocab')
149
+ if vocab is None:
150
+ vocab_path = os.path.join(os.path.dirname(checkpoint_path), 'hmer_vocab.pth')
151
+ if os.path.exists(vocab_path):
152
+ vocab_data = torch.load(vocab_path)
153
+ vocab = Vocabulary()
154
+ vocab.word2idx = vocab_data['word2idx']
155
+ vocab.idx2word = vocab_data['idx2word']
156
+ vocab.idx = vocab_data['idx']
157
+ vocab.pad_token = vocab.word2idx['<pad>']
158
+ vocab.start_token = vocab.word2idx['<start>']
159
+ vocab.end_token = vocab.word2idx['<end>']
160
+ vocab.unk_token = vocab.word2idx['<unk>']
161
+ else:
162
+ raise ValueError(f"Vocabulary not found in checkpoint and {vocab_path} does not exist")
163
+
164
+ hidden_size = checkpoint.get('hidden_size', 256)
165
+ embedding_dim = checkpoint.get('embedding_dim', 256)
166
+ use_coverage = checkpoint.get('use_coverage', True)
167
+
168
+ model = create_can_model(
169
+ num_classes=len(vocab),
170
+ hidden_size=hidden_size,
171
+ embedding_dim=embedding_dim,
172
+ use_coverage=use_coverage,
173
+ pretrained_backbone=pretrained_backbone,
174
+ backbone_type=backbone
175
+ ).to(device)
176
+
177
+ model.load_state_dict(checkpoint['model'])
178
+ model.eval()
179
+ return model, vocab
180
+
181
+ model, vocab = load_checkpoint(CHECKPOINT_PATH, DEVICE, PRETRAINED_BACKBONE, BACKBONE_TYPE)
182
+
183
+ # Image processing function for Gradio
184
+ def gradio_process_img(image, convert_to_rgb=False):
185
+ # Convert Gradio image (PIL, numpy, or dict from Sketchpad) to grayscale numpy array
186
+ if isinstance(image, dict): # Handle Sketchpad input
187
+ # The Sketchpad component returns a dict with 'background' and 'layers' keys
188
+ # We need to combine the background and layers to get the final image
189
+ background = np.array(image['background'])
190
+ layers = image['layers']
191
+
192
+ # Start with the background
193
+ final_image = background.copy()
194
+
195
+ # Add each layer on top
196
+ for layer in layers:
197
+ if layer is not None: # Some layers might be None
198
+ layer_img = np.array(layer)
199
+ # Create a mask for non-transparent pixels
200
+ mask = layer_img[..., 3] > 0
201
+ # Replace pixels in final_image where mask is True, keeping the alpha channel
202
+ final_image[mask] = layer_img[mask]
203
+
204
+ # Convert to grayscale using the alpha channel
205
+ if len(final_image.shape) == 3:
206
+ # Use alpha channel to determine which pixels to keep
207
+ alpha_mask = final_image[..., 3] > 0
208
+ # Convert to grayscale using standard formula
209
+ gray = np.dot(final_image[..., :3], [0.299, 0.587, 0.114])
210
+ # Create a white background
211
+ final_image = np.ones_like(gray) * 255
212
+ # Apply the drawing where alpha > 0
213
+ final_image[alpha_mask] = gray[alpha_mask]
214
+ # Invert the image to get black on white
215
+ final_image = 255 - final_image
216
+ elif isinstance(image, Image.Image):
217
+ image = np.array(image.convert('L'))
218
+ elif isinstance(image, np.ndarray):
219
+ if len(image.shape) == 3:
220
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
221
+ elif len(image.shape) != 2:
222
+ raise ValueError("Invalid image format: Expected grayscale or RGB image")
223
+ else:
224
+ raise ValueError("Unsupported image input type")
225
+
226
+ # For Sketchpad input, use the final_image we created
227
+ if isinstance(image, dict):
228
+ image = final_image
229
+
230
+ # Apply modified process_img
231
+ processed_img, best_crop = process_img(image, convert_to_rgb=False)
232
+ if processed_img is None:
233
+ raise ValueError("Image processing failed: Resulted in invalid image size")
234
+
235
+ # Prepare for model input
236
+ transform = A.Compose([
237
+ A.Normalize(mean=[0.0], std=[1.0]),
238
+ ToTensorV2()
239
+ ])
240
+ processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
241
+ image_tensor = transform(image=processed_img)['image'].unsqueeze(0).to(DEVICE)
242
+
243
+ return image_tensor, processed_img, best_crop
244
+
245
+ # Model inference
246
+ def recognize_image(image_tensor, processed_img, best_crop):
247
+ with torch.no_grad():
248
+ predictions, _ = model.recognize(
249
+ image_tensor,
250
+ max_length=150,
251
+ start_token=vocab.start_token,
252
+ end_token=vocab.end_token,
253
+ beam_width=5
254
+ )
255
+
256
+ # Convert indices to LaTeX tokens
257
+ latex_tokens = []
258
+ for idx in predictions:
259
+ if idx == vocab.end_token:
260
+ break
261
+ if idx != vocab.start_token:
262
+ latex_tokens.append(vocab.idx2word[idx])
263
+
264
+ latex = ' '.join(latex_tokens)
265
+
266
+ # Format LaTeX for rendering
267
+ rendered_latex = f"$${latex}$$"
268
+
269
+ return latex, rendered_latex
270
+
271
+ # Gradio interface function
272
+ def process_draw(image):
273
+ if image is None:
274
+ return "Please draw an expression", ""
275
+
276
+ try:
277
+ # Process image
278
+ image_tensor, processed_img, best_crop = gradio_process_img(image)
279
+
280
+ # Recognize
281
+ latex, rendered_latex = recognize_image(image_tensor, processed_img, best_crop)
282
+
283
+ return latex, rendered_latex
284
+ except Exception as e:
285
+ return f"Error processing image: {str(e)}", ""
286
+
287
+ def process_upload(image):
288
+ if image is None:
289
+ return "Please upload an image", ""
290
+
291
+ try:
292
+ # Process image
293
+ image_tensor, processed_img, best_crop = gradio_process_img(image)
294
+
295
+ # Recognize
296
+ latex, rendered_latex = recognize_image(image_tensor, processed_img, best_crop)
297
+
298
+ return latex, rendered_latex
299
+ except Exception as e:
300
+ return f"Error processing image: {str(e)}", ""
301
+
302
+ # Enhanced custom CSS with expanded input areas
303
+ custom_css = """
304
+ /* Global styles */
305
+ .gradio-container {
306
+ max-width: 1400px !important;
307
+ margin: 0 auto !important;
308
+ font-family: 'Segoe UI', 'Roboto', sans-serif !important;
309
+ padding: 1rem !important;
310
+ box-sizing: border-box !important;
311
+ }
312
+
313
+ /* Header styling */
314
+ .header-title {
315
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
316
+ -webkit-background-clip: text !important;
317
+ -webkit-text-fill-color: transparent !important;
318
+ background-clip: text !important;
319
+ text-align: center !important;
320
+ font-size: clamp(1.8rem, 5vw, 2.5rem) !important;
321
+ font-weight: 700 !important;
322
+ margin-bottom: 1.5rem !important;
323
+ text-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
324
+ padding: 0 1rem !important;
325
+ }
326
+
327
+ /* Main container styling */
328
+ .main-container {
329
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%) !important;
330
+ border-radius: 20px !important;
331
+ padding: clamp(1rem, 3vw, 2rem) !important;
332
+ box-shadow: 0 10px 30px rgba(0,0,0,0.1) !important;
333
+ margin: 1rem 0 !important;
334
+ }
335
+
336
+ /* Input section styling - RESPONSIVE */
337
+ .input-section {
338
+ background: white !important;
339
+ border-radius: 15px !important;
340
+ padding: clamp(1rem, 3vw, 2rem) !important;
341
+ box-shadow: 0 4px 15px rgba(0,0,0,0.05) !important;
342
+ border: 1px solid #e1e8ed !important;
343
+ min-height: min(700px, 80vh) !important;
344
+ width: 100% !important;
345
+ box-sizing: border-box !important;
346
+ }
347
+
348
+ /* Output section styling - RESPONSIVE */
349
+ .output-section {
350
+ background: white !important;
351
+ border-radius: 15px !important;
352
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
353
+ box-shadow: 0 4px 15px rgba(0,0,0,0.05) !important;
354
+ border: 1px solid #e1e8ed !important;
355
+ min-height: min(700px, 80vh) !important;
356
+ width: 100% !important;
357
+ box-sizing: border-box !important;
358
+ }
359
+
360
+ /* Tab styling - RESPONSIVE */
361
+ .tab-nav {
362
+ background: #f8f9fa !important;
363
+ border-radius: 10px !important;
364
+ padding: 0.5rem !important;
365
+ margin-bottom: 1.5rem !important;
366
+ display: flex !important;
367
+ flex-wrap: wrap !important;
368
+ gap: 0.5rem !important;
369
+ }
370
+
371
+ .tab-nav button {
372
+ border-radius: 8px !important;
373
+ padding: clamp(0.5rem, 2vw, 0.75rem) clamp(1rem, 3vw, 1.5rem) !important;
374
+ font-weight: 600 !important;
375
+ transition: all 0.3s ease !important;
376
+ border: none !important;
377
+ background: transparent !important;
378
+ color: #6c757d !important;
379
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
380
+ white-space: nowrap !important;
381
+ }
382
+
383
+ /* Sketchpad styling - RESPONSIVE */
384
+ .sketchpad-container {
385
+ border: 3px dashed #667eea !important;
386
+ border-radius: 15px !important;
387
+ background: #fafbfc !important;
388
+ transition: all 0.3s ease !important;
389
+ overflow: hidden !important;
390
+ min-height: min(500px, 60vh) !important;
391
+ height: min(500px, 60vh) !important;
392
+ width: 100% !important;
393
+ box-sizing: border-box !important;
394
+ }
395
+
396
+ .sketchpad-container canvas {
397
+ width: 100% !important;
398
+ height: 100% !important;
399
+ min-height: min(500px, 60vh) !important;
400
+ touch-action: none !important;
401
+ }
402
+
403
+ /* Upload area styling - RESPONSIVE */
404
+ .upload-container {
405
+ border: 3px dashed #667eea !important;
406
+ border-radius: 15px !important;
407
+ background: #fafbfc !important;
408
+ padding: clamp(1.5rem, 5vw, 3rem) !important;
409
+ text-align: center !important;
410
+ transition: all 0.3s ease !important;
411
+ min-height: min(500px, 60vh) !important;
412
+ display: flex !important;
413
+ flex-direction: column !important;
414
+ justify-content: center !important;
415
+ align-items: center !important;
416
+ box-sizing: border-box !important;
417
+ }
418
+
419
+ .upload-container img {
420
+ max-height: min(400px, 50vh) !important;
421
+ max-width: 100% !important;
422
+ object-fit: contain !important;
423
+ }
424
+
425
+ /* Button styling - RESPONSIVE */
426
+ .process-button {
427
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
428
+ border: none !important;
429
+ border-radius: 12px !important;
430
+ padding: clamp(0.8rem, 2vw, 1.2rem) clamp(1.5rem, 4vw, 2.5rem) !important;
431
+ font-size: clamp(1rem, 2.5vw, 1.2rem) !important;
432
+ font-weight: 600 !important;
433
+ color: white !important;
434
+ cursor: pointer !important;
435
+ transition: all 0.3s ease !important;
436
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3) !important;
437
+ text-transform: uppercase !important;
438
+ letter-spacing: 0.5px !important;
439
+ width: 100% !important;
440
+ margin-top: 1.5rem !important;
441
+ white-space: nowrap !important;
442
+ }
443
+
444
+ /* Output text styling - RESPONSIVE */
445
+ .latex-output {
446
+ background: #f8f9fa !important;
447
+ border: 1px solid #e9ecef !important;
448
+ border-radius: 10px !important;
449
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
450
+ font-family: 'Monaco', 'Consolas', monospace !important;
451
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
452
+ line-height: 1.6 !important;
453
+ min-height: min(200px, 30vh) !important;
454
+ overflow-x: auto !important;
455
+ white-space: pre-wrap !important;
456
+ word-break: break-word !important;
457
+ }
458
+
459
+ .rendered-output {
460
+ background: white !important;
461
+ border: 1px solid #e9ecef !important;
462
+ border-radius: 10px !important;
463
+ padding: clamp(1.5rem, 4vw, 2.5rem) !important;
464
+ text-align: center !important;
465
+ min-height: min(300px, 40vh) !important;
466
+ display: flex !important;
467
+ align-items: center !important;
468
+ justify-content: center !important;
469
+ font-size: clamp(1.2rem, 3vw, 1.8rem) !important;
470
+ overflow-x: auto !important;
471
+ }
472
+
473
+ /* Instructions styling - RESPONSIVE */
474
+ .instructions {
475
+ background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%) !important;
476
+ border-radius: 12px !important;
477
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
478
+ margin-bottom: clamp(1rem, 3vw, 2rem) !important;
479
+ border-left: 4px solid #28a745 !important;
480
+ }
481
+
482
+ .instructions h3 {
483
+ color: #155724 !important;
484
+ margin-bottom: 0.8rem !important;
485
+ font-weight: 600 !important;
486
+ font-size: clamp(1rem, 2.5vw, 1.1rem) !important;
487
+ }
488
+
489
+ .instructions p {
490
+ color: #155724 !important;
491
+ margin: 0.5rem 0 !important;
492
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
493
+ line-height: 1.5 !important;
494
+ }
495
+
496
+ /* Drawing tips styling - RESPONSIVE */
497
+ .drawing-tips {
498
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%) !important;
499
+ border-radius: 10px !important;
500
+ padding: clamp(0.8rem, 2vw, 1rem) !important;
501
+ margin-top: 1rem !important;
502
+ border-left: 4px solid #fd7e14 !important;
503
+ }
504
+
505
+ .drawing-tips h4 {
506
+ color: #8a4100 !important;
507
+ margin-bottom: 0.5rem !important;
508
+ font-weight: 600 !important;
509
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
510
+ }
511
+
512
+ .drawing-tips ul {
513
+ color: #8a4100 !important;
514
+ margin: 0 !important;
515
+ padding-left: clamp(1rem, 3vw, 1.5rem) !important;
516
+ }
517
+
518
+ .drawing-tips li {
519
+ margin: 0.3rem 0 !important;
520
+ font-size: clamp(0.8rem, 1.8vw, 0.9rem) !important;
521
+ }
522
+
523
+ /* Full-width layout adjustments - RESPONSIVE */
524
+ .input-output-container {
525
+ display: grid !important;
526
+ grid-template-columns: repeat(auto-fit, minmax(min(100%, 600px), 1fr)) !important;
527
+ gap: clamp(1rem, 3vw, 2rem) !important;
528
+ align-items: start !important;
529
+ width: 100% !important;
530
+ box-sizing: border-box !important;
531
+ }
532
+
533
+ /* Examples section - RESPONSIVE */
534
+ .examples-grid {
535
+ display: grid !important;
536
+ grid-template-columns: repeat(auto-fit, minmax(min(100%, 250px), 1fr)) !important;
537
+ gap: clamp(1rem, 3vw, 1.5rem) !important;
538
+ text-align: center !important;
539
+ }
540
+
541
+ .example-card {
542
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
543
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%) !important;
544
+ border-radius: 12px !important;
545
+ border: 2px solid #dee2e6 !important;
546
+ }
547
+
548
+ .example-card strong {
549
+ color: #495057 !important;
550
+ font-size: clamp(0.9rem, 2.5vw, 1.1rem) !important;
551
+ display: block !important;
552
+ margin-bottom: 0.5rem !important;
553
+ }
554
+
555
+ .example-card span {
556
+ font-family: monospace !important;
557
+ color: #6c757d !important;
558
+ font-size: clamp(0.8rem, 2vw, 0.9rem) !important;
559
+ line-height: 1.6 !important;
560
+ }
561
+
562
+ /* Performance metrics section - RESPONSIVE */
563
+ .metrics-grid {
564
+ display: grid !important;
565
+ grid-template-columns: repeat(auto-fit, minmax(min(100%, 200px), 1fr)) !important;
566
+ gap: clamp(0.8rem, 2vw, 1rem) !important;
567
+ }
568
+
569
+ .metric-item {
570
+ text-align: center !important;
571
+ padding: clamp(0.5rem, 2vw, 1rem) !important;
572
+ }
573
+
574
+ .metric-item strong {
575
+ color: #e65100 !important;
576
+ font-size: clamp(0.9rem, 2.5vw, 1rem) !important;
577
+ display: block !important;
578
+ margin-bottom: 0.3rem !important;
579
+ }
580
+
581
+ .metric-item span {
582
+ color: #bf360c !important;
583
+ font-size: clamp(0.8rem, 2vw, 0.9rem) !important;
584
+ }
585
+
586
+ /* Responsive breakpoints */
587
+ @media (max-width: 1200px) {
588
+ .gradio-container {
589
+ padding: 0.8rem !important;
590
+ }
591
+ }
592
+
593
+ @media (max-width: 768px) {
594
+ .gradio-container {
595
+ padding: 0.5rem !important;
596
+ }
597
+
598
+ .main-container {
599
+ padding: 0.8rem !important;
600
+ margin: 0.5rem 0 !important;
601
+ }
602
+
603
+ .input-section, .output-section {
604
+ padding: 0.8rem !important;
605
+ }
606
+
607
+ .tab-nav {
608
+ flex-direction: column !important;
609
+ }
610
+
611
+ .tab-nav button {
612
+ width: 100% !important;
613
+ }
614
+ }
615
+
616
+ @media (max-width: 480px) {
617
+ .gradio-container {
618
+ padding: 0.3rem !important;
619
+ }
620
+
621
+ .main-container {
622
+ padding: 0.5rem !important;
623
+ margin: 0.3rem 0 !important;
624
+ }
625
+
626
+ .input-section, .output-section {
627
+ padding: 0.5rem !important;
628
+ }
629
+
630
+ .process-button {
631
+ padding: 0.8rem 1.2rem !important;
632
+ font-size: 0.9rem !important;
633
+ }
634
+ }
635
+
636
+ /* Touch device optimizations */
637
+ @media (hover: none) {
638
+ .process-button:hover {
639
+ transform: none !important;
640
+ }
641
+
642
+ .sketchpad-container {
643
+ touch-action: none !important;
644
+ -webkit-touch-callout: none !important;
645
+ -webkit-user-select: none !important;
646
+ user-select: none !important;
647
+ }
648
+
649
+ .tab-nav button {
650
+ padding: 1rem !important;
651
+ }
652
+ }
653
+
654
+ /* Print styles */
655
+ @media print {
656
+ .gradio-container {
657
+ max-width: 100% !important;
658
+ padding: 0 !important;
659
+ }
660
+
661
+ .input-section, .output-section {
662
+ break-inside: avoid !important;
663
+ }
664
+
665
+ .process-button, .tab-nav {
666
+ display: none !important;
667
+ }
668
+ }
669
+ """
670
+
671
+ # Create the enhanced Gradio interface with expanded input
672
+ with gr.Blocks(css=custom_css, title="Math Expression Recognition") as demo:
673
+ gr.HTML('<h1 class="header-title">๐Ÿงฎ Handwritten Mathematical Expression Recognition</h1>')
674
+
675
+ # Enhanced Instructions
676
+ gr.HTML("""
677
+ <div class="instructions">
678
+ <h3>๐Ÿ“ How to use this expanded interface:</h3>
679
+ <p><strong>โœ๏ธ Draw Tab:</strong> Use the large drawing canvas (800x500px) to draw mathematical expressions with your mouse or touch device</p>
680
+ <p><strong>๐Ÿ“ Upload Tab:</strong> Upload high-resolution images containing handwritten mathematical expressions</p>
681
+ <p><strong>๐ŸŽฏ Tips:</strong> Write clearly, use proper mathematical notation, and ensure good contrast between your writing and the background</p>
682
+ </div>
683
+ """)
684
+
685
+ with gr.Row(elem_classes="input-output-container"):
686
+ # Expanded Input Section
687
+ with gr.Column(elem_classes="input-section"):
688
+ gr.HTML('<h2 style="text-align: center; color: #667eea; margin-bottom: 1.5rem; font-size: 1.5rem;">๐Ÿ“ฅ Input Area</h2>')
689
+
690
+ with gr.Tabs():
691
+ with gr.TabItem("โœ๏ธ Draw Expression"):
692
+ gr.HTML("""
693
+ <div class="drawing-tips">
694
+ <h4>๐ŸŽจ Drawing Tips:</h4>
695
+ <ul>
696
+ <li>Use clear, legible handwriting</li>
697
+ <li>Draw symbols at reasonable sizes</li>
698
+ <li>Leave space between different parts</li>
699
+ <li>Use standard mathematical notation</li>
700
+ <li>Avoid overlapping symbols</li>
701
+ </ul>
702
+ </div>
703
+ """)
704
+
705
+ draw_input = gr.Sketchpad(
706
+ label="Draw your mathematical expression here",
707
+ elem_classes="sketchpad-container",
708
+ height=500,
709
+ width=800,
710
+ canvas_size=(800, 500)
711
+ )
712
+ draw_button = gr.Button("๐Ÿš€ Recognize Drawn Expression", elem_classes="process-button")
713
+
714
+ with gr.TabItem("๐Ÿ“ Upload Image"):
715
+ gr.HTML("""
716
+ <div class="drawing-tips">
717
+ <h4>๐Ÿ“ท Upload Tips:</h4>
718
+ <ul>
719
+ <li>Use high-resolution images (minimum 300 DPI)</li>
720
+ <li>Ensure good lighting and contrast</li>
721
+ <li>Crop the image to focus on the expression</li>
722
+ <li>Avoid shadows or glare</li>
723
+ <li>Supported formats: PNG, JPG, JPEG</li>
724
+ </ul>
725
+ </div>
726
+ """)
727
+
728
+ upload_input = gr.Image(
729
+ label="Upload your mathematical expression image",
730
+ elem_classes="upload-container",
731
+ height=500,
732
+ type="pil"
733
+ )
734
+ upload_button = gr.Button("๐Ÿš€ Recognize Uploaded Expression", elem_classes="process-button")
735
+
736
+ # Output Section
737
+ with gr.Column(elem_classes="output-section"):
738
+ gr.HTML('<h2 style="text-align: center; color: #667eea; margin-bottom: 1.5rem; font-size: 1.5rem;">๐Ÿ“ค Recognition Results</h2>')
739
+
740
+ with gr.Tabs():
741
+ with gr.TabItem("๐Ÿ“„ LaTeX Code"):
742
+ latex_output = gr.Textbox(
743
+ label="Generated LaTeX Code",
744
+ elem_classes="latex-output",
745
+ lines=8,
746
+ placeholder="Your LaTeX code will appear here...\n\nThis is the raw LaTeX markup that represents your mathematical expression. You can copy this code and use it in any LaTeX document or LaTeX-compatible system.",
747
+ interactive=False
748
+ )
749
+
750
+ with gr.TabItem("๐ŸŽจ Rendered Expression"):
751
+ rendered_output = gr.Markdown(
752
+ label="Rendered Mathematical Expression",
753
+ elem_classes="rendered-output",
754
+ value="*Your beautifully rendered mathematical expression will appear here...*\n\n*Draw or upload an expression to see the magic happen!*"
755
+ )
756
+
757
+ # Connect the buttons to their respective functions
758
+ draw_button.click(
759
+ fn=process_draw,
760
+ inputs=[draw_input],
761
+ outputs=[latex_output, rendered_output]
762
+ )
763
+
764
+ upload_button.click(
765
+ fn=process_upload,
766
+ inputs=[upload_input],
767
+ outputs=[latex_output, rendered_output]
768
+ )
769
+
770
+ if __name__ == "__main__":
771
+ demo.launch(
772
+ server_name="0.0.0.0",
773
+ server_port=7860,
774
+ share=False,
775
+ show_error=True,
776
+ inbrowser=True
777
+ )