WizardForest commited on
Commit
ca23df0
·
verified ·
1 Parent(s): 356099f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -7,27 +7,32 @@ def predict(img):
7
 
8
  model = YOLO("model.pt")
9
 
10
- results = model.predict(source=img,save=False, show_labels=False, show_conf=False)
 
 
 
 
 
11
  count = 0
 
12
  for i in results:
13
  count = len(i.boxes)
14
  annot_img = i.plot(labels=False, masks=False)
 
 
 
15
 
16
- #result_img = os.path.join(cur_path,r"runs\segment\predict",f"{path}.jpg")
17
- #print("result img",result_img)
18
  return str(count), annot_img
19
 
20
  with gr.Blocks(title="Pill Counter") as demo:
21
-
22
  with gr.Row():
23
  with gr.Column():
24
- img = gr.Image(type="filepath",format="jpg", height=500, width=700)
25
-
26
  button = gr.Button()
27
  with gr.Column():
28
  data_output = gr.Textbox()
29
  img_output = gr.Image(type="numpy")
30
- button.click(fn=predict, inputs=img, outputs=[data_output,img_output])
31
 
32
  if __name__ == "__main__":
33
  demo.launch()
 
7
 
8
  model = YOLO("model.pt")
9
 
10
+ results = model.predict(source=img, save=False, show_labels=False, show_conf=False)
11
+
12
+ # Handle both single result and list of results
13
+ if not isinstance(results, (list, tuple)):
14
+ results = [results]
15
+
16
  count = 0
17
+ annot_img = None
18
  for i in results:
19
  count = len(i.boxes)
20
  annot_img = i.plot(labels=False, masks=False)
21
+ # Convert BGR to RGB if needed
22
+ if annot_img.shape[-1] == 3: # Check if image has 3 channels
23
+ annot_img = annot_img[..., ::-1] # Reverse the color channels
24
 
 
 
25
  return str(count), annot_img
26
 
27
  with gr.Blocks(title="Pill Counter") as demo:
 
28
  with gr.Row():
29
  with gr.Column():
30
+ img = gr.Image(type="filepath", format="jpg", height=500, width=700)
 
31
  button = gr.Button()
32
  with gr.Column():
33
  data_output = gr.Textbox()
34
  img_output = gr.Image(type="numpy")
35
+ button.click(fn=predict, inputs=img, outputs=[data_output, img_output])
36
 
37
  if __name__ == "__main__":
38
  demo.launch()