cwpkd commited on
Commit
42b3d17
·
verified ·
1 Parent(s): 19d00bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from transformers import pipeline
3
 
@@ -8,11 +9,14 @@ import matplotlib.patches as patches
8
 
9
  from random import choice
10
  import io
11
- import gradio as gr
12
 
13
  detector50 = pipeline(model="facebook/detr-resnet-50")
 
14
  detector101 = pipeline(model="facebook/detr-resnet-101")
15
 
 
 
 
16
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
17
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
18
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
@@ -25,10 +29,11 @@ fdic = {
25
  "weight" : "bold"
26
  }
27
 
 
28
  def get_figure(in_pil_img, in_results):
29
  plt.figure(figsize=(16, 10))
30
  plt.imshow(in_pil_img)
31
-
32
  ax = plt.gca()
33
 
34
  for prediction in in_results:
@@ -36,13 +41,17 @@ def get_figure(in_pil_img, in_results):
36
 
37
  x, y = prediction['box']['xmin'], prediction['box']['ymin'],
38
  w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
 
39
  ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
40
  ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
41
 
42
  plt.axis("off")
 
43
  return plt.gcf()
44
 
 
45
  def infer(model, in_pil_img):
 
46
  results = None
47
  if model == "detr-resnet-101":
48
  results = detector101(in_pil_img)
@@ -57,7 +66,8 @@ def infer(model, in_pil_img):
57
  output_pil_img = Image.open(buf)
58
 
59
  return output_pil_img
60
-
 
61
  with gr.Blocks(title="DETR Object Detection - ClassCat",
62
  css=".gradio-container {background:lightyellow;}"
63
  ) as demo:
@@ -93,5 +103,5 @@ with gr.Blocks(title="DETR Object Detection - ClassCat",
93
 
94
 
95
  #demo.queue()
96
- demo.launch(debug=True)
97
-
 
1
+
2
  import torch
3
  from transformers import pipeline
4
 
 
9
 
10
  from random import choice
11
  import io
 
12
 
13
  detector50 = pipeline(model="facebook/detr-resnet-50")
14
+
15
  detector101 = pipeline(model="facebook/detr-resnet-101")
16
 
17
+
18
+ import gradio as gr
19
+
20
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
21
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
22
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
 
29
  "weight" : "bold"
30
  }
31
 
32
+
33
  def get_figure(in_pil_img, in_results):
34
  plt.figure(figsize=(16, 10))
35
  plt.imshow(in_pil_img)
36
+ #pyplot.gcf()
37
  ax = plt.gca()
38
 
39
  for prediction in in_results:
 
41
 
42
  x, y = prediction['box']['xmin'], prediction['box']['ymin'],
43
  w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
44
+
45
  ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
46
  ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
47
 
48
  plt.axis("off")
49
+
50
  return plt.gcf()
51
 
52
+
53
  def infer(model, in_pil_img):
54
+
55
  results = None
56
  if model == "detr-resnet-101":
57
  results = detector101(in_pil_img)
 
66
  output_pil_img = Image.open(buf)
67
 
68
  return output_pil_img
69
+
70
+
71
  with gr.Blocks(title="DETR Object Detection - ClassCat",
72
  css=".gradio-container {background:lightyellow;}"
73
  ) as demo:
 
103
 
104
 
105
  #demo.queue()
106
+ demo.launch(debug=True)el == "detr-resnet-101":
107
+