Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # Swin Transformer | |
| # Copyright (c) 2021 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Ze Liu | |
| # -------------------------------------------------------- | |
| import os | |
| import zipfile | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| from PIL import ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| def is_zip_path(img_or_path): | |
| """judge if this is a zip path""" | |
| return '.zip@' in img_or_path | |
| class ZipReader(object): | |
| """A class to read zipped files""" | |
| zip_bank = dict() | |
| def __init__(self): | |
| super(ZipReader, self).__init__() | |
| def get_zipfile(path): | |
| zip_bank = ZipReader.zip_bank | |
| if path not in zip_bank: | |
| zfile = zipfile.ZipFile(path, 'r') | |
| zip_bank[path] = zfile | |
| return zip_bank[path] | |
| def split_zip_style_path(path): | |
| pos_at = path.index('@') | |
| assert pos_at != -1, "character '@' is not found from the given path '%s'" % path | |
| zip_path = path[0: pos_at] | |
| folder_path = path[pos_at + 1:] | |
| folder_path = str.strip(folder_path, '/') | |
| return zip_path, folder_path | |
| def list_folder(path): | |
| zip_path, folder_path = ZipReader.split_zip_style_path(path) | |
| zfile = ZipReader.get_zipfile(zip_path) | |
| folder_list = [] | |
| for file_foler_name in zfile.namelist(): | |
| file_foler_name = str.strip(file_foler_name, '/') | |
| if file_foler_name.startswith(folder_path) and \ | |
| len(os.path.splitext(file_foler_name)[-1]) == 0 and \ | |
| file_foler_name != folder_path: | |
| if len(folder_path) == 0: | |
| folder_list.append(file_foler_name) | |
| else: | |
| folder_list.append(file_foler_name[len(folder_path) + 1:]) | |
| return folder_list | |
| def list_files(path, extension=None): | |
| if extension is None: | |
| extension = ['.*'] | |
| zip_path, folder_path = ZipReader.split_zip_style_path(path) | |
| zfile = ZipReader.get_zipfile(zip_path) | |
| file_lists = [] | |
| for file_foler_name in zfile.namelist(): | |
| file_foler_name = str.strip(file_foler_name, '/') | |
| if file_foler_name.startswith(folder_path) and \ | |
| str.lower(os.path.splitext(file_foler_name)[-1]) in extension: | |
| if len(folder_path) == 0: | |
| file_lists.append(file_foler_name) | |
| else: | |
| file_lists.append(file_foler_name[len(folder_path) + 1:]) | |
| return file_lists | |
| def read(path): | |
| zip_path, path_img = ZipReader.split_zip_style_path(path) | |
| zfile = ZipReader.get_zipfile(zip_path) | |
| data = zfile.read(path_img) | |
| return data | |
| def imread(path): | |
| zip_path, path_img = ZipReader.split_zip_style_path(path) | |
| zfile = ZipReader.get_zipfile(zip_path) | |
| data = zfile.read(path_img) | |
| try: | |
| im = Image.open(io.BytesIO(data)) | |
| except: | |
| print("ERROR IMG LOADED: ", path_img) | |
| random_img = np.random.rand(224, 224, 3) * 255 | |
| im = Image.fromarray(np.uint8(random_img)) | |
| return im | |