Spaces:
Build error
Build error
Commit
·
918d78a
1
Parent(s):
b4759d0
updated app
Browse files
app.py
CHANGED
|
@@ -58,7 +58,6 @@ from captum.attr._utils.visualization import visualize_image_attr
|
|
| 58 |
device = torch.device('cpu')
|
| 59 |
opt = get_args(is_train=False)
|
| 60 |
|
| 61 |
-
""" vocab / character number configuration """
|
| 62 |
if opt.sensitive:
|
| 63 |
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
| 64 |
|
|
@@ -125,7 +124,6 @@ if modelName=="vitstr":
|
|
| 125 |
model = torch.nn.DataParallel(model_obj).to(device)
|
| 126 |
modelCopy = copy.deepcopy(model)
|
| 127 |
|
| 128 |
-
""" evaluation """
|
| 129 |
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
|
| 130 |
super_pixel_model_singlechar = torch.nn.Sequential(
|
| 131 |
# super_pixler,
|
|
@@ -193,7 +191,25 @@ if opt.blackbg:
|
|
| 193 |
# x = st.slider('Select a value')
|
| 194 |
# st.write(x, 'squared is', x * x)
|
| 195 |
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
|
| 199 |
image2 = Image.open('demo_image/demo_ronaldo.jpg') #Brand logo image (optional)
|
|
|
|
| 58 |
device = torch.device('cpu')
|
| 59 |
opt = get_args(is_train=False)
|
| 60 |
|
|
|
|
| 61 |
if opt.sensitive:
|
| 62 |
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
| 63 |
|
|
|
|
| 124 |
model = torch.nn.DataParallel(model_obj).to(device)
|
| 125 |
modelCopy = copy.deepcopy(model)
|
| 126 |
|
|
|
|
| 127 |
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
|
| 128 |
super_pixel_model_singlechar = torch.nn.Sequential(
|
| 129 |
# super_pixler,
|
|
|
|
| 191 |
# x = st.slider('Select a value')
|
| 192 |
# st.write(x, 'squared is', x * x)
|
| 193 |
|
| 194 |
+
### Acquire pixelwise attributions and replace them with ranked numbers averaged
|
| 195 |
+
### across segmentation with the largest contribution having the largest number
|
| 196 |
+
### and the smallest set to 1, which is the minimum number.
|
| 197 |
+
### attr - original attribution
|
| 198 |
+
### segm - image segmentations
|
| 199 |
+
def rankedAttributionsBySegm(attr, segm):
|
| 200 |
+
aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm)
|
| 201 |
+
totalSegm = len(sortedDict.keys()) # total segmentations
|
| 202 |
+
sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])]
|
| 203 |
+
sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score
|
| 204 |
+
currentRank = totalSegm
|
| 205 |
+
rankedSegmImg = torch.clone(attr)
|
| 206 |
+
for totalSegToHide in range(0, len(sortedKeys)):
|
| 207 |
+
currentSegmentToHide = sortedKeys[totalSegToHide]
|
| 208 |
+
rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank
|
| 209 |
+
currentRank -= 1
|
| 210 |
+
return rankedSegmImg
|
| 211 |
+
|
| 212 |
+
labels = st.text_input('You need to put the text of the image here (e.g. BALLYS)')
|
| 213 |
|
| 214 |
image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
|
| 215 |
image2 = Image.open('demo_image/demo_ronaldo.jpg') #Brand logo image (optional)
|