Spaces:
Build error
Build error
| import torch | |
| import torchvision.transforms as transforms | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import json | |
| # Get cpu or gpu device for training. | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using {device} device") | |
| image_path = '/content/mogu_dataset/有毒类/Amanita/000_jY-KIKiHVjQ.jpg' | |
| num_classes = 2 | |
| label_name_list = ['No', 'Yes'] | |
| # step1: 加载模型 | |
| # model = model_big | |
| class BaseModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(3, 6, 3) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(6, 16, 3) | |
| self.conv3 = nn.Conv2d(16, 32, 3) | |
| self.conv4 = nn.Conv2d(32, 32, 3) | |
| self.fc1 = nn.Linear(4608, 120) | |
| self.fc2 = nn.Linear(120, 84) | |
| self.fc3 = nn.Linear(84, num_classes) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = self.pool(F.relu(self.conv3(x))) | |
| x = self.pool(F.relu(self.conv4(x))) | |
| x = torch.flatten(x, 1) # flatten all dimensions except batch | |
| x = F.relu(self.fc1(x)) | |
| x = F.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| cnn_model = BaseModel().to(device) | |
| cnn_model.load_state_dict(torch.load("cnn_model.pth")) | |
| model = cnn_model | |
| # model = base_model_vgg13 | |
| model.eval() | |
| # step2: 图片转换 | |
| train_transform = transforms.Compose([ | |
| transforms.RandomRotation(5), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(inp): | |
| # 定义预处理变换 | |
| transform = train_transform | |
| # 加载图片并进行预处理 | |
| # image = Image.open(image_path) | |
| image = transform(inp).unsqueeze(0).to(device) | |
| # step3:使用模型进行预测 | |
| with torch.no_grad(): | |
| output = model(image) | |
| print(output) | |
| # step4:数据后处理 | |
| # 计算预测概率 | |
| pred_score = nn.functional.softmax(output[0], dim=0) | |
| pred_score = pred_score.numpy() | |
| # pred_score = torch.max(pred_score) | |
| # 获取预测结果 | |
| pred_index = torch.argmax(output, dim=1).item() | |
| pred_label = label_name_list[pred_index] | |
| result_dict = {'pred_score':max(pred_score),'pred_index':pred_index,'pred_label':pred_label } | |
| # result_json = json.dumps(result_dict) | |
| # # 打印预测结果 | |
| # print(f"predict class name : {pred_label} \npredict score : {pred_score}") | |
| result_json = str(result_dict) | |
| return result_json | |
| demo = gr.Interface(fn=predict, | |
| inputs=gr.inputs.Image(type="pil"), | |
| outputs="text", | |
| examples=[["111212.jpg"]], | |
| ) | |
| demo.launch() |