PolarisFTL commited on
Commit
1efa86d
·
verified ·
1 Parent(s): 00fb1c0

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +75 -39
predict.py CHANGED
@@ -1,39 +1,75 @@
1
- from PIL import Image
2
- from yolo import YOLO
3
-
4
- if __name__ == "__main__":
5
- mode = 'predict'
6
- crop = False
7
- count = True
8
- dir_origin_path = "img/vs"
9
- dir_save_path = "img_out"
10
-
11
- yolo = YOLO()
12
-
13
- if mode == "predict":
14
- while True:
15
- img = input('Input image filename:')
16
- try:
17
- image = Image.open(img)
18
- except:
19
- print('Open Error! Try again!')
20
- continue
21
- else:
22
- r_image = yolo.detect_image(image, crop = crop, count=count)
23
- r_image.show()
24
-
25
- elif mode == "dir_predict":
26
- import os
27
- from tqdm import tqdm
28
-
29
- img_names = os.listdir(dir_origin_path)
30
- for img_name in tqdm(img_names):
31
- if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
32
- image_path = os.path.join(dir_origin_path, img_name)
33
- image = Image.open(image_path)
34
- r_image = yolo.detect_image(image)
35
- if not os.path.exists(dir_save_path):
36
- os.makedirs(dir_save_path)
37
- r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
38
- else:
39
- raise AssertionError("Please specify the correct mode: 'predict', 'dir_predict'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from yolo import YOLO
3
+ import gradio as gr
4
+ import os
5
+ from tqdm import tqdm
6
+
7
+ # Initialize YOLO model
8
+ yolo = YOLO()
9
+
10
+ def predict_image(image, crop=False, count=True):
11
+ """
12
+ Predict single image using YOLO model
13
+ """
14
+ try:
15
+ r_image = yolo.detect_image(image, crop=crop, count=count)
16
+ return r_image
17
+ except Exception as e:
18
+ print(f"Error: {e}")
19
+ return None
20
+
21
+ def predict_directory(input_dir, output_dir, crop=False, count=True):
22
+ """
23
+ Predict images in a directory using YOLO model and save results to another directory
24
+ """
25
+ img_names = os.listdir(input_dir)
26
+ results = []
27
+ for img_name in tqdm(img_names):
28
+ if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
29
+ image_path = os.path.join(input_dir, img_name)
30
+ image = Image.open(image_path)
31
+ r_image = yolo.detect_image(image, crop=crop, count=count)
32
+ if not os.path.exists(output_dir):
33
+ os.makedirs(output_dir)
34
+ output_path = os.path.join(output_dir, img_name.replace(".jpg", ".png"))
35
+ r_image.save(output_path, quality=95, subsampling=0)
36
+ results.append((img_name, output_path))
37
+ return results
38
+
39
+ def inference(image, mode='predict', crop=False, count=True, input_dir=None, output_dir=None):
40
+ if mode == 'predict':
41
+ return predict_image(image, crop=crop, count=count)
42
+ elif mode == 'dir_predict' and input_dir and output_dir:
43
+ return predict_directory(input_dir, output_dir, crop=crop, count=count)
44
+ else:
45
+ raise ValueError("Invalid mode or missing directories for 'dir_predict' mode.")
46
+
47
+ title = "YOLO Image Prediction"
48
+ description = "This demo allows you to perform image prediction using a YOLO model. You can either predict a single image or all images in a directory."
49
+
50
+ css = """
51
+ .image-frame img, .image-container img {
52
+ width: auto;
53
+ height: auto;
54
+ max-width: none;
55
+ }
56
+ """
57
+
58
+ demo = gr.Interface(
59
+ fn=inference,
60
+ inputs=[
61
+ gr.Image(type="pil", label="Input Image"),
62
+ gr.Radio(choices=["predict", "dir_predict"], label="Mode", value="predict"),
63
+ gr.Checkbox(value=False, label="Crop"),
64
+ gr.Checkbox(value=True, label="Count"),
65
+ gr.Textbox(placeholder="Input directory (for 'dir_predict' mode)", label="Input Directory", visible=False),
66
+ gr.Textbox(placeholder="Output directory (for 'dir_predict' mode)", label="Output Directory", visible=False),
67
+ ],
68
+ outputs=gr.Image(type="pil", label="Output Image"),
69
+ title=title,
70
+ description=description,
71
+ css=css,
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch()