Mushroom_V3 / mushroom_class_load_predict.py
CHCZHC's picture
Upload 3 files
0612bb5
from torchvision import transforms
from PIL import Image
from PIL import ImageFile
from torch.utils.data import Dataset,DataLoader
from transformers import AutoImageProcessor, BitModel, AdamW
import torch
from datasets import load_dataset
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import os
import numpy as np
from sklearn import metrics, model_selection
from collections import Counter
# model class
class mush_root_model(torch.nn.Module):
def __init__(self, num_labels=1):
super(mush_root_model, self).__init__()
self.model = BitModel.from_pretrained("google/bit-50")
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(2048, num_labels),
)
def forward(self, input):
outputs = self.model(**input).pooler_output
#print(outputs.shape)
logits = self.classifier(outputs)
return logits
# load model
model_path="/kaggle/input/mush-room-model-class/mush_classifier_20230801.pth"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = mush_root_model(num_labels=9)
model.load_state_dict(torch.load(model_path))
model.to(device)
image_processor = AutoImageProcessor.from_pretrained("google/bit-50")
# label setting
labels = ['Amanita', 'Suillus', 'Boletus', 'Lactarius', 'Agaricus', 'Hygrocybe', 'Cortinarius', 'Russula', 'Entoloma']
toxic_labels = {'Amanita': 1, 'Suillus': 0, 'Boletus': 0, 'Lactarius': 0, 'Agaricus': 0, 'Hygrocybe': 1, 'Cortinarius': 0, 'Russula': 0, 'Entoloma': 1}
mushroom_address_list = [
"Amanita毒蝇伞,伞菌目,鹅膏菌科,鹅膏菌属,主要分布于我国黑龙江、吉林、四川、西藏、云南等地,有毒",
"Suillus乳牛肝菌,牛肝菌目,乳牛肝菌科,乳牛肝菌属,分布于吉林、辽宁、山西、安徽、江西、浙江、湖南、四川、贵州等地,无毒",
"Boletus丽柄牛肝菌,伞菌目,牛肝菌科,牛肝菌属,分布于云南、陕西、甘肃、西藏等地,有毒",
"Lactarius松乳菇,红菇目,红菇科,乳菇属,广泛分布于亚热带松林地,无毒",
"Agaricus双孢蘑菇,伞菌目,蘑菇科,蘑菇属,广泛分布于北半球温带,无毒",
"Hygrocybe浅黄褐湿伞,伞菌目,蜡伞科,湿伞属,分布于香港(见于松仔园),有毒",
"Cortinarius掷丝膜菌,伞菌目,丝膜菌科,丝膜菌属,分布于湖南等地(夏秋季在山毛等阔叶林地上生长)",
"Russula褪色红菇,伞菌目,红菇科,红菇属,分布于河北、吉林、四川、江苏、西藏等地,无毒",
"Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒",
]
def image_process(image_path):
image = Image.open(image_path)
image_pt = image_processor(image,return_tensors="pt")
return image_pt
def predict(image_path):
image_pt = image_process(image_path)
images = image_pt.to(device)
#print(images['pixel_values'].shape)
outputs = torch.squeeze(model(images))
output = torch.sigmoid(outputs).cpu().detach().numpy().tolist()
label_id = np.argmax(output, axis=-1)
label_score = output[label_id]
return label_id, label_score
image_path="/kaggle/input/mush-dataset/dataset1/无毒类/Cortinarius/000_Pw3qUBVmwN8.jpg"
mushroom_class, confidence = predict(image_path)
toxic = toxic_labels[labels[mushroom_class]]
address = mushroom_address_list[mushroom_class]
print(f"the class of mushroom is { labels[mushroom_class]}, its confidence is {confidence} and it is {bool(toxic)} toxic")