Dog_Breed_ImageWoof / data_setup.py
Chukwuka's picture
Making first Commit
b8a2709
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt
import albumentations as A
from albumentations.pytorch import ToTensorV2
stats = (0.4862, 0.4561, 0.3941), (0.2202, 0.2142, 0.2160)
model_tsfm = A.Compose([
A.Resize(224, 224),
A.Normalize(*stats),
ToTensorV2()
])
classes = ['Australian terrier', 'Border terrier', 'Samoyed', 'Beagle', 'Shih-Tzu', 'English foxhound', 'Rhodesian ridgeback', 'Dingo', 'Golden retriever', 'Old English sheepdog']
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--Image',
help="input image path", required=True)
args = vars(parser.parse_args())
print(args)
img_path = args['Image']
#plt.imshow(get_image(img_path, model_tsfm).permute(1,2,0))
#img_pred = eff_b2(get_image(img_path, model_tsfm).unsqueeze(0).to(device))
#print(img_pred)
#img_class = torch.argmax(img_pred)
#print(img_class)
#print(classes[img_class.item()])