PolarisFTL commited on
Commit
bb3ea01
·
verified ·
1 Parent(s): 320cf10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -19
app.py CHANGED
@@ -1,8 +1,11 @@
1
- import gradio as gr
2
  from PIL import Image
3
- import os
 
 
 
 
4
  from yolo import YOLO
5
- from tqdm import tqdm
6
 
7
  yolo = YOLO()
8
 
@@ -13,19 +16,53 @@ def predict_single_image(image, crop=False, count=True):
13
  except Exception as e:
14
  return str(e)
15
 
16
- demo = gr.Interface(
17
- fn=predict_single_image,
18
- inputs=[
19
- gr.Image(value="img/1.png", type="pil", label="Input Images"),
20
- ],
21
-
22
- outputs=gr.Image(label="Detection Images",
23
- type="pil",
24
- show_download_button=True,
25
- width=400,
26
- height=600
27
- ),
28
- )
29
-
30
- if __name__ == "__main__":
31
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.patches as patches
5
+ import io
6
+ from random import choice
7
+ import gradio as gr
8
  from yolo import YOLO
 
9
 
10
  yolo = YOLO()
11
 
 
16
  except Exception as e:
17
  return str(e)
18
 
19
+ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
20
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
21
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
22
+
23
+ fdic = {
24
+ "family" : "DejaVu Serif",
25
+ "style" : "normal",
26
+ "size" : 18,
27
+ "color" : "yellow",
28
+ "weight" : "bold"
29
+ }
30
+
31
+
32
+ def get_figure(in_pil_img, in_results):
33
+ plt.figure(figsize=(16, 10))
34
+ plt.imshow(in_pil_img)
35
+ ax = plt.gca()
36
+
37
+ for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
38
+ selected_color = choice(COLORS)
39
+
40
+ box_int = [i.item() for i in torch.round(box).to(torch.int32)]
41
+ x, y, w, h = box_int[0], box_int[1], box_int[2]-box_int[0], box_int[3]-box_int[1]
42
+ #x, y, w, h = torch.round(box[0]).item(), torch.round(box[1]).item(), torch.round(box[2]-box[0]).item(), torch.round(box[3]-box[1]).item()
43
+
44
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3, alpha=0.8))
45
+ ax.text(x, y, 'MASFNet')
46
+
47
+ plt.axis("off")
48
+
49
+ return plt.gcf()
50
+
51
+
52
+
53
+ with gr.Blocks(title="MASFNet Object Detection",
54
+ css=".gradio-container {background:lightyellow;}"
55
+ ) as demo:
56
+ #sample_index = gr.State([])
57
+
58
+
59
+ with gr.Row():
60
+ input_image = gr.Image(label="Input image", type="pil")
61
+ output_image = gr.Image(label="Output image with predicted instances", type="pil")
62
+
63
+ gr.Examples(['img/1.png', 'img/2.png'], inputs=input_image)
64
+
65
+ send_btn = gr.Button("Predict")
66
+
67
+ #demo.queue()
68
+ demo.launch(debug=True)