jflo's picture
Initial push
91b7cf9 verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
def transform_img(img):
# Transformations that will be applied
the_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.CenterCrop((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
return the_transform(img)
# Returns string with class and probability
def classify_img(img):
class_names = ['AIR COMPRESSOR',
'ALTERNATOR',
'BATTERY',
'BRAKE CALIPER',
'BRAKE PAD',
'BRAKE ROTOR',
'CAMSHAFT',
'CARBERATOR',
'CLUTCH PLATE',
'COIL SPRING',
'CRANKSHAFT',
'CYLINDER HEAD',
'DISTRIBUTOR',
'ENGINE BLOCK',
'ENGINE VALVE',
'FUEL INJECTOR',
'FUSE BOX',
'GAS CAP',
'HEADLIGHTS',
'IDLER ARM',
'IGNITION COIL',
'INSTRUMENT CLUSTER',
'LEAF SPRING',
'LOWER CONTROL ARM',
'MUFFLER',
'OIL FILTER',
'OIL PAN',
'OIL PRESSURE SENSOR',
'OVERFLOW TANK',
'OXYGEN SENSOR',
'PISTON',
'PRESSURE PLATE',
'RADIATOR',
'RADIATOR FAN',
'RADIATOR HOSE',
'RADIO',
'RIM',
'SHIFT KNOB',
'SIDE MIRROR',
'SPARK PLUG',
'SPOILER',
'STARTER',
'TAILLIGHTS',
'THERMOSTAT',
'TORQUE CONVERTER',
'TRANSMISSION',
'VACUUM BRAKE BOOSTER',
'VALVE LIFTER',
'WATER PUMP',
'WINDOW REGULATOR']
model = torch.jit.load("car_part_traced_classifier_resnet50.ptl")
# Applying transformation to the image
model_img = transform_img(img)
model_img = model_img.view(1,3,224,224)
# Running image through the model
model.eval()
with torch.no_grad():
result = model(model_img)
# Converting values to softmax values
result = F.softmax(result,dim=1)
# Grabbing top 3 indices and probabilities for each index
top3_prob, top3_catid = torch.topk(result,3)
# Dictionary I will display
model_output = {}
for i in range(top3_prob.size(1)):
model_output[class_names[top3_catid[0][i].item()]] = top3_prob[0][i].item()
print(model_output)
return model_output
demo = gr.Interface(classify_img, gr.Image(type='pil'), outputs = gr.Label(num_top_classes=3))
demo.launch()