Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from io import BytesIO | |
| from vit_model import vit_base_patch16_224_in21k as create_model | |
| def classify_image(img): | |
| # Your existing code here, modified to use `img_path` as input | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| data_transform = transforms.Compose( | |
| [transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
| # [N, C, H, W] | |
| img = data_transform(img) | |
| # expand batch dimension | |
| img = torch.unsqueeze(img, dim=0) | |
| # read class_indict | |
| json_path = './class_indices.json' | |
| assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) | |
| with open(json_path, "r") as f: | |
| class_indict = json.load(f) | |
| # create model | |
| model = create_model(num_classes=370, has_logits=False).to(device) | |
| # load model weights | |
| model_weight_path = "./best_model.pth" | |
| model.load_state_dict(torch.load(model_weight_path, map_location=device)) | |
| model.eval() | |
| with torch.no_grad(): | |
| # predict class | |
| output = torch.squeeze(model(img.to(device))).cpu() | |
| predict = torch.softmax(output, dim=0) | |
| predict_cla = torch.argmax(predict).numpy() | |
| print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], | |
| predict[predict_cla].numpy()) | |
| # Combine the two lists into a list of tuples | |
| combined_list = list(zip(class_indict, predict)) | |
| # Sort the combined list by the 'predict' values in descending order | |
| sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True) | |
| # Determine the position you are currently interested in | |
| current_position = 5 # Example position | |
| # Get the previous five elements from the sorted list | |
| # Ensure that the index does not go below zero | |
| start_index = max(current_position - 5, 0) | |
| previous_five = sorted_combined_list[start_index:current_position] | |
| joined_string = "" | |
| for i in previous_five: | |
| #print("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) | |
| joined_string += ("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n" | |
| #print(joined_string) | |
| plt.title(joined_string) | |
| plt.tight_layout() | |
| fig = plt.figure() | |
| return joined_string | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| theme=gr.themes.Default(text_size="lg"), | |
| inputs=gr.Image(type='pil'), | |
| outputs=gr.Textbox(), | |
| title="Mushroom Image Classification", | |
| description="Upload a mushroom image to classify." | |
| ) | |
| # Run the Gradio app | |
| #if __name__ == '__main__': | |
| iface.launch() | |