SamaritanOCR / py3 /e2e /e2e_model.py
johnlockejrr's picture
Upload 80 files
43bca44 verified
import torch
import torch.nn as nn
from torch.autograd import Variable
import cv2
import numpy as np
from utils import string_utils, error_rates
from utils import transformation_utils
from . import handwriting_alignment_loss
from . import e2e_postprocessing
import copy
from scipy.optimize import linear_sum_assignment
import math
#from pynvml import *
# max_lines_per_image is the max lines in a batch for HW to process
class E2EModel(nn.Module):
def __init__(self, sol, lf, hw, dtype=torch.cuda.FloatTensor, max_lines_per_image=8, device="cuda"):
super(E2EModel, self).__init__()
self.dtype = dtype
self.sol = sol
self.lf = lf
self.hw = hw
self.line = None
self.max_lines_per_image = max_lines_per_image
self.device=device
def train(self):
self.sol.train()
self.lf.train()
self.hw.train()
def eval(self):
self.sol.eval()
self.lf.eval()
self.hw.eval()
def forward(self, x, use_full_img=True, accpet_threshold=0.1, volatile=True, gt_lines=None,
idx_to_char=None, HW_cuda=0, device="cuda"):
if device != self.device:
print('Wrong device is set', 'param', device, 'self', self.device)
asldjfdkfj
sol_img = Variable(x['resized_img'].type(self.dtype), requires_grad=False)
if use_full_img:
img = Variable(x['full_img'].type(self.dtype), requires_grad=False)
scale = x['resize_scale']
results_scale = 1.0
else:
img = sol_img
scale = 1.0
results_scale = x['resize_scale']
original_starts = self.sol(sol_img)
start = original_starts
#Take at least one point
sorted_start, sorted_indices = torch.sort(start[...,0:1], dim=1, descending=True)
#print("sorted_start size", sorted_start.size())
#print("sorted_start", sorted_start)
min_threshold = sorted_start[0,1,0].data
accpet_threshold = min(accpet_threshold, min_threshold)
# There should not be more than 56 points to avoid out of memory
if sorted_start.size()[1] > 56:
accpet_threshold = max(accpet_threshold, sorted_start[0,55,0].data)
#print('using accept_threshold', accpet_threshold, sorted_start[0,55,0].data)
select = original_starts[...,0:1] >= accpet_threshold
select_idx = np.where(select.data.cpu().numpy())[1]
select = select.expand(select.size(0), select.size(1), start.size(2))
select = select.detach()
start = start[select].view(start.size(0), -1, start.size(2))
perform_forward = len(start.size()) == 3
if not perform_forward:
return None
forward_img = img
start = start.transpose(0,1)
positions = torch.cat([
start[...,1:3] * scale,
start[...,3:4],
start[...,4:5] * scale,
start[...,0:1]
], 2)
#print('positions size', positions.size())
hw_out = []
p_interval = positions.size(0)
lf_xy_positions = None
line_imgs = []
# show_mem_status(1, "before for in FORWARD")
for p in range(0,min(positions.size(0), np.inf), p_interval):
sub_positions = positions[p:p+p_interval,0,:]
sub_select_idx = select_idx[p:p+p_interval]
batch_size = sub_positions.size(0)
sub_positions = [sub_positions]
# print(sub_positions)
# sys.exit()
expand_img = forward_img.expand(sub_positions[0].size(0), img.size(1), img.size(2), img.size(3))
step_size = 8 #5
extra_bw = 1 #1
forward_steps = 30 #40
grid_line, _, out_positions, xy_positions = self.lf(expand_img, sub_positions, steps=step_size)
grid_line, _, out_positions, xy_positions = self.lf(expand_img, [out_positions[step_size]], steps=step_size+extra_bw, negate_lw=True)
grid_line, _, out_positions, xy_positions = self.lf(expand_img, [out_positions[step_size+extra_bw]], steps=forward_steps, allow_end_early=True)
#show_mem_status(1, 'after lf')
if lf_xy_positions is None:
lf_xy_positions = xy_positions
else:
for i in range(len(lf_xy_positions)):
lf_xy_positions[i] = torch.cat([
lf_xy_positions[i],
xy_positions[i]
])
expand_img = expand_img.transpose(2,3)
hw_interval = p_interval
for h in range(0,min(grid_line.size(0), np.inf), hw_interval):
sub_out_positions = [o[h:h+hw_interval] for o in out_positions]
sub_xy_positions = [o[h:h+hw_interval] for o in xy_positions]
sub_sub_select_idx = sub_select_idx[h:h+hw_interval]
line = torch.nn.functional.grid_sample(expand_img[h:h+hw_interval].detach(), grid_line[h:h+hw_interval], align_corners=True)
line = line.transpose(2,3)
for l in line:
l = l.transpose(0,1).transpose(1,2)
l = (l + 1)*128
l_np = l.data.cpu().numpy()
line_imgs.append(l_np)
# cv2.imwrite("example_line_out.png", l_np)
# print "Saved!"
# raw_input()
# REsize to 60 ht
# Mehreen add: To avoid out of memory errors. A large batch has to be split up for HW network to process
# This case will arise when SOL finds too many lines on a page
batch, channels, old_ht, old_width = line.size()
line = line.detach().cpu()
total_todo = batch
#show_mem_status(0, '.... Before hw line')
start_index = 0
while total_todo > 0:
mini_batch_size = min(self.max_lines_per_image, total_todo)
partial_lines = line[start_index:start_index+mini_batch_size, :, :, :]
#print('start_index, end_index', start_index, start_index+mini_batch_size)
start_index += mini_batch_size
total_todo = total_todo - mini_batch_size
#print('partial_line size', partial_lines.size())
partial_lines = partial_lines.to(self.device)
out = self.hw(partial_lines)
if "cuda" in device:
torch.cuda.empty_cache()
out = out.transpose(0, 1)
hw_out.append(out)
#print('batch size: ', batch)
# new_ht = 60
# new_width = int(old_width/old_ht*new_ht)
#print('line type', type(line), line.size())
# self.line = nn.functional.interpolate(line, size=(new_ht, new_width),
# mode='bilinear', align_corners=True)
# Mehreen commented out for processing entire batch in one go
# out = self.hw(line)
# out = out.transpose(0,1)
# hw_out.append(out)
#show_mem_status(0, '.... After hw line')
hw_out = torch.cat(hw_out, 0)
# print(original_starts,positions,lf_xy_positions,hw_out,results_scale,line_imgs)
return {
"original_sol": original_starts,
"sol": positions,
"lf": lf_xy_positions,
"hw": hw_out,
"results_scale": results_scale,
"line_imgs": line_imgs
}