Spaces:
No application file
No application file
| from torchvision import transforms | |
| from PIL import Image | |
| from PIL import ImageFile | |
| from torch.utils.data import Dataset,DataLoader | |
| from transformers import AutoImageProcessor, BitModel, AdamW | |
| import torch | |
| from datasets import load_dataset | |
| from torch import Tensor, nn | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| import os | |
| import numpy as np | |
| from sklearn import metrics, model_selection | |
| from collections import Counter | |
| # model class | |
| class mush_root_model(torch.nn.Module): | |
| def __init__(self, num_labels=1): | |
| super(mush_root_model, self).__init__() | |
| self.model = BitModel.from_pretrained("google/bit-50") | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(2048, num_labels), | |
| ) | |
| def forward(self, input): | |
| outputs = self.model(**input).pooler_output | |
| #print(outputs.shape) | |
| logits = self.classifier(outputs) | |
| return logits | |
| # load model | |
| model_path="/kaggle/input/mush-room-model-class/mush_classifier_20230801.pth" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = mush_root_model(num_labels=9) | |
| model.load_state_dict(torch.load(model_path)) | |
| model.to(device) | |
| image_processor = AutoImageProcessor.from_pretrained("google/bit-50") | |
| # label setting | |
| labels = ['Amanita', 'Suillus', 'Boletus', 'Lactarius', 'Agaricus', 'Hygrocybe', 'Cortinarius', 'Russula', 'Entoloma'] | |
| toxic_labels = {'Amanita': 1, 'Suillus': 0, 'Boletus': 0, 'Lactarius': 0, 'Agaricus': 0, 'Hygrocybe': 1, 'Cortinarius': 0, 'Russula': 0, 'Entoloma': 1} | |
| mushroom_address_list = [ | |
| "Amanita毒蝇伞,伞菌目,鹅膏菌科,鹅膏菌属,主要分布于我国黑龙江、吉林、四川、西藏、云南等地,有毒", | |
| "Suillus乳牛肝菌,牛肝菌目,乳牛肝菌科,乳牛肝菌属,分布于吉林、辽宁、山西、安徽、江西、浙江、湖南、四川、贵州等地,无毒", | |
| "Boletus丽柄牛肝菌,伞菌目,牛肝菌科,牛肝菌属,分布于云南、陕西、甘肃、西藏等地,有毒", | |
| "Lactarius松乳菇,红菇目,红菇科,乳菇属,广泛分布于亚热带松林地,无毒", | |
| "Agaricus双孢蘑菇,伞菌目,蘑菇科,蘑菇属,广泛分布于北半球温带,无毒", | |
| "Hygrocybe浅黄褐湿伞,伞菌目,蜡伞科,湿伞属,分布于香港(见于松仔园),有毒", | |
| "Cortinarius掷丝膜菌,伞菌目,丝膜菌科,丝膜菌属,分布于湖南等地(夏秋季在山毛等阔叶林地上生长)", | |
| "Russula褪色红菇,伞菌目,红菇科,红菇属,分布于河北、吉林、四川、江苏、西藏等地,无毒", | |
| "Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒", | |
| ] | |
| def image_process(image_path): | |
| image = Image.open(image_path) | |
| image_pt = image_processor(image,return_tensors="pt") | |
| return image_pt | |
| def predict(image_path): | |
| image_pt = image_process(image_path) | |
| images = image_pt.to(device) | |
| #print(images['pixel_values'].shape) | |
| outputs = torch.squeeze(model(images)) | |
| output = torch.sigmoid(outputs).cpu().detach().numpy().tolist() | |
| label_id = np.argmax(output, axis=-1) | |
| label_score = output[label_id] | |
| return label_id, label_score | |
| image_path="/kaggle/input/mush-dataset/dataset1/无毒类/Cortinarius/000_Pw3qUBVmwN8.jpg" | |
| mushroom_class, confidence = predict(image_path) | |
| toxic = toxic_labels[labels[mushroom_class]] | |
| address = mushroom_address_list[mushroom_class] | |
| print(f"the class of mushroom is { labels[mushroom_class]}, its confidence is {confidence} and it is {bool(toxic)} toxic") | |