Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| from torch import nn | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
| # Device configuration | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Assuming 'class_names' is already defined in your script | |
| class_names = [line.strip() for line in open("classes.txt")] | |
| # Load the model | |
| model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture | |
| model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head | |
| checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint, strict=False) | |
| model = model.to(device) | |
| model.eval() | |
| # Define transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize(256, interpolation=InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Prediction function | |
| def predict(image): | |
| img = Image.fromarray(image) | |
| transformed_image = transform(img).unsqueeze(dim=0).to(device) | |
| with torch.inference_mode(): | |
| target_image_pred = model(transformed_image) | |
| target_image_pred_probs = torch.softmax(target_image_pred, dim=1) | |
| top_probs, top_indices = torch.topk(target_image_pred_probs, k=5) | |
| top_probs = top_probs.squeeze().cpu().numpy() | |
| top_indices = top_indices.squeeze().cpu().numpy() | |
| top_classes = [class_names[i] for i in top_indices] | |
| # Plotting the probabilities as a bar chart | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| sns.barplot(x=top_probs, y=top_classes, palette="viridis", ax=ax) | |
| ax.set_xlabel('Probability') | |
| ax.set_ylabel('Class') | |
| ax.set_title('Top 5 Predictions') | |
| ax.set_xlim(0, 1) | |
| for i in ax.patches: | |
| ax.text(i.get_width() + 0.02, i.get_y() + 0.55, f'{i.get_width():.2f}', | |
| ha='center', va='center', fontsize=10, color='black') | |
| sns.despine(left=True, bottom=True) | |
| plt.tight_layout() | |
| return top_classes[0], fig | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |