Pattira commited on
Commit
c188015
·
verified ·
1 Parent(s): 8436542

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -1,18 +1,21 @@
1
  import torch
2
  from transformers import pipeline
3
 
4
- from PIT import Image
5
 
6
- import matplotlib.pyplot as plt
7
  import matplotlib.patches as patches
8
 
9
  from random import choice
10
  import io
11
- import gradio as gr
12
 
13
  detector50 = pipeline(model="facebook/detr-resnet-50")
 
14
  detector101 = pipeline(model="facebook/detr-resnet-101")
15
 
 
 
 
16
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
17
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
18
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
@@ -25,40 +28,46 @@ fdic = {
25
  "weight" : "bold"
26
  }
27
 
 
28
  def get_figure(in_pil_img, in_results):
29
- plt.figure(figsize=(16,10))
30
  plt.imshow(in_pil_img)
31
-
32
  ax = plt.gca()
33
 
34
  for prediction in in_results:
35
- select_color = choice(COLORS)
36
 
37
- x, y = prediction['box']['xmin'], predictiion['box']['ymin'],
38
- w, h = prediction['box']['xmax'] - prediction['box']['xmax'], prediction['box']['ymax'] - prediction['box']['ymin']
39
- ax.add_patch(plt.Rectangle((x,y), w, h, fill=False, color=selected_color, linewidth=3))
 
40
  ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
41
-
42
  plt.axis("off")
 
43
  return plt.gcf()
44
-
 
45
  def infer(model, in_pil_img):
 
46
  results = None
47
  if model == "detr-resnet-101":
48
  results = detector101(in_pil_img)
49
  else:
50
  results = detector50(in_pil_img)
51
-
52
- figure = get_figure(in_pil_img)
53
 
54
- buf = io.BytesIO():
 
 
55
  figure.savefig(buf, bbox_inches='tight')
56
  buf.seek(0)
57
  output_pil_img = Image.open(buf)
58
 
59
  return output_pil_img
60
 
61
- with gr.Blocks(title="DETR Object Detection - ClassCat",
 
62
  css=".gradio-container {background:lightyellow;}"
63
  ) as demo:
64
  #sample_index = gr.State([])
@@ -94,5 +103,6 @@ def infer(model, in_pil_img):
94
 
95
  #demo.queue()
96
  demo.launch(debug=True)
97
-
98
-
 
 
1
  import torch
2
  from transformers import pipeline
3
 
4
+ from PIL import Image
5
 
6
+ import matplotlib.pyplot as plt
7
  import matplotlib.patches as patches
8
 
9
  from random import choice
10
  import io
 
11
 
12
  detector50 = pipeline(model="facebook/detr-resnet-50")
13
+
14
  detector101 = pipeline(model="facebook/detr-resnet-101")
15
 
16
+
17
+ import gradio as gr
18
+
19
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
20
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
21
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
 
28
  "weight" : "bold"
29
  }
30
 
31
+
32
  def get_figure(in_pil_img, in_results):
33
+ plt.figure(figsize=(16, 10))
34
  plt.imshow(in_pil_img)
35
+ #pyplot.gcf()
36
  ax = plt.gca()
37
 
38
  for prediction in in_results:
39
+ selected_color = choice(COLORS)
40
 
41
+ x, y = prediction['box']['xmin'], prediction['box']['ymin'],
42
+ w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
43
+
44
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
45
  ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
46
+
47
  plt.axis("off")
48
+
49
  return plt.gcf()
50
+
51
+
52
  def infer(model, in_pil_img):
53
+
54
  results = None
55
  if model == "detr-resnet-101":
56
  results = detector101(in_pil_img)
57
  else:
58
  results = detector50(in_pil_img)
 
 
59
 
60
+ figure = get_figure(in_pil_img, results)
61
+
62
+ buf = io.BytesIO()
63
  figure.savefig(buf, bbox_inches='tight')
64
  buf.seek(0)
65
  output_pil_img = Image.open(buf)
66
 
67
  return output_pil_img
68
 
69
+
70
+ with gr.Blocks(title="DETR Object Detection - ClassCat",
71
  css=".gradio-container {background:lightyellow;}"
72
  ) as demo:
73
  #sample_index = gr.State([])
 
103
 
104
  #demo.queue()
105
  demo.launch(debug=True)
106
+
107
+
108
+ ### EOF ###