PolarisFTL commited on
Commit
a4eef97
·
verified ·
1 Parent(s): 5dbd914

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -62
app.py CHANGED
@@ -1,72 +1,45 @@
1
- from PIL import Image
2
- from yolo import YOLO
3
  import gradio as gr
4
- import os
 
 
5
 
6
- # Initialize YOLO model
7
- yolo = YOLO()
8
 
9
- def detect_objects(image, crop=False, count=True):
10
- r_image = yolo.detect_image(image, crop=crop, count=count)
11
- return r_image
 
12
 
13
- def save_image(image, filename):
14
- if not os.path.exists("img_out"):
15
- os.makedirs("img_out")
16
- image.save(os.path.join("img_out", filename), quality=95, subsampling=0)
17
- return os.path.join("img_out", filename)
 
 
18
 
19
- # Gradio interface for single image prediction
20
- def predict(image):
21
- result_image = detect_objects(image)
22
- output_path = save_image(result_image, "output.png")
23
- return result_image
24
 
25
- # Function to list images in the 'img' folder
26
- def get_image_list():
27
- image_folder = "img"
28
- if not os.path.exists(image_folder):
29
- os.makedirs(image_folder)
30
- img_names = [f for f in os.listdir(image_folder) if f.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))]
31
- img_paths = [os.path.join(image_folder, img_name) for img_name in img_names]
32
- return img_paths
33
-
34
- # Gradio interface components
35
- image_output = gr.Image(type="pil", label="Output Image")
36
- image_select = gr.inputs.Radio(get_image_list, label="Select Image from 'img' Folder")
37
 
38
- # Gradio app for single image prediction
39
  iface = gr.Interface(
40
- fn=predict,
41
- inputs=image_select,
42
- outputs=image_output,
43
- title="YOLO Object Detection",
44
- description="Select an image from the 'img' folder to detect objects using YOLO model."
45
  )
46
 
47
- # Custom function to display images in a gallery format
48
- def gallery_update():
49
- image_list = get_image_list()
50
- return gr.Gallery.update(value=image_list)
51
-
52
- # Gradio gallery component for displaying images in the 'img' folder
53
- gallery = gr.Gallery(get_image_list(), label="Image Gallery")
54
-
55
- # Update gallery when Gradio interface loads
56
- gallery.update(gallery_update)
57
-
58
- # Combine both interfaces
59
- app = gr.Blocks()
60
-
61
- with app:
62
- with gr.Row():
63
- with gr.Column():
64
- gallery.render()
65
- with gr.Column():
66
- image = gr.Image(type="pil", label="Selected Image")
67
- predict_button = gr.Button("Predict")
68
- output_image = gr.Image(type="pil", label="Output Image")
69
- predict_button.click(predict, inputs=image, outputs=output_image)
70
-
71
- # Launch the app
72
- app.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
 
6
+ # 定义设备
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # 加载预训练的模型
10
+ model = models.resnet18(pretrained=True)
11
+ model = model.to(device)
12
+ model.eval()
13
 
14
+ # 图像预处理
15
+ transform = transforms.Compose([
16
+ transforms.Resize(256),
17
+ transforms.CenterCrop(224),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20
+ ])
21
 
22
+ # 加载类名称
23
+ with open("model_data/rtts_classes.txt") as f:
24
+ class_names = [line.strip() for line in f.readlines()]
 
 
25
 
26
+ # 定义预测函数
27
+ def predict(image):
28
+ image = transform(image).unsqueeze(0).to(device)
29
+ with torch.no_grad():
30
+ outputs = model(image)
31
+ _, predicted = outputs.max(1)
32
+ return class_names[predicted]
 
 
 
 
 
33
 
34
+ # 使用Gradio创建界面
35
  iface = gr.Interface(
36
+ fn=predict,
37
+ inputs=gr.inputs.Image(type="pil"),
38
+ outputs=gr.outputs.Textbox(),
39
+ title="图像分类器",
40
+ description="上传一张图像,并让模型预测它的类别。",
41
  )
42
 
43
+ # 启动应用
44
+ if __name__ == "__main__":
45
+ iface.launch()