santhoshkammari commited on
Commit
cc52c7a
·
verified ·
1 Parent(s): 04e32f5

Upload split_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. split_model.py +651 -0
split_model.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader, Subset
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import os
8
+ import numpy as np
9
+ from bs4 import BeautifulSoup
10
+ import argparse
11
+ import logging
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ from datetime import datetime
14
+ import json
15
+ from PIL import Image, ImageDraw
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ def get_ground_truth(image, cells, otsl, split_width=5):
20
+
21
+ """
22
+ parse OTSL to derive row/column split positions.
23
+ this is the groundtruth for split model training.
24
+
25
+ Args:
26
+ image: PIL Image
27
+ html_tags: not used, kept for compatibility
28
+ cells: nested list - cells[0] contains actual cell data
29
+ otsl: OTSL token sequence
30
+ split_width: width of split regions in pixels (default: 5)
31
+ """
32
+ orig_width, orig_height = image.size
33
+ target_size = 960
34
+
35
+ # cells is nested - extract actual list
36
+ cells_flat = cells[0]
37
+
38
+ # parse OTSL to build 2D grid
39
+ grid = []
40
+ current_row = []
41
+ cell_idx = 0 # only increments for fcel ,ecel tokens
42
+
43
+ for token in otsl:
44
+ if token == 'nl':
45
+ if current_row:
46
+ grid.append(current_row)
47
+ current_row = []
48
+ elif token == 'fcel' or token=='ecel':
49
+ current_row.append({'type': token, 'cell_idx': cell_idx})
50
+ cell_idx += 1
51
+ elif token in ['lcel', 'ucel', 'xcel']:
52
+ # merge/empty tokens don't consume bboxes
53
+ current_row.append({'type': token, 'cell_idx': None})
54
+
55
+ if current_row:
56
+ grid.append(current_row)
57
+
58
+ # derive row splits - max y2 for each row
59
+ row_splits = []
60
+ for row in grid:
61
+ row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None]
62
+ if row_cell_indices:
63
+ max_y = max(cells_flat[i]['bbox'][3] for i in row_cell_indices)
64
+ row_splits.append(max_y)
65
+
66
+ # derive column splits - max x2 for each column
67
+ num_cols = len(grid[0]) if grid else 0
68
+ col_splits = []
69
+ for col_idx in range(num_cols):
70
+ col_max_x = []
71
+ for row in grid:
72
+ if col_idx < len(row) and row[col_idx]['cell_idx'] is not None:
73
+ next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel')
74
+ if not next_is_lcel:
75
+ cell_id = row[col_idx]['cell_idx']
76
+ col_max_x.append(cells_flat[cell_id]['bbox'][2])
77
+ if col_max_x:
78
+ col_splits.append(max(col_max_x))
79
+
80
+ # # DEBUG: print what we found
81
+ # print(f"\nDEBUG get_ground_truth:")
82
+ # print(f" Found {len(row_splits)} row splits: {row_splits}")
83
+ # print(f" Found {len(col_splits)} col splits: {col_splits}")
84
+
85
+ # # scale to target size
86
+ # y_scaled = [(y * target_size / orig_height) for y in row_splits]
87
+ # x_scaled = [(x * target_size / orig_width) for x in col_splits]
88
+
89
+ # print(f" Scaled row splits: {[int(y) for y in y_scaled]}")
90
+ # print(f" Scaled col splits: {[int(x) for x in x_scaled]}")
91
+
92
+
93
+ row_splits = row_splits[:-1]
94
+ col_splits = col_splits[:-1]
95
+
96
+ # scale to target size
97
+ y_scaled = [(y * target_size / orig_height) for y in row_splits]
98
+ x_scaled = [(x * target_size / orig_width) for x in col_splits]
99
+
100
+ # init ground truth arrays
101
+ horizontal_gt = [0] * target_size
102
+ vertical_gt = [0] * target_size
103
+
104
+ all_x1 = [c['bbox'][0] for c in cells_flat]
105
+ all_y1 = [c['bbox'][1] for c in cells_flat]
106
+ all_x2 = [c['bbox'][2] for c in cells_flat]
107
+ all_y2 = [c['bbox'][3] for c in cells_flat]
108
+ table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)]
109
+ table_y1 = int(round(table_bbox[1] * target_size / orig_height))
110
+ table_y2 = int(round(table_bbox[3] * target_size / orig_height))
111
+ table_x1 = int(round(table_bbox[0] * target_size / orig_width))
112
+ table_x2 = int(round(table_bbox[2] * target_size / orig_width))
113
+
114
+
115
+ # Mark table bbox boundaries (5 pixels wide)
116
+ # Top boundary
117
+ for offset in range(split_width):
118
+ pos = table_y1 + offset
119
+ if 0 <= pos < target_size:
120
+ horizontal_gt[pos] = 1
121
+
122
+ # Bottom boundary
123
+ for offset in range(split_width):
124
+ pos = table_y2 - offset
125
+ if 0 <= pos < target_size:
126
+ horizontal_gt[pos] = 1
127
+
128
+ # Left boundary
129
+ for offset in range(split_width):
130
+ pos = table_x1 + offset
131
+ if 0 <= pos < target_size:
132
+ vertical_gt[pos] = 1
133
+
134
+ # Right boundary
135
+ for offset in range(split_width):
136
+ pos = table_x2 - offset
137
+ if 0 <= pos < target_size:
138
+ vertical_gt[pos] = 1
139
+
140
+ # mark split regions (configurable pixel width)
141
+ for y in y_scaled:
142
+ y_int = int(round(y))
143
+ if 0 <= y_int < target_size:
144
+ for offset in range(split_width):
145
+ pos = y_int + offset
146
+ if 0 <= pos < target_size:
147
+ horizontal_gt[pos] = 1
148
+
149
+ for x in x_scaled:
150
+ x_int = int(round(x))
151
+ if 0 <= x_int < target_size:
152
+ for offset in range(split_width):
153
+ pos = x_int + offset
154
+ if 0 <= pos < target_size:
155
+ vertical_gt[pos] = 1
156
+
157
+ return horizontal_gt, vertical_gt
158
+
159
+
160
+ def get_ground_truth_auto_gap(image, cells, otsl):
161
+ """
162
+ Parse OTSL to derive row/column split positions with DYNAMIC gap widths.
163
+ This creates ground truth for the split model training.
164
+
165
+ Args:
166
+ image: PIL Image
167
+ cells: nested list - cells[0] contains actual cell data
168
+ otsl: OTSL token sequence
169
+ """
170
+ orig_width, orig_height = image.size
171
+ target_size = 960
172
+
173
+ # cells is nested - extract actual list
174
+ cells_flat = cells[0]
175
+
176
+ # Parse OTSL to build 2D grid
177
+ grid = []
178
+ current_row = []
179
+ cell_idx = 0 # only increments for fcel, ecel tokens
180
+
181
+ for token in otsl:
182
+ if token == 'nl':
183
+ if current_row:
184
+ grid.append(current_row)
185
+ current_row = []
186
+ elif token == 'fcel' or token == 'ecel':
187
+ current_row.append({'type': token, 'cell_idx': cell_idx})
188
+ cell_idx += 1
189
+ elif token in ['lcel', 'ucel', 'xcel']:
190
+ # merge/empty tokens don't consume bboxes
191
+ current_row.append({'type': token, 'cell_idx': None})
192
+
193
+ if current_row:
194
+ grid.append(current_row)
195
+
196
+ # Get row boundaries (min y1 and max y2 for each row)
197
+ row_boundaries = []
198
+ for row in grid:
199
+ row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None]
200
+ if row_cell_indices:
201
+ min_y1 = min(cells_flat[i]['bbox'][1] for i in row_cell_indices)
202
+ max_y2 = max(cells_flat[i]['bbox'][3] for i in row_cell_indices)
203
+ row_boundaries.append({'min_y': min_y1, 'max_y': max_y2})
204
+
205
+ # Get column boundaries (min x1 and max x2 for each column)
206
+ num_cols = len(grid[0]) if grid else 0
207
+ col_boundaries = []
208
+ for col_idx in range(num_cols):
209
+ col_cells = []
210
+ for row in grid:
211
+ if col_idx < len(row) and row[col_idx]['cell_idx'] is not None:
212
+ # Check if next cell is lcel (merged left)
213
+ next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel')
214
+ if not next_is_lcel:
215
+ cell_id = row[col_idx]['cell_idx']
216
+ col_cells.append(cell_id)
217
+ if col_cells:
218
+ min_x1 = min(cells_flat[i]['bbox'][0] for i in col_cells)
219
+ max_x2 = max(cells_flat[i]['bbox'][2] for i in col_cells)
220
+ col_boundaries.append({'min_x': min_x1, 'max_x': max_x2})
221
+
222
+ # Calculate table bbox
223
+ all_x1 = [c['bbox'][0] for c in cells_flat]
224
+ all_y1 = [c['bbox'][1] for c in cells_flat]
225
+ all_x2 = [c['bbox'][2] for c in cells_flat]
226
+ all_y2 = [c['bbox'][3] for c in cells_flat]
227
+ table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)]
228
+
229
+ # Init ground truth arrays
230
+ horizontal_gt = [0] * target_size
231
+ vertical_gt = [0] * target_size
232
+
233
+ # Helper function to scale and mark range
234
+ def mark_range(gt_array, start, end, orig_dim):
235
+ """Mark all pixels from start to end (scaled to target_size)"""
236
+ start_scaled = int(round(start * target_size / orig_dim))
237
+ end_scaled = int(round(end * target_size / orig_dim))
238
+ for pos in range(start_scaled, min(end_scaled + 1, target_size)):
239
+ if 0 <= pos < target_size:
240
+ gt_array[pos] = 1
241
+
242
+ # Mark HORIZONTAL gaps (between rows)
243
+ # 1. Gap from image top to first row top
244
+ if row_boundaries:
245
+ mark_range(horizontal_gt, 0, row_boundaries[0]['min_y'], orig_height)
246
+
247
+ # 2. Gaps between consecutive rows
248
+ for i in range(len(row_boundaries) - 1):
249
+ gap_start = row_boundaries[i]['max_y']
250
+ gap_end = row_boundaries[i + 1]['min_y']
251
+ if gap_end > gap_start: # Only mark if there's actual gap
252
+ mark_range(horizontal_gt, gap_start, gap_end, orig_height)
253
+
254
+ # 3. Gap from last row bottom to image bottom
255
+ if row_boundaries:
256
+ mark_range(horizontal_gt, row_boundaries[-1]['max_y'], orig_height, orig_height)
257
+
258
+ # Mark VERTICAL gaps (between columns)
259
+ # 1. Gap from image left to first column left
260
+ if col_boundaries:
261
+ mark_range(vertical_gt, 0, col_boundaries[0]['min_x'], orig_width)
262
+
263
+ # 2. Gaps between consecutive columns
264
+ for i in range(len(col_boundaries) - 1):
265
+ gap_start = col_boundaries[i]['max_x']
266
+ gap_end = col_boundaries[i + 1]['min_x']
267
+ if gap_end > gap_start: # Only mark if there's actual gap
268
+ mark_range(vertical_gt, gap_start, gap_end, orig_width)
269
+
270
+ # 3. Gap from last column right to image right
271
+ if col_boundaries:
272
+ mark_range(vertical_gt, col_boundaries[-1]['max_x'], orig_width, orig_width)
273
+
274
+ return horizontal_gt, vertical_gt
275
+
276
+
277
+ def get_ground_truth_auto_gap_expand_min5pix_overlap_cells(image, cells, otsl, split_width=5):
278
+ """
279
+ Parse OTSL to derive row/column split positions with DYNAMIC gap widths.
280
+ This creates ground truth for the split model training.
281
+
282
+ Args:
283
+ image: PIL Image
284
+ cells: nested list - cells[0] contains actual cell data
285
+ otsl: OTSL token sequence
286
+ split_width: width of split when there's no gap (default: 5)
287
+ """
288
+ orig_width, orig_height = image.size
289
+ target_size = 960
290
+
291
+ # cells is nested - extract actual list
292
+ cells_flat = cells[0]
293
+
294
+ # Parse OTSL to build 2D grid
295
+ grid = []
296
+ current_row = []
297
+ cell_idx = 0 # only increments for fcel, ecel tokens
298
+
299
+ for token in otsl:
300
+ if token == 'nl':
301
+ if current_row:
302
+ grid.append(current_row)
303
+ current_row = []
304
+ elif token in ['fcel', 'ecel']: # FIXED: was == ['fcel','ecel']
305
+ current_row.append({'type': token, 'cell_idx': cell_idx})
306
+ cell_idx += 1
307
+ elif token in ['lcel', 'ucel', 'xcel']:
308
+ # merge/empty tokens don't consume bboxes
309
+ current_row.append({'type': token, 'cell_idx': None})
310
+
311
+ if current_row:
312
+ grid.append(current_row)
313
+
314
+ # Get row boundaries (min y1 and max y2 for each row)
315
+ row_boundaries = []
316
+ for row in grid:
317
+ row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None]
318
+ if row_cell_indices:
319
+ min_y1 = min(cells_flat[i]['bbox'][1] for i in row_cell_indices)
320
+ max_y2 = max(cells_flat[i]['bbox'][3] for i in row_cell_indices)
321
+ row_boundaries.append({'min_y': min_y1, 'max_y': max_y2, 'row_cells': row_cell_indices})
322
+
323
+ # Get column boundaries (min x1 and max x2 for each column)
324
+ num_cols = len(grid[0]) if grid else 0
325
+ col_boundaries = []
326
+ for col_idx in range(num_cols):
327
+ col_cells = []
328
+ for row in grid:
329
+ if col_idx < len(row) and row[col_idx]['cell_idx'] is not None:
330
+ # Check if next cell is lcel (merged left)
331
+ next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel')
332
+ if not next_is_lcel:
333
+ cell_id = row[col_idx]['cell_idx']
334
+ col_cells.append(cell_id)
335
+ if col_cells:
336
+ min_x1 = min(cells_flat[i]['bbox'][0] for i in col_cells)
337
+ max_x2 = max(cells_flat[i]['bbox'][2] for i in col_cells)
338
+ col_boundaries.append({'min_x': min_x1, 'max_x': max_x2, 'col_cells': col_cells})
339
+
340
+ # Calculate table bbox
341
+ all_x1 = [c['bbox'][0] for c in cells_flat]
342
+ all_y1 = [c['bbox'][1] for c in cells_flat]
343
+ all_x2 = [c['bbox'][2] for c in cells_flat]
344
+ all_y2 = [c['bbox'][3] for c in cells_flat]
345
+ table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)]
346
+
347
+ # Init ground truth arrays
348
+ horizontal_gt = [0] * target_size
349
+ vertical_gt = [0] * target_size
350
+
351
+ # Helper function to scale and mark range
352
+ def mark_range(gt_array, start, end, orig_dim):
353
+ """Mark all pixels from start to end (scaled to target_size)"""
354
+ start_scaled = int(round(start * target_size / orig_dim))
355
+ end_scaled = int(round(end * target_size / orig_dim))
356
+ for pos in range(start_scaled, min(end_scaled + 1, target_size)):
357
+ if 0 <= pos < target_size:
358
+ gt_array[pos] = 1
359
+
360
+ # Mark HORIZONTAL gaps (between rows)
361
+ # 1. Gap from image top to first row top
362
+ if row_boundaries:
363
+ mark_range(horizontal_gt, 0, row_boundaries[0]['min_y'], orig_height)
364
+
365
+ # 2. Gaps between consecutive rows
366
+ for i in range(len(row_boundaries) - 1):
367
+ gap_start = row_boundaries[i]['max_y']
368
+ gap_end = row_boundaries[i + 1]['min_y']
369
+ if gap_end > gap_start: # Only mark if there's actual gap
370
+ mark_range(horizontal_gt, gap_start, gap_end, orig_height)
371
+ else:
372
+ # No gap or overlap - find actual split position
373
+ curr_row_y2 = [cells_flat[cell_id]['bbox'][3] for cell_id in row_boundaries[i]['row_cells']]
374
+ next_row_y1 = [cells_flat[cell_id]['bbox'][1] for cell_id in row_boundaries[i + 1]['row_cells']]
375
+
376
+ max_curr_y2 = max(curr_row_y2)
377
+ min_next_y1 = min(next_row_y1)
378
+
379
+ # Mark between the actual closest cells
380
+ if min_next_y1 > max_curr_y2:
381
+ mark_range(horizontal_gt, max_curr_y2, min_next_y1, orig_height)
382
+ else:
383
+ # Overlap - mark fixed width at midpoint
384
+ split_pos = (max_curr_y2 + min_next_y1) / 2
385
+ mark_range(horizontal_gt, split_pos - split_width/2, split_pos + split_width/2, orig_height)
386
+
387
+ # 3. Gap from last row bottom to image bottom
388
+ if row_boundaries:
389
+ mark_range(horizontal_gt, row_boundaries[-1]['max_y'], orig_height, orig_height)
390
+
391
+ # Mark VERTICAL gaps (between columns)
392
+ # 1. Gap from image left to first column left
393
+ if col_boundaries:
394
+ mark_range(vertical_gt, 0, col_boundaries[0]['min_x'], orig_width)
395
+
396
+ # 2. Gaps between consecutive columns
397
+ for i in range(len(col_boundaries) - 1):
398
+ gap_start = col_boundaries[i]['max_x']
399
+ gap_end = col_boundaries[i + 1]['min_x']
400
+
401
+ if gap_end > gap_start: # Actual gap exists
402
+ mark_range(vertical_gt, gap_start, gap_end, orig_width)
403
+ else:
404
+ # No gap or overlap - use col_cells to find actual split position
405
+ curr_col_x2 = [cells_flat[cell_id]['bbox'][2] for cell_id in col_boundaries[i]['col_cells']]
406
+ next_col_x1 = [cells_flat[cell_id]['bbox'][0] for cell_id in col_boundaries[i + 1]['col_cells']]
407
+
408
+ max_curr_x2 = max(curr_col_x2)
409
+ min_next_x1 = min(next_col_x1)
410
+
411
+ # Mark between the actual closest cells
412
+ if min_next_x1 > max_curr_x2:
413
+ mark_range(vertical_gt, max_curr_x2, min_next_x1, orig_width)
414
+ else:
415
+ # Overlap case - mark fixed width at midpoint
416
+ split_pos = (max_curr_x2 + min_next_x1) / 2
417
+ mark_range(vertical_gt, split_pos - split_width/2, split_pos + split_width/2, orig_width)
418
+
419
+ # 3. Gap from last column right to image right
420
+ if col_boundaries:
421
+ mark_range(vertical_gt, col_boundaries[-1]['max_x'], orig_width, orig_width)
422
+
423
+ return horizontal_gt, vertical_gt
424
+
425
+
426
+ class BasicBlock(nn.Module):
427
+ """Basic ResNet block with halved channels"""
428
+ def __init__(self, inplanes, planes, stride=1):
429
+ super().__init__()
430
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
431
+ self.bn1 = nn.BatchNorm2d(planes)
432
+ self.relu = nn.ReLU(inplace=True)
433
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
434
+ self.bn2 = nn.BatchNorm2d(planes)
435
+
436
+ self.downsample = None
437
+ if stride != 1 or inplanes != planes:
438
+ self.downsample = nn.Sequential(
439
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
440
+ nn.BatchNorm2d(planes)
441
+ )
442
+
443
+ def forward(self, x):
444
+ residual = x
445
+ out = self.conv1(x)
446
+ out = self.bn1(out)
447
+ out = self.relu(out)
448
+ out = self.conv2(out)
449
+ out = self.bn2(out)
450
+
451
+ if self.downsample is not None:
452
+ residual = self.downsample(x)
453
+
454
+ out += residual
455
+ out = self.relu(out)
456
+ return out
457
+
458
+ class ModifiedResNet18(nn.Module):
459
+ """ResNet-18 with removed maxpool and halved channels"""
460
+ def __init__(self):
461
+ super().__init__()
462
+ # First conv block - halved channels: 64→32
463
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False)
464
+ self.bn1 = nn.BatchNorm2d(32)
465
+ self.relu = nn.ReLU(inplace=True)
466
+ # Skip maxpool - this is the removal mentioned in paper
467
+
468
+ # ResNet layers with halved channels
469
+ self.layer1 = self._make_layer(32, 32, 2, stride=1) # Original: 64
470
+ self.layer2 = self._make_layer(32, 64, 2, stride=2) # Original: 128
471
+ self.layer3 = self._make_layer(64, 128, 2, stride=2) # Original: 256
472
+ self.layer4 = self._make_layer(128, 256, 2, stride=2) # Original: 512
473
+
474
+ def _make_layer(self, inplanes, planes, blocks, stride=1):
475
+ layers = []
476
+ layers.append(BasicBlock(inplanes, planes, stride))
477
+ for _ in range(1, blocks):
478
+ layers.append(BasicBlock(planes, planes))
479
+ return nn.Sequential(*layers)
480
+
481
+ def forward(self, x):
482
+ x = self.conv1(x) # [B, 32, 480, 480]
483
+ x = self.bn1(x)
484
+ x = self.relu(x)
485
+ # No maxpool here - this is the key modification
486
+
487
+ x = self.layer1(x) # [B, 32, 480, 480]
488
+ x = self.layer2(x) # [B, 64, 240, 240]
489
+ x = self.layer3(x) # [B, 128, 120, 120]
490
+ x = self.layer4(x) # [B, 256, 60, 60]
491
+ return x
492
+
493
+ class FPN(nn.Module):
494
+ """Feature Pyramid Network outputting 128 channels at H/2×W/2"""
495
+ def __init__(self):
496
+ super().__init__()
497
+ self.conv = nn.Conv2d(256, 128, kernel_size=1)
498
+
499
+ def forward(self, x):
500
+ # x is [B, 256, 60, 60] from ResNet
501
+ x = self.conv(x) # [B, 128, 60, 60]
502
+ # Upsample to H/2×W/2 = 480×480
503
+ x = F.interpolate(x, size=(480, 480), mode='bilinear', align_corners=False)
504
+ return x # [B, 128, 480, 480]
505
+
506
+ class SplitModel(nn.Module):
507
+ def __init__(self):
508
+ super().__init__()
509
+ self.backbone = ModifiedResNet18()
510
+ self.fpn = FPN()
511
+
512
+ # Learnable weights for global feature averaging
513
+ self.h_global_weight = nn.Parameter(torch.randn(480)) # For width dimension
514
+ self.v_global_weight = nn.Parameter(torch.randn(480)) # For height dimension
515
+
516
+ # Local feature processing - reduce to 1 channel then treat spatial as features
517
+ self.h_local_conv = nn.Conv2d(128, 1, kernel_size=1)
518
+ self.v_local_conv = nn.Conv2d(128, 1, kernel_size=1)
519
+
520
+ # Fix: Correct feature dimensions - 128 + W/4 = 128 + 120 = 248
521
+ feature_dim = 128 + 120 # Global + Local features
522
+
523
+ # Positional embeddings (1D as mentioned in paper)
524
+ self.h_pos_embed = nn.Parameter(torch.randn(480, feature_dim))
525
+ self.v_pos_embed = nn.Parameter(torch.randn(480, feature_dim))
526
+
527
+ # Transformers with correct dimensions
528
+ self.h_transformer = nn.TransformerEncoder(
529
+ nn.TransformerEncoderLayer(
530
+ d_model=feature_dim, nhead=8, dim_feedforward=2048,
531
+ dropout=0.1, batch_first=True
532
+ ),
533
+ num_layers=3
534
+ )
535
+ self.v_transformer = nn.TransformerEncoder(
536
+ nn.TransformerEncoderLayer(
537
+ d_model=feature_dim, nhead=8, dim_feedforward=2048,
538
+ dropout=0.1, batch_first=True
539
+ ),
540
+ num_layers=3
541
+ )
542
+
543
+ # Classification heads
544
+ self.h_classifier = nn.Linear(feature_dim, 1)
545
+ self.v_classifier = nn.Linear(feature_dim, 1)
546
+
547
+ def forward(self, x):
548
+ # Input: [B, 3, 960, 960]
549
+ features = self.backbone(x) # [B, 256, 60, 60]
550
+ F_half = self.fpn(features) # [B, 128, 480, 480] - This is F1/2
551
+
552
+ B, C, H, W = F_half.shape # B, 128, 480, 480
553
+
554
+ # HORIZONTAL FEATURES (for row splitting)
555
+ # Global: learnable weighted average along width dimension
556
+ F_RG = torch.einsum('bchw,w->bch', F_half, self.h_global_weight) # [B, 128, 480]
557
+ F_RG = F_RG.transpose(1, 2) # [B, 480, 128]
558
+
559
+ # Local: 1×4 AvgPool to get 120 features (W/4), then 1×1 conv to 1 channel
560
+ F_RL_pooled = F.avg_pool2d(F_half, kernel_size=(1, 4)) # [B, 128, 480, 120]
561
+ F_RL = self.h_local_conv(F_RL_pooled) # [B, 1, 480, 120]
562
+ F_RL = F_RL.squeeze(1) # [B, 480, 120] - spatial becomes features
563
+
564
+ # Concatenate: [B, 480, 128+120=248]
565
+ F_RG_L = torch.cat([F_RG, F_RL], dim=2)
566
+
567
+ # Add positional embeddings
568
+ F_RG_L = F_RG_L + self.h_pos_embed
569
+
570
+ # VERTICAL FEATURES (for column splitting)
571
+ # Global: learnable weighted average along height dimension
572
+ F_CG = torch.einsum('bchw,h->bcw', F_half, self.v_global_weight) # [B, 128, 480]
573
+ F_CG = F_CG.transpose(1, 2) # [B, 480, 128]
574
+
575
+ # Local: 4×1 AvgPool to get 120 features (H/4), then 1×1 conv to 1 channel
576
+ F_CL_pooled = F.avg_pool2d(F_half, kernel_size=(4, 1)) # [B, 128, 120, 480]
577
+ F_CL = self.v_local_conv(F_CL_pooled) # [B, 1, 120, 480]
578
+ F_CL = F_CL.squeeze(1) # [B, 120, 480]
579
+ F_CL = F_CL.transpose(1, 2) # [B, 480, 120] - transpose to get spatial as features
580
+
581
+ # Concatenate: [B, 480, 128+120=248]
582
+ F_CG_L = torch.cat([F_CG, F_CL], dim=2)
583
+
584
+ # Add positional embeddings
585
+ F_CG_L = F_CG_L + self.v_pos_embed
586
+
587
+ # Transformer processing
588
+ F_R = self.h_transformer(F_RG_L) # [B, 480, 368]
589
+ F_C = self.v_transformer(F_CG_L) # [B, 480, 368]
590
+
591
+ # Binary classification at 480 resolution
592
+ h_logits = self.h_classifier(F_R).squeeze(-1) # [B, 480]
593
+ v_logits = self.v_classifier(F_C).squeeze(-1) # [B, 480]
594
+
595
+ # return at 480 resolution (upsample happens AFTER loss computation)
596
+ return torch.sigmoid(h_logits), torch.sigmoid(v_logits) # [B, 480]
597
+
598
+ def focal_loss(predictions, targets, alpha=1.0, gamma=2.0):
599
+ """Focal loss exactly as specified in paper"""
600
+ ce_loss = F.binary_cross_entropy(predictions, targets, reduction='none')
601
+ pt = torch.where(targets == 1, predictions, 1 - predictions)
602
+ focal_weight = alpha * (1 - pt) ** gamma
603
+ return (focal_weight * ce_loss).mean()
604
+
605
+ def post_process_predictions(h_pred, v_pred, threshold=0.5):
606
+ """
607
+ Simple post-processing to convert predictions to binary masks
608
+ """
609
+ h_binary = (h_pred > threshold).float()
610
+ v_binary = (v_pred > threshold).float()
611
+
612
+ return h_binary, v_binary
613
+
614
+ class TableDataset(Dataset):
615
+ def __init__(self, hf_dataset):
616
+ self.hf_dataset = hf_dataset
617
+ self.transform = transforms.Compose([
618
+ transforms.Resize((960, 960)),
619
+ transforms.ToTensor(),
620
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
621
+ ])
622
+
623
+ def __len__(self):
624
+ return len(self.hf_dataset)
625
+
626
+ def __getitem__(self, idx):
627
+ item = self.hf_dataset[idx]
628
+
629
+ image = item['image'].convert('RGB')
630
+ image_transformed = self.transform(image)
631
+
632
+ # generate GT at 960 resolution
633
+ h_gt_960, v_gt_960 = get_ground_truth_auto_gap(
634
+ item['image'], # original PIL image for dimensions
635
+ item['cells'],
636
+ item['otsl'],
637
+ )
638
+
639
+ # downsample to 480 for loss computation (take every 2nd element)
640
+ h_gt_480 = [h_gt_960[i] for i in range(0, 960, 2)] # [480]
641
+ v_gt_480 = [v_gt_960[i] for i in range(0, 960, 2)] # [480]
642
+
643
+ return (
644
+ image_transformed,
645
+ torch.tensor(h_gt_480, dtype=torch.float), # [480] for training loss
646
+ torch.tensor(v_gt_480, dtype=torch.float), # [480] for training loss
647
+ torch.tensor(h_gt_960, dtype=torch.float), # [960] for metrics
648
+ torch.tensor(v_gt_960, dtype=torch.float), # [960] for metrics
649
+ )
650
+
651
+