MASFNet / app.py
PolarisFTL's picture
Update app.py
a4eef97 verified
raw
history blame
1.17 kB
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练的模型
model = models.resnet18(pretrained=True)
model = model.to(device)
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载类名称
with open("model_data/rtts_classes.txt") as f:
class_names = [line.strip() for line in f.readlines()]
# 定义预测函数
def predict(image):
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
_, predicted = outputs.max(1)
return class_names[predicted]
# 使用Gradio创建界面
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Textbox(),
title="图像分类器",
description="上传一张图像,并让模型预测它的类别。",
)
# 启动应用
if __name__ == "__main__":
iface.launch()