Spaces:
Build error
Build error
| import os | |
| import pickle | |
| from captum_improve_vitstr import rankedAttributionsBySegm | |
| import matplotlib.pyplot as plt | |
| from skimage.color import gray2rgb | |
| from captum.attr._utils.visualization import visualize_image_attr | |
| import torch | |
| import numpy as np | |
| def attr_one_dataset(): | |
| modelName = "vitstr" | |
| datasetName = "IIIT5k_3000" | |
| rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/" | |
| attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/" | |
| if not os.path.exists(attrOutputImgs): | |
| os.makedirs(attrOutputImgs) | |
| minNumber = 1000000 | |
| maxNumber = 0 | |
| # From a folder containing saved attribution pickle files, convert them into attribution images | |
| for path, subdirs, files in os.walk(rootDir): | |
| for name in files: | |
| fullfilename = os.path.join(rootDir, name) # Value | |
| # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl | |
| partfilename = fullfilename[fullfilename.rfind('/')+1:] | |
| print("fullfilename: ", fullfilename) | |
| imgNum = int(partfilename.split('_')[0]) | |
| attrImgName = partfilename.replace('.pkl', '.png') | |
| minNumber = min(minNumber, imgNum) | |
| maxNumber = max(maxNumber, imgNum) | |
| with open(fullfilename, 'rb') as f: | |
| pklData = pickle.load(f) | |
| attributions = pklData['attribution'] | |
| segmDataNP = pklData['segmData'] | |
| origImgNP = pklData['origImg'] | |
| if np.isnan(attributions).any(): | |
| continue | |
| attributions = torch.from_numpy(attributions) | |
| rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) | |
| rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] | |
| rankedAttr = gray2rgb(rankedAttr) | |
| mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') | |
| mplotfig.savefig(attrOutputImgs + attrImgName) | |
| mplotfig.clear() | |
| plt.close(mplotfig) | |
| def attr_all_dataset(): | |
| modelName = "vitstr" | |
| datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] | |
| for datasetName in datasetNameList: | |
| rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/" | |
| attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/" | |
| if not os.path.exists(attrOutputImgs): | |
| os.makedirs(attrOutputImgs) | |
| minNumber = 1000000 | |
| maxNumber = 0 | |
| # From a folder containing saved attribution pickle files, convert them into attribution images | |
| for path, subdirs, files in os.walk(rootDir): | |
| for name in files: | |
| fullfilename = os.path.join(rootDir, name) # Value | |
| # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl | |
| partfilename = fullfilename[fullfilename.rfind('/')+1:] | |
| imgNum = int(partfilename.split('_')[0]) | |
| attrImgName = partfilename.replace('.pkl', '.png') | |
| minNumber = min(minNumber, imgNum) | |
| maxNumber = max(maxNumber, imgNum) | |
| with open(fullfilename, 'rb') as f: | |
| pklData = pickle.load(f) | |
| attributions = pklData['attribution'] | |
| segmDataNP = pklData['segmData'] | |
| origImgNP = pklData['origImg'] | |
| attributions = torch.from_numpy(attributions) | |
| rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) | |
| rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] | |
| rankedAttr = gray2rgb(rankedAttr) | |
| mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') | |
| mplotfig.savefig(attrOutputImgs + attrImgName) | |
| mplotfig.clear() | |
| plt.close(mplotfig) | |
| if __name__ == '__main__': | |
| attr_one_dataset() | |
| # attr_all_dataset() | |