Spaces:
Build error
Build error
Commit
·
6ccedcd
1
Parent(s):
dffa77d
updated app
Browse files
app.py
CHANGED
|
@@ -80,6 +80,90 @@ elif modelName=="parseq":
|
|
| 80 |
opt.scorer = "mean"
|
| 81 |
opt.blackbg = True
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# x = st.slider('Select a value')
|
| 84 |
# st.write(x, 'squared is', x * x)
|
| 85 |
|
|
@@ -99,6 +183,7 @@ if uploaded_file is not None:
|
|
| 99 |
# To read file as bytes:
|
| 100 |
bytes_data = uploaded_file.getvalue()
|
| 101 |
pilImg = Image.open(uploaded_file)
|
|
|
|
| 102 |
|
| 103 |
orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
|
| 104 |
img1 = orig_img_tensors.to(device)
|
|
|
|
| 80 |
opt.scorer = "mean"
|
| 81 |
opt.blackbg = True
|
| 82 |
|
| 83 |
+
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
|
| 84 |
+
max_dist=200, ratio=0.2,
|
| 85 |
+
random_seed=random.randint(0, 1000))
|
| 86 |
+
|
| 87 |
+
if modelName=="vitstr":
|
| 88 |
+
if opt.Transformer:
|
| 89 |
+
converter = TokenLabelConverter(opt)
|
| 90 |
+
elif 'CTC' in opt.Prediction:
|
| 91 |
+
converter = CTCLabelConverter(opt.character)
|
| 92 |
+
else:
|
| 93 |
+
converter = AttnLabelConverter(opt.character)
|
| 94 |
+
opt.num_class = len(converter.character)
|
| 95 |
+
if opt.rgb:
|
| 96 |
+
opt.input_channel = 3
|
| 97 |
+
model_obj = Model(opt)
|
| 98 |
+
|
| 99 |
+
model = torch.nn.DataParallel(model_obj).to(device)
|
| 100 |
+
modelCopy = copy.deepcopy(model)
|
| 101 |
+
|
| 102 |
+
""" evaluation """
|
| 103 |
+
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
|
| 104 |
+
super_pixel_model_singlechar = torch.nn.Sequential(
|
| 105 |
+
# super_pixler,
|
| 106 |
+
# numpy2torch_converter,
|
| 107 |
+
modelCopy,
|
| 108 |
+
scoring_singlechar
|
| 109 |
+
).to(device)
|
| 110 |
+
modelCopy.eval()
|
| 111 |
+
scoring_singlechar.eval()
|
| 112 |
+
super_pixel_model_singlechar.eval()
|
| 113 |
+
|
| 114 |
+
# Single Char Attribution Averaging
|
| 115 |
+
# enableSingleCharAttrAve - set to True
|
| 116 |
+
scoring = STRScore(opt=opt, converter=converter, device=device)
|
| 117 |
+
super_pixel_model = torch.nn.Sequential(
|
| 118 |
+
# super_pixler,
|
| 119 |
+
# numpy2torch_converter,
|
| 120 |
+
model,
|
| 121 |
+
scoring
|
| 122 |
+
).to(device)
|
| 123 |
+
model.eval()
|
| 124 |
+
scoring.eval()
|
| 125 |
+
super_pixel_model.eval()
|
| 126 |
+
|
| 127 |
+
elif modelName=="parseq":
|
| 128 |
+
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True)
|
| 129 |
+
# checkpoint = torch.hub.load_state_dict_from_url('https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', map_location="cpu")
|
| 130 |
+
# # state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
|
| 131 |
+
# model.load_state_dict(checkpoint)
|
| 132 |
+
model = model.to(device)
|
| 133 |
+
model_obj = model
|
| 134 |
+
converter = TokenLabelConverter(opt)
|
| 135 |
+
modelCopy = copy.deepcopy(model)
|
| 136 |
+
|
| 137 |
+
""" evaluation """
|
| 138 |
+
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True, model=modelCopy)
|
| 139 |
+
super_pixel_model_singlechar = torch.nn.Sequential(
|
| 140 |
+
# super_pixler,
|
| 141 |
+
# numpy2torch_converter,
|
| 142 |
+
modelCopy,
|
| 143 |
+
scoring_singlechar
|
| 144 |
+
).to(device)
|
| 145 |
+
modelCopy.eval()
|
| 146 |
+
scoring_singlechar.eval()
|
| 147 |
+
super_pixel_model_singlechar.eval()
|
| 148 |
+
|
| 149 |
+
# Single Char Attribution Averaging
|
| 150 |
+
# enableSingleCharAttrAve - set to True
|
| 151 |
+
scoring = STRScore(opt=opt, converter=converter, device=device, model=model)
|
| 152 |
+
super_pixel_model = torch.nn.Sequential(
|
| 153 |
+
# super_pixler,
|
| 154 |
+
# numpy2torch_converter,
|
| 155 |
+
model,
|
| 156 |
+
scoring
|
| 157 |
+
).to(device)
|
| 158 |
+
model.eval()
|
| 159 |
+
scoring.eval()
|
| 160 |
+
super_pixel_model.eval()
|
| 161 |
+
|
| 162 |
+
if opt.blackbg:
|
| 163 |
+
shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32)
|
| 164 |
+
trainList = np.array(shapImgLs)
|
| 165 |
+
background = torch.from_numpy(trainList).to(device)
|
| 166 |
+
|
| 167 |
# x = st.slider('Select a value')
|
| 168 |
# st.write(x, 'squared is', x * x)
|
| 169 |
|
|
|
|
| 183 |
# To read file as bytes:
|
| 184 |
bytes_data = uploaded_file.getvalue()
|
| 185 |
pilImg = Image.open(uploaded_file)
|
| 186 |
+
pilImg = pilImg.resize((opt.imgW, opt.imgH))
|
| 187 |
|
| 188 |
orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
|
| 189 |
img1 = orig_img_tensors.to(device)
|