Spaces:
Runtime error
Runtime error
Move files to Clothes-Category-Classifier
Browse files- .gitignore +31 -0
- main.py +100 -0
- requirements.txt +6 -0
- src/datasets.py +91 -0
- src/model_jigsaw.py +235 -0
- src/utils.py +267 -0
.gitignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore Python bytecode files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Ignore virtual environments
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
myenv/
|
| 10 |
+
|
| 11 |
+
# Ignore Jupyter notebook checkpoints
|
| 12 |
+
.ipynb_checkpoints
|
| 13 |
+
|
| 14 |
+
# Ignore model files that are tracked by Git LFS
|
| 15 |
+
netBest.pth
|
| 16 |
+
|
| 17 |
+
# Ignore test images (if you don't want to include them in your repo)
|
| 18 |
+
test_images/
|
| 19 |
+
|
| 20 |
+
# Ignore log files and temporary files
|
| 21 |
+
*.log
|
| 22 |
+
*.tmp
|
| 23 |
+
*.bak
|
| 24 |
+
|
| 25 |
+
# Ignore system files
|
| 26 |
+
.DS_Store
|
| 27 |
+
Thumbs.db
|
| 28 |
+
|
| 29 |
+
# Ignore any other temporary or backup files
|
| 30 |
+
*.swp
|
| 31 |
+
*.swo
|
main.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
from torchvision import transforms, datasets # Correct import for torchvision
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Add the 'src' folder to the system path for module imports
|
| 11 |
+
sys.path.append('./src')
|
| 12 |
+
|
| 13 |
+
# Import from the 'src' folder
|
| 14 |
+
from model_jigsaw import mae_vit_small_patch16 # Import directly from src folder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Static Variables
|
| 18 |
+
MODEL_PATH = "model/netBest.pth"
|
| 19 |
+
TEST_IMAGE_FOLDER = "test_images/"
|
| 20 |
+
|
| 21 |
+
# Class Mapping from Clothing1M dataset
|
| 22 |
+
class_names = {
|
| 23 |
+
0: "T-shirt",
|
| 24 |
+
1: "Shirt",
|
| 25 |
+
2: "Knitwear",
|
| 26 |
+
3: "Chiffon",
|
| 27 |
+
4: "Sweater",
|
| 28 |
+
5: "Hoodie",
|
| 29 |
+
6: "Windbreaker",
|
| 30 |
+
7: "Jacket",
|
| 31 |
+
8: "Down Coat",
|
| 32 |
+
9: "Suit",
|
| 33 |
+
10: "Shawl",
|
| 34 |
+
11: "Dress",
|
| 35 |
+
12: "Vest",
|
| 36 |
+
13: "Nightwear"
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Image Preprocessing
|
| 40 |
+
def preprocess_image(image):
|
| 41 |
+
transform = transforms.Compose([
|
| 42 |
+
transforms.Resize((224, 224)),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
image = image.convert('RGB')
|
| 48 |
+
image = transform(image).unsqueeze(0)
|
| 49 |
+
return image
|
| 50 |
+
|
| 51 |
+
# Load Model
|
| 52 |
+
model = mae_vit_small_patch16(nb_cls=14)
|
| 53 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"))['net'])
|
| 54 |
+
model.eval()
|
| 55 |
+
|
| 56 |
+
# Load dataset mapping
|
| 57 |
+
val_dataset = torchvision.datasets.ImageFolder(root=TEST_IMAGE_FOLDER)
|
| 58 |
+
idx_to_class = {v: k for k, v in val_dataset.class_to_idx.items()}
|
| 59 |
+
|
| 60 |
+
def predict_single_image(image):
|
| 61 |
+
"""Predicts the class of an image using the exact logic from eval.py"""
|
| 62 |
+
image = preprocess_image(image)
|
| 63 |
+
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
outputs = model.forward_cls(image)
|
| 66 |
+
|
| 67 |
+
_, predicted_class = torch.max(outputs, 1)
|
| 68 |
+
|
| 69 |
+
# Get final class name
|
| 70 |
+
mapped_class_index = idx_to_class[predicted_class.item()]
|
| 71 |
+
final_class_name = class_names[int(mapped_class_index)]
|
| 72 |
+
|
| 73 |
+
return final_class_name
|
| 74 |
+
|
| 75 |
+
# Get all images from subfolders (recursively)
|
| 76 |
+
def get_image_paths(test_image_folder):
|
| 77 |
+
image_paths = []
|
| 78 |
+
for root, dirs, files in os.walk(test_image_folder):
|
| 79 |
+
for file in files:
|
| 80 |
+
if file.endswith((".jpg", ".png")):
|
| 81 |
+
image_paths.append(os.path.join(root, file))
|
| 82 |
+
return image_paths
|
| 83 |
+
|
| 84 |
+
test_image_files = get_image_paths(TEST_IMAGE_FOLDER)
|
| 85 |
+
|
| 86 |
+
def load_test_image(selected_image):
|
| 87 |
+
"""Loads the selected test image from test_images folder"""
|
| 88 |
+
return Image.open(selected_image)
|
| 89 |
+
|
| 90 |
+
# Create Gradio Interface with dynamic examples from test_images folder
|
| 91 |
+
demo = gr.Interface(
|
| 92 |
+
fn=predict_single_image,
|
| 93 |
+
inputs=gr.Image(type="pil", label="Upload an image"),
|
| 94 |
+
outputs=gr.Textbox(label="Predicted Category"),
|
| 95 |
+
title="Clothes Category Classifier",
|
| 96 |
+
description="Upload an image to classify its clothing category.",
|
| 97 |
+
examples=test_image_files # Use the correct paths for example images
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
Pillow
|
| 5 |
+
einops
|
| 6 |
+
timm
|
src/datasets.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import PIL
|
| 13 |
+
|
| 14 |
+
from torchvision import datasets, transforms
|
| 15 |
+
|
| 16 |
+
from timm.data import create_transform
|
| 17 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_dataset(is_train, args) :
|
| 21 |
+
transform = build_transform(is_train, args)
|
| 22 |
+
|
| 23 |
+
if args.data_set == 'CIFAR10' :
|
| 24 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
| 25 |
+
nb_cls = 10
|
| 26 |
+
elif args.data_set == 'CIFAR100' :
|
| 27 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
| 28 |
+
nb_cls = 100
|
| 29 |
+
elif args.data_set == 'Animal10N' :
|
| 30 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'test')
|
| 31 |
+
nb_cls = 10
|
| 32 |
+
elif args.data_set == 'Clothing1M' :
|
| 33 |
+
# we use a randomly selected balanced training subset
|
| 34 |
+
root = os.path.join(args.data_path, 'noisy_rand_subtrain' if is_train else 'clean_val')
|
| 35 |
+
nb_cls = 14
|
| 36 |
+
elif args.data_set == 'Food101N' :
|
| 37 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'test')
|
| 38 |
+
nb_cls = 101
|
| 39 |
+
|
| 40 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
| 41 |
+
|
| 42 |
+
print(dataset)
|
| 43 |
+
|
| 44 |
+
return dataset, nb_cls
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_transform(is_train, args) :
|
| 48 |
+
if args.data_set == 'CIFAR10' or args.data_set == 'CIFAR100' :
|
| 49 |
+
mean = (0.4914, 0.4822, 0.4465)
|
| 50 |
+
std = (0.2023, 0.1994, 0.2010)
|
| 51 |
+
else :
|
| 52 |
+
mean = IMAGENET_DEFAULT_MEAN
|
| 53 |
+
std = IMAGENET_DEFAULT_STD
|
| 54 |
+
|
| 55 |
+
resize_im = args.input_size > 32
|
| 56 |
+
if is_train :
|
| 57 |
+
# this should always dispatch to transforms_imagenet_train
|
| 58 |
+
transform = create_transform(
|
| 59 |
+
input_size=args.input_size,
|
| 60 |
+
is_training=True,
|
| 61 |
+
color_jitter=args.color_jitter,
|
| 62 |
+
auto_augment=args.aa,
|
| 63 |
+
interpolation='bicubic',
|
| 64 |
+
re_prob=args.reprob,
|
| 65 |
+
re_mode=args.remode,
|
| 66 |
+
re_count=args.recount,
|
| 67 |
+
mean=mean,
|
| 68 |
+
std=std,
|
| 69 |
+
)
|
| 70 |
+
if not resize_im :
|
| 71 |
+
# replace RandomResizedCropAndInterpolation with
|
| 72 |
+
# RandomCrop
|
| 73 |
+
transform.transforms[0] = transforms.RandomCrop(
|
| 74 |
+
args.input_size, padding=4)
|
| 75 |
+
return transform
|
| 76 |
+
|
| 77 |
+
# eval transform
|
| 78 |
+
t = []
|
| 79 |
+
if args.input_size <= 224 :
|
| 80 |
+
crop_pct = 224 / 256
|
| 81 |
+
else :
|
| 82 |
+
crop_pct = 1.0
|
| 83 |
+
size = int(args.input_size / crop_pct)
|
| 84 |
+
t.append(
|
| 85 |
+
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
|
| 86 |
+
)
|
| 87 |
+
t.append(transforms.CenterCrop(args.input_size))
|
| 88 |
+
|
| 89 |
+
t.append(transforms.ToTensor())
|
| 90 |
+
t.append(transforms.Normalize(mean, std))
|
| 91 |
+
return transforms.Compose(t)
|
src/model_jigsaw.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import einops
|
| 17 |
+
|
| 18 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
| 19 |
+
|
| 20 |
+
import utils
|
| 21 |
+
|
| 22 |
+
class MaskedAutoencoderViT(nn.Module):
|
| 23 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self,
|
| 26 |
+
nb_cls=10,
|
| 27 |
+
img_size=224,
|
| 28 |
+
patch_size=16,
|
| 29 |
+
in_chans=3,
|
| 30 |
+
embed_dim=1024,
|
| 31 |
+
depth=24,
|
| 32 |
+
num_heads=16,
|
| 33 |
+
mlp_ratio=4.,
|
| 34 |
+
norm_layer=nn.LayerNorm):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
# --------------------------------------------------------------------------
|
| 38 |
+
# MAE encoder specifics
|
| 39 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
| 40 |
+
self.num_patches = self.patch_embed.num_patches
|
| 41 |
+
|
| 42 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 43 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 44 |
+
|
| 45 |
+
self.blocks = nn.ModuleList([
|
| 46 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
| 47 |
+
for i in range(depth)])
|
| 48 |
+
self.norm = norm_layer(embed_dim)
|
| 49 |
+
self.head = torch.nn.Linear(embed_dim, nb_cls)
|
| 50 |
+
self.jigsaw = torch.nn.Sequential(*[torch.nn.Linear(embed_dim, embed_dim),
|
| 51 |
+
torch.nn.ReLU(),
|
| 52 |
+
torch.nn.Linear(embed_dim, embed_dim),
|
| 53 |
+
torch.nn.ReLU(),
|
| 54 |
+
torch.nn.Linear(embed_dim, self.num_patches)])
|
| 55 |
+
self.target = torch.arange(self.num_patches)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
self.initialize_weights()
|
| 59 |
+
|
| 60 |
+
def initialize_weights(self):
|
| 61 |
+
# initialization
|
| 62 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 63 |
+
pos_embed = utils.get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 64 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 68 |
+
w = self.patch_embed.proj.weight.data
|
| 69 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 70 |
+
|
| 71 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 72 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 73 |
+
|
| 74 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 75 |
+
self.apply(self._init_weights)
|
| 76 |
+
|
| 77 |
+
def _init_weights(self, m):
|
| 78 |
+
if isinstance(m, nn.Linear):
|
| 79 |
+
# we use xavier_uniform following official JAX ViT:
|
| 80 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 81 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 82 |
+
nn.init.constant_(m.bias, 0)
|
| 83 |
+
elif isinstance(m, nn.LayerNorm):
|
| 84 |
+
nn.init.constant_(m.bias, 0)
|
| 85 |
+
nn.init.constant_(m.weight, 1.0)
|
| 86 |
+
|
| 87 |
+
def patchify(self, imgs):
|
| 88 |
+
"""
|
| 89 |
+
imgs: (N, 3, H, W)
|
| 90 |
+
x: (N, L, patch_size**2 *3)
|
| 91 |
+
"""
|
| 92 |
+
p = self.patch_embed.patch_size[0]
|
| 93 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 94 |
+
|
| 95 |
+
h = w = imgs.shape[2] // p
|
| 96 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
| 97 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 98 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
def unpatchify(self, x):
|
| 102 |
+
"""
|
| 103 |
+
x: (N, L, patch_size**2 *3)
|
| 104 |
+
imgs: (N, 3, H, W)
|
| 105 |
+
"""
|
| 106 |
+
p = self.patch_embed.patch_size[0]
|
| 107 |
+
h = w = int(x.shape[1]**.5)
|
| 108 |
+
assert h * w == x.shape[1]
|
| 109 |
+
|
| 110 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
| 111 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 112 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
| 113 |
+
return imgs
|
| 114 |
+
|
| 115 |
+
def random_masking(self, x, mask_ratio):
|
| 116 |
+
"""
|
| 117 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 118 |
+
Per-sample shuffling is done by argsort random noise.
|
| 119 |
+
x: [N, L, D], sequence
|
| 120 |
+
"""
|
| 121 |
+
N, L, D = x.shape # batch, length, dim
|
| 122 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 123 |
+
|
| 124 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
| 125 |
+
|
| 126 |
+
# sort noise for each sample
|
| 127 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 128 |
+
# target = einops.repeat(self.target, 'L -> N L', N=N)
|
| 129 |
+
# target = target.to(x.device)
|
| 130 |
+
|
| 131 |
+
# keep the first subset
|
| 132 |
+
ids_keep = ids_shuffle[:, :len_keep] # N, len_keep
|
| 133 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 134 |
+
target_masked = ids_keep
|
| 135 |
+
|
| 136 |
+
return x_masked, target_masked
|
| 137 |
+
|
| 138 |
+
def forward_jigsaw(self, x, mask_ratio):
|
| 139 |
+
# embed patches
|
| 140 |
+
x = self.patch_embed(x)
|
| 141 |
+
|
| 142 |
+
# masking: length -> length * mask_ratio
|
| 143 |
+
x, target = self.random_masking(x, mask_ratio)
|
| 144 |
+
|
| 145 |
+
# append cls token
|
| 146 |
+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
| 147 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 148 |
+
|
| 149 |
+
# apply Transformer blocks
|
| 150 |
+
for blk in self.blocks:
|
| 151 |
+
x = blk(x)
|
| 152 |
+
x = self.norm(x)
|
| 153 |
+
x = self.jigsaw(x[:, 1:])
|
| 154 |
+
return x.reshape(-1, self.num_patches), target.reshape(-1)
|
| 155 |
+
|
| 156 |
+
def forward_cls(self, x) :
|
| 157 |
+
# embed patches
|
| 158 |
+
x = self.patch_embed(x)
|
| 159 |
+
|
| 160 |
+
# add pos embed w/o cls token
|
| 161 |
+
x = x + self.pos_embed[:, 1:, :]
|
| 162 |
+
|
| 163 |
+
# append cls token
|
| 164 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 165 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 166 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 167 |
+
|
| 168 |
+
# apply Transformer blocks
|
| 169 |
+
for blk in self.blocks:
|
| 170 |
+
x = blk(x)
|
| 171 |
+
x = self.norm(x)
|
| 172 |
+
x = self.head(x[:, 0])
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
def forward(self, x_jigsaw, x_cls, mask_ratio) :
|
| 176 |
+
pred_jigsaw, targets_jigsaw = self.forward_jigsaw(x_jigsaw, mask_ratio)
|
| 177 |
+
pred_cls = self.forward_cls(x_cls)
|
| 178 |
+
return pred_jigsaw, targets_jigsaw, pred_cls
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def mae_vit_small_patch16(nb_cls, **kwargs):
|
| 182 |
+
model = MaskedAutoencoderViT(nb_cls,
|
| 183 |
+
img_size=224,
|
| 184 |
+
patch_size=16,
|
| 185 |
+
embed_dim=384,
|
| 186 |
+
depth=12,
|
| 187 |
+
num_heads=6,
|
| 188 |
+
mlp_ratio=4,
|
| 189 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 190 |
+
**kwargs)
|
| 191 |
+
return model
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def mae_vit_base_patch16(nb_cls, **kwargs):
|
| 195 |
+
model = MaskedAutoencoderViT(nb_cls,
|
| 196 |
+
img_size=224,
|
| 197 |
+
patch_size=16,
|
| 198 |
+
embed_dim=768,
|
| 199 |
+
depth=12,
|
| 200 |
+
num_heads=12,
|
| 201 |
+
mlp_ratio=4,
|
| 202 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 203 |
+
**kwargs)
|
| 204 |
+
return model
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def mae_vit_large_patch16(nb_cls, **kwargs):
|
| 208 |
+
model = MaskedAutoencoderViT(nb_cls,
|
| 209 |
+
img_size=224,
|
| 210 |
+
patch_size=16,
|
| 211 |
+
embed_dim=1024,
|
| 212 |
+
depth=24,
|
| 213 |
+
num_heads=16,
|
| 214 |
+
mlp_ratio=4,
|
| 215 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 216 |
+
**kwargs)
|
| 217 |
+
return model
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def create_model(arch, nb_cls) :
|
| 221 |
+
if arch == 'vit_small_patch16' :
|
| 222 |
+
return mae_vit_small_patch16(nb_cls)
|
| 223 |
+
elif arch == 'vit_base_patch16' :
|
| 224 |
+
return mae_vit_base_patch16(nb_cls)
|
| 225 |
+
elif arch == 'vit_large_patch16' :
|
| 226 |
+
return mae_vit_large_patch16(nb_cls)
|
| 227 |
+
|
| 228 |
+
if __name__ == '__main__':
|
| 229 |
+
|
| 230 |
+
net = create_model(arch = 'vit_small_patch16', nb_cls = 10)
|
| 231 |
+
net = net.cpu() # Move the model to CPU instead of CUDA
|
| 232 |
+
img = torch.cuda.FloatTensor(6, 3, 224, 224)
|
| 233 |
+
mask_ratio = 0.75
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
x, target = net.forward_jigsaw(img, mask_ratio)
|
src/utils.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''Some helper functions for PyTorch, including:
|
| 2 |
+
- get_mean_and_std: calculate the mean and std value of dataset.
|
| 3 |
+
- msr_init: net parameter initialization.
|
| 4 |
+
- progress_bar: progress bar mimic xlua.progress.
|
| 5 |
+
'''
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
from torch.nn import Parameter
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_logger(out_dir):
|
| 19 |
+
logger = logging.getLogger('Exp')
|
| 20 |
+
logger.setLevel(logging.INFO)
|
| 21 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
|
| 22 |
+
|
| 23 |
+
ts = str(datetime.now()).split(".")[0].replace(" ", "_")
|
| 24 |
+
ts = ts.replace(":", "_").replace("-", "_")
|
| 25 |
+
file_path = os.path.join(out_dir, "run_{}.log".format(ts)) if os.path.isdir(out_dir) else out_dir.replace('.pth.tar', '')
|
| 26 |
+
file_hdlr = logging.FileHandler(file_path)
|
| 27 |
+
file_hdlr.setFormatter(formatter)
|
| 28 |
+
|
| 29 |
+
strm_hdlr = logging.StreamHandler(sys.stdout)
|
| 30 |
+
strm_hdlr.setFormatter(formatter)
|
| 31 |
+
|
| 32 |
+
logger.addHandler(file_hdlr)
|
| 33 |
+
logger.addHandler(strm_hdlr)
|
| 34 |
+
return logger
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AverageMeter(object):
|
| 38 |
+
"""Computes and stores the average and current value"""
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.reset()
|
| 41 |
+
|
| 42 |
+
def reset(self):
|
| 43 |
+
self.val = 0
|
| 44 |
+
self.avg = 0
|
| 45 |
+
self.sum = 0
|
| 46 |
+
self.count = 0
|
| 47 |
+
|
| 48 |
+
def update(self, val, n=1):
|
| 49 |
+
self.val = val
|
| 50 |
+
self.sum += val * n
|
| 51 |
+
self.count += n
|
| 52 |
+
self.avg = self.sum / self.count
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def accuracy(output, target, topk=(1,)):
|
| 56 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
maxk = max(topk)
|
| 59 |
+
batch_size = target.size()[0]
|
| 60 |
+
|
| 61 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 62 |
+
pred = pred.t()
|
| 63 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 64 |
+
|
| 65 |
+
res = []
|
| 66 |
+
for k in topk:
|
| 67 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
| 68 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 69 |
+
return res
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def update_lr(iteration, warmup_iter, total_iter, max_lr, min_lr) :
|
| 73 |
+
if iteration < warmup_iter:
|
| 74 |
+
current_lr = max_lr * iteration / warmup_iter
|
| 75 |
+
else:
|
| 76 |
+
current_lr = min_lr + (max_lr - min_lr) * 0.5 * \
|
| 77 |
+
(1. + math.cos(math.pi * (iteration - warmup_iter) / (total_iter - warmup_iter)))
|
| 78 |
+
return current_lr
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def adjust_learning_rate(optimizer, iteration, warmup_iter, total_iter, max_lr, min_lr):
|
| 82 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
| 83 |
+
current_lr = update_lr(iteration, warmup_iter, total_iter, max_lr, min_lr)
|
| 84 |
+
for param_group in optimizer.param_groups:
|
| 85 |
+
if "lr_scale" in param_group:
|
| 86 |
+
param_group["lr"] = current_lr * param_group["lr_scale"]
|
| 87 |
+
else:
|
| 88 |
+
param_group["lr"] = current_lr
|
| 89 |
+
return current_lr
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 93 |
+
"""
|
| 94 |
+
grid_size: int of the grid height and width
|
| 95 |
+
return:
|
| 96 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 97 |
+
"""
|
| 98 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 99 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 100 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 101 |
+
grid = np.stack(grid, axis=0)
|
| 102 |
+
|
| 103 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 104 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 105 |
+
if cls_token:
|
| 106 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 107 |
+
return pos_embed
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 111 |
+
assert embed_dim % 2 == 0
|
| 112 |
+
|
| 113 |
+
# use half of dimensions to encode grid_h
|
| 114 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 115 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 116 |
+
|
| 117 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 118 |
+
return emb
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 122 |
+
"""
|
| 123 |
+
embed_dim: output dimension for each position
|
| 124 |
+
pos: a list of positions to be encoded: size (M,)
|
| 125 |
+
out: (M, D)
|
| 126 |
+
"""
|
| 127 |
+
assert embed_dim % 2 == 0
|
| 128 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 129 |
+
omega /= embed_dim / 2.
|
| 130 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 131 |
+
|
| 132 |
+
pos = pos.reshape(-1) # (M,)
|
| 133 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 134 |
+
|
| 135 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 136 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 137 |
+
|
| 138 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 139 |
+
return emb
|
| 140 |
+
|
| 141 |
+
# --------------------------------------------------------
|
| 142 |
+
# Interpolate position embeddings for high-resolution
|
| 143 |
+
# References:
|
| 144 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 145 |
+
# --------------------------------------------------------
|
| 146 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 147 |
+
if 'pos_embed' in checkpoint_model:
|
| 148 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 149 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 150 |
+
num_patches = model.patch_embed.num_patches
|
| 151 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 152 |
+
# height (== width) for the checkpoint position embedding
|
| 153 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 154 |
+
# height (== width) for the new position embedding
|
| 155 |
+
new_size = int(num_patches ** 0.5)
|
| 156 |
+
# class_token and dist_token are kept unchanged
|
| 157 |
+
if orig_size != new_size:
|
| 158 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 159 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 160 |
+
# only the position tokens are interpolated
|
| 161 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 162 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 163 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 164 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 165 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 166 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 167 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def load_my_state_dict(net, state_dict):
|
| 171 |
+
own_state = net.state_dict()
|
| 172 |
+
for name, param in state_dict.items():
|
| 173 |
+
name = name.replace('module.','')
|
| 174 |
+
if name not in own_state:
|
| 175 |
+
continue
|
| 176 |
+
if isinstance(param, Parameter):
|
| 177 |
+
# backwards compatibility for serialized parameters
|
| 178 |
+
param = param.data
|
| 179 |
+
own_state[name].copy_(param)
|
| 180 |
+
|
| 181 |
+
# Fix for non-interactive environments
|
| 182 |
+
try:
|
| 183 |
+
_, term_width = os.popen('stty size', 'r').read().split()
|
| 184 |
+
term_width = int(term_width)
|
| 185 |
+
except ValueError:
|
| 186 |
+
term_width = 80 # Set a default value if the stty command fails
|
| 187 |
+
|
| 188 |
+
TOTAL_BAR_LENGTH = 65.
|
| 189 |
+
last_time = time.time()
|
| 190 |
+
begin_time = last_time
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def progress_bar(current, total, msg=None):
|
| 194 |
+
global last_time, begin_time
|
| 195 |
+
if current == 0:
|
| 196 |
+
begin_time = time.time() # reset for new bar.
|
| 197 |
+
|
| 198 |
+
cur_len = int(TOTAL_BAR_LENGTH*current/total)
|
| 199 |
+
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
| 200 |
+
|
| 201 |
+
sys.stdout.write(' [')
|
| 202 |
+
for i in range(cur_len):
|
| 203 |
+
sys.stdout.write('=')
|
| 204 |
+
sys.stdout.write('>')
|
| 205 |
+
for i in range(rest_len):
|
| 206 |
+
sys.stdout.write('.')
|
| 207 |
+
sys.stdout.write(']')
|
| 208 |
+
|
| 209 |
+
cur_time = time.time()
|
| 210 |
+
step_time = cur_time - last_time
|
| 211 |
+
last_time = cur_time
|
| 212 |
+
tot_time = cur_time - begin_time
|
| 213 |
+
|
| 214 |
+
L = []
|
| 215 |
+
L.append(' Step: %s' % format_time(step_time))
|
| 216 |
+
L.append(' | Tot: %s' % format_time(tot_time))
|
| 217 |
+
if msg:
|
| 218 |
+
L.append(' | ' + msg)
|
| 219 |
+
|
| 220 |
+
msg = ''.join(L)
|
| 221 |
+
sys.stdout.write(msg)
|
| 222 |
+
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
| 223 |
+
sys.stdout.write(' ')
|
| 224 |
+
|
| 225 |
+
# Go back to the center of the bar.
|
| 226 |
+
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
| 227 |
+
sys.stdout.write('\b')
|
| 228 |
+
sys.stdout.write(' %d/%d ' % (current+1, total))
|
| 229 |
+
|
| 230 |
+
if current < total-1:
|
| 231 |
+
sys.stdout.write('\r')
|
| 232 |
+
else:
|
| 233 |
+
sys.stdout.write('\n')
|
| 234 |
+
sys.stdout.flush()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def format_time(seconds):
|
| 238 |
+
days = int(seconds / 3600/24)
|
| 239 |
+
seconds = seconds - days*3600*24
|
| 240 |
+
hours = int(seconds / 3600)
|
| 241 |
+
seconds = seconds - hours*3600
|
| 242 |
+
minutes = int(seconds / 60)
|
| 243 |
+
seconds = seconds - minutes*60
|
| 244 |
+
secondsf = int(seconds)
|
| 245 |
+
seconds = seconds - secondsf
|
| 246 |
+
millis = int(seconds*1000)
|
| 247 |
+
|
| 248 |
+
f = ''
|
| 249 |
+
i = 1
|
| 250 |
+
if days > 0:
|
| 251 |
+
f += str(days) + 'D'
|
| 252 |
+
i += 1
|
| 253 |
+
if hours > 0 and i <= 2:
|
| 254 |
+
f += str(hours) + 'h'
|
| 255 |
+
i += 1
|
| 256 |
+
if minutes > 0 and i <= 2:
|
| 257 |
+
f += str(minutes) + 'm'
|
| 258 |
+
i += 1
|
| 259 |
+
if secondsf > 0 and i <= 2:
|
| 260 |
+
f += str(secondsf) + 's'
|
| 261 |
+
i += 1
|
| 262 |
+
if millis > 0 and i <= 2:
|
| 263 |
+
f += str(millis) + 'ms'
|
| 264 |
+
i += 1
|
| 265 |
+
if f == '':
|
| 266 |
+
f = '0ms'
|
| 267 |
+
return f
|