PolarisFTL commited on
Commit
57b2194
·
verified ·
1 Parent(s): d9126a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -8
app.py CHANGED
@@ -1,17 +1,52 @@
 
 
1
  import gradio as gr
 
2
  from yolo import YOLO
3
 
 
4
  yolo = YOLO()
5
 
6
- def predict(image):
7
- r_image = yolo.detect_image(image)
8
- return r_image
 
 
 
 
 
 
 
9
 
10
- title = "MASFNet: Multiscale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather"
11
- description = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def reset_interface():
14
- return gr.update(value=None), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
15
 
16
  example_images = [
17
  "img/1.png",
@@ -25,6 +60,17 @@ example_images = [
25
  "img/10.png",
26
  ]
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  with gr.Blocks() as demo:
29
  gr.Markdown(f"### {title}")
30
  gr.Markdown(description)
@@ -36,7 +82,7 @@ with gr.Blocks() as demo:
36
  with gr.Column():
37
  output = gr.Image(type="pil", label="Prediction Result")
38
 
39
- submit_btn.click(fn=predict, inputs=img_input, outputs=output)
40
  demo.load(reset_interface, None, [output])
41
 
42
  with gr.Row():
 
1
+ import os
2
+ from PIL import Image
3
  import gradio as gr
4
+ from tqdm import tqdm
5
  from yolo import YOLO
6
 
7
+ # Initialize YOLO model
8
  yolo = YOLO()
9
 
10
+ def predict_image(image, crop=False, count=True):
11
+ """
12
+ Predict single image using YOLO model
13
+ """
14
+ try:
15
+ r_image = yolo.detect_image(image, crop=crop, count=count)
16
+ return r_image
17
+ except Exception as e:
18
+ print(f"Error: {e}")
19
+ return None
20
 
21
+ def predict_directory(input_dir, output_dir, crop=False, count=True):
22
+ """
23
+ Predict images in a directory using YOLO model and save results to another directory
24
+ """
25
+ img_names = os.listdir(input_dir)
26
+ results = []
27
+ for img_name in tqdm(img_names):
28
+ if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
29
+ image_path = os.path.join(input_dir, img_name)
30
+ image = Image.open(image_path)
31
+ r_image = yolo.detect_image(image, crop=crop, count=count)
32
+ if not os.path.exists(output_dir):
33
+ os.makedirs(output_dir)
34
+ output_path = os.path.join(output_dir, img_name.replace(".jpg", ".png"))
35
+ r_image.save(output_path, quality=95, subsampling=0)
36
+ results.append((img_name, output_path))
37
+ return results
38
 
39
+ def inference(image, mode='predict', crop=False, count=True, input_dir=None, output_dir=None):
40
+ if mode == 'predict':
41
+ return predict_image(image, crop=crop, count=count)
42
+ elif mode == 'dir_predict' and input_dir and output_dir:
43
+ return predict_directory(input_dir, output_dir, crop=crop, count=count)
44
+ else:
45
+ raise ValueError("Invalid mode or missing directories for 'dir_predict' mode.")
46
+
47
+ # Gradio interface setup
48
+ title = "YOLO Image Prediction"
49
+ description = "This demo allows you to perform image prediction using a YOLO model. You can either predict a single image or all images in a directory."
50
 
51
  example_images = [
52
  "img/1.png",
 
60
  "img/10.png",
61
  ]
62
 
63
+ def reset_interface():
64
+ return gr.update(value=None), gr.update(visible=False)
65
+
66
+ css = """
67
+ .image-frame img, .image-container img {
68
+ width: auto;
69
+ height: auto;
70
+ max-width: none;
71
+ }
72
+ """
73
+
74
  with gr.Blocks() as demo:
75
  gr.Markdown(f"### {title}")
76
  gr.Markdown(description)
 
82
  with gr.Column():
83
  output = gr.Image(type="pil", label="Prediction Result")
84
 
85
+ submit_btn.click(fn=inference, inputs=[img_input, 'predict', False, True], outputs=output)
86
  demo.load(reset_interface, None, [output])
87
 
88
  with gr.Row():