Mushroom_V2 / app.py
CHCZHC's picture
Upload 3 files
22a27fa
import gradio as gr
import json
import torch
from torch.utils.data import Dataset
from torchvision import datasets,models
from torchvision.transforms import ToTensor
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
num_classes = 9
class_names = ['Agaricus','Amanita','Boletus','Cortinarius','Entoloma','Hygrocybe','Lactarius','Russula','Suillus']
label_dict = {"Agaricus":"no-toxic",
"Amanita":"toxic",
"Boletus":"toxic",
"Cortinarius":"toxic",
"Entoloma":"toxic",
"Hygrocybe":"toxic",
"Lactarius":"no-toxic",
"Russula":"no-toxic",
"Suillus":"no-toxic"}
train_transform = transforms.Compose([
transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(),
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
base_model_efficientnet_b0 = models.efficientnet_b0(pretrained=True)
base_model_efficientnet_b0.classifier[1] = nn.Linear(1280, 100)
base_model_resnet18 = models.resnet18(pretrained=True)
num_ftrs = base_model_resnet18.fc.in_features
base_model_resnet18.fc = nn.Linear(num_ftrs, 100)
base_model_vit_b_16 = models.vit_b_16(pretrained=True)
# base_model_vit_b_16.encoder.layers = base_model_vit_b_16.encoder.layers[:2]
base_model_vit_b_16.heads.head = nn.Linear(768, 100)
class ModeleBig(nn.Module):
def __init__(self):
super().__init__()
self.base_model1 = base_model_resnet18
self.base_model2 = base_model_efficientnet_b0
self.base_model3 = base_model_vit_b_16
encoder_layer = nn.TransformerEncoderLayer(d_model=100, nhead=2)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
self.fc1 = nn.Linear(300, 50)
self.fc2 = nn.Linear(50, num_classes)
self.flatten = nn.Flatten()
def forward(self, img_rgb):
# 模型一提取特征
rgb_tensor1 = self.base_model1(img_rgb)
# 模型二提取特征
rgb_tensor2 = self.base_model2(img_rgb)
# 模型三提取特征
rgb_tensor3 = self.base_model3(img_rgb)
# 合并特征并用Transformer encoder
x = torch.stack([rgb_tensor1,rgb_tensor2,rgb_tensor3],dim=1)
x = self.transformer_encoder(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
model_big = ModeleBig().to(device)
model_big.load_state_dict(torch.load("model_big.pth", map_location=torch.device('cpu')))
def predict(inp):
# 加载模型
model = model_big
# model = base_model_vgg13
model.eval()
# 定义预处理变换
transform = train_transform
# 加载图片并进行预处理
# image = Image.open(image_path).convert('RGB')
image = transform(inp).unsqueeze(0).to(device)
# 使用模型进行预测
with torch.no_grad():
output = model(image)
# print(output)
# 计算预测概率
pred_score = nn.functional.softmax(output[0], dim=0)
pred_score = str(float(torch.max(pred_score).numpy()))[:5]
# 获取预测结果
pred_index = torch.argmax(output, dim=1).item()
pred_label = class_names[pred_index]
is_du = label_dict[pred_label]
result_dict = {"class_name":pred_label,"class_score":pred_score,"is_du":is_du}
# 将字典转换为 JSON 格式的字符串
json_string = json.dumps(result_dict)
final_result = pred_label+'+'+is_du+'+'+pred_score
# 打印预测结果
# print(f"predict class name : {pred_label} \npredict score : {pred_score}")
return final_result
demo = gr.Interface(fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs="text",
examples=[['111.jpg','222.png']],
)
demo.launch()