PolarisFTL commited on
Commit
b497c55
·
verified ·
1 Parent(s): 08f663c

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +75 -0
predict.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from yolo import YOLO
3
+ import gradio as gr
4
+ import os
5
+ from tqdm import tqdm
6
+
7
+ # Initialize YOLO model
8
+ yolo = YOLO()
9
+
10
+ def predict_image(image, crop=False, count=True):
11
+ """
12
+ Predict single image using YOLO model
13
+ """
14
+ try:
15
+ r_image = yolo.detect_image(image, crop=crop, count=count)
16
+ return r_image
17
+ except Exception as e:
18
+ print(f"Error: {e}")
19
+ return None
20
+
21
+ def predict_directory(input_dir, output_dir, crop=False, count=True):
22
+ """
23
+ Predict images in a directory using YOLO model and save results to another directory
24
+ """
25
+ img_names = os.listdir(input_dir)
26
+ results = []
27
+ for img_name in tqdm(img_names):
28
+ if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
29
+ image_path = os.path.join(input_dir, img_name)
30
+ image = Image.open(image_path)
31
+ r_image = yolo.detect_image(image, crop=crop, count=count)
32
+ if not os.path.exists(output_dir):
33
+ os.makedirs(output_dir)
34
+ output_path = os.path.join(output_dir, img_name.replace(".jpg", ".png"))
35
+ r_image.save(output_path, quality=95, subsampling=0)
36
+ results.append((img_name, output_path))
37
+ return results
38
+
39
+ def inference(image, mode='predict', crop=False, count=True, input_dir=None, output_dir=None):
40
+ if mode == 'predict':
41
+ return predict_image(image, crop=crop, count=count)
42
+ elif mode == 'dir_predict' and input_dir and output_dir:
43
+ return predict_directory(input_dir, output_dir, crop=crop, count=count)
44
+ else:
45
+ raise ValueError("Invalid mode or missing directories for 'dir_predict' mode.")
46
+
47
+ title = "YOLO Image Prediction"
48
+ description = "This demo allows you to perform image prediction using a YOLO model. You can either predict a single image or all images in a directory."
49
+
50
+ css = """
51
+ .image-frame img, .image-container img {
52
+ width: auto;
53
+ height: auto;
54
+ max-width: none;
55
+ }
56
+ """
57
+
58
+ demo = gr.Interface(
59
+ fn=inference,
60
+ inputs=[
61
+ gr.Image(type="pil", label="Input Image"),
62
+ gr.Radio(choices=["predict", "dir_predict"], label="Mode", value="predict"),
63
+ gr.Checkbox(value=False, label="Crop"),
64
+ gr.Checkbox(value=True, label="Count"),
65
+ gr.Textbox(placeholder="Input directory (for 'dir_predict' mode)", label="Input Directory", visible=False),
66
+ gr.Textbox(placeholder="Output directory (for 'dir_predict' mode)", label="Output Directory", visible=False),
67
+ ],
68
+ outputs=gr.Image(type="pil", label="Output Image"),
69
+ title=title,
70
+ description=description,
71
+ css=css,
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch()