import gradio as gr import torch import torch.nn.functional as F from torchvision import transforms def transform_img(img): # Transformations that will be applied the_transform = transforms.Compose([ transforms.Resize((224,224)), transforms.CenterCrop((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) ]) return the_transform(img) # Returns string with class and probability def classify_img(img): class_names = ['AIR COMPRESSOR', 'ALTERNATOR', 'BATTERY', 'BRAKE CALIPER', 'BRAKE PAD', 'BRAKE ROTOR', 'CAMSHAFT', 'CARBERATOR', 'CLUTCH PLATE', 'COIL SPRING', 'CRANKSHAFT', 'CYLINDER HEAD', 'DISTRIBUTOR', 'ENGINE BLOCK', 'ENGINE VALVE', 'FUEL INJECTOR', 'FUSE BOX', 'GAS CAP', 'HEADLIGHTS', 'IDLER ARM', 'IGNITION COIL', 'INSTRUMENT CLUSTER', 'LEAF SPRING', 'LOWER CONTROL ARM', 'MUFFLER', 'OIL FILTER', 'OIL PAN', 'OIL PRESSURE SENSOR', 'OVERFLOW TANK', 'OXYGEN SENSOR', 'PISTON', 'PRESSURE PLATE', 'RADIATOR', 'RADIATOR FAN', 'RADIATOR HOSE', 'RADIO', 'RIM', 'SHIFT KNOB', 'SIDE MIRROR', 'SPARK PLUG', 'SPOILER', 'STARTER', 'TAILLIGHTS', 'THERMOSTAT', 'TORQUE CONVERTER', 'TRANSMISSION', 'VACUUM BRAKE BOOSTER', 'VALVE LIFTER', 'WATER PUMP', 'WINDOW REGULATOR'] model = torch.jit.load("car_part_traced_classifier_resnet50.ptl") # Applying transformation to the image model_img = transform_img(img) model_img = model_img.view(1,3,224,224) # Running image through the model model.eval() with torch.no_grad(): result = model(model_img) # Converting values to softmax values result = F.softmax(result,dim=1) # Grabbing top 3 indices and probabilities for each index top3_prob, top3_catid = torch.topk(result,3) # Dictionary I will display model_output = {} for i in range(top3_prob.size(1)): model_output[class_names[top3_catid[0][i].item()]] = top3_prob[0][i].item() print(model_output) return model_output demo = gr.Interface(classify_img, gr.Image(type='pil'), outputs = gr.Label(num_top_classes=3)) demo.launch()