stevenbucaille commited on
Commit
e0a0083
·
1 Parent(s): 4f82037

feat: new space for LWDETR

Browse files
Files changed (2) hide show
  1. app.py +57 -95
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,109 +1,71 @@
1
-
2
  import torch
3
- from transformers import pipeline
4
-
5
- from PIL import Image
6
-
7
- import matplotlib.pyplot as plt
8
- import matplotlib.patches as patches
9
-
10
- from random import choice
11
- import io
12
-
13
- detector50 = pipeline(model="facebook/detr-resnet-50")
14
-
15
- detector101 = pipeline(model="facebook/detr-resnet-101")
16
-
17
-
18
  import gradio as gr
 
19
 
20
- COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
21
- "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
22
- "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
23
-
24
- fdic = {
25
- "family" : "Impact",
26
- "style" : "italic",
27
- "size" : 15,
28
- "color" : "yellow",
29
- "weight" : "bold"
30
- }
31
-
32
-
33
- def get_figure(in_pil_img, in_results):
34
- plt.figure(figsize=(16, 10))
35
- plt.imshow(in_pil_img)
36
- #pyplot.gcf()
37
- ax = plt.gca()
38
-
39
- for prediction in in_results:
40
- selected_color = choice(COLORS)
41
-
42
- x, y = prediction['box']['xmin'], prediction['box']['ymin'],
43
- w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
44
-
45
- ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
46
- ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
47
-
48
- plt.axis("off")
49
 
50
- return plt.gcf()
 
 
 
51
 
 
 
 
 
 
52
 
53
- def infer(model, in_pil_img):
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- results = None
56
- if model == "detr-resnet-101":
57
- results = detector101(in_pil_img)
58
- else:
59
- results = detector50(in_pil_img)
60
 
61
- figure = get_figure(in_pil_img, results)
62
-
63
- buf = io.BytesIO()
64
- figure.savefig(buf, bbox_inches='tight')
65
- buf.seek(0)
66
- output_pil_img = Image.open(buf)
67
-
68
- return output_pil_img
69
-
70
-
71
- with gr.Blocks(title="DETR Object Detection - ClassCat",
72
- css=".gradio-container {background:lightyellow;}"
73
- ) as demo:
74
- #sample_index = gr.State([])
75
-
76
- gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")
77
-
78
- gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
79
-
80
- model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
81
-
82
- gr.HTML("""<br/>""")
83
- gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
84
- gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")
85
 
86
  with gr.Row():
87
- input_image = gr.Image(label="Input image", type="pil")
88
- output_image = gr.Image(label="Output image with predicted instances", type="pil")
89
-
90
- gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)
91
-
92
- gr.HTML("""<br/>""")
93
- gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
94
-
95
- send_btn = gr.Button("Infer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
97
 
98
- gr.HTML("""<br/>""")
99
- gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
100
- gr.HTML("""<ul>""")
101
- gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
102
- gr.HTML("""</ul>""")
103
 
104
-
105
- #demo.queue()
106
  demo.launch(debug=True)
107
-
108
-
109
- ### EOF ###
 
 
1
  import torch
2
+ from transformers import AutoImageProcessor, LwDetrForObjectDetection
3
+ import supervision as sv
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
+ import spaces
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ @spaces.GPU
9
+ def infer(model_name, image):
10
+ label_annotator = sv.LabelAnnotator(text_padding=4, smart_position=True)
11
+ box_annotator = sv.BoxAnnotator()
12
 
13
+ model_name = f"AnnaZhang/{model_name}"
14
+ processor = AutoImageProcessor.from_pretrained(model_name)
15
+ model = LwDetrForObjectDetection.from_pretrained(model_name)
16
+ inputs = processor(images=image, return_tensors="pt")
17
+ outputs = model(**inputs)
18
 
19
+ # convert outputs (bounding boxes and class logits) to COCO API
20
+ # let's only keep detections with score > 0.7
21
+ target_sizes = torch.tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(
23
+ outputs, target_sizes=target_sizes, threshold=0.7
24
+ )[0]
25
+ detections = sv.Detections.from_transformers(
26
+ transformers_results=results, id2label=model.config.id2label
27
+ )
28
+ image = box_annotator.annotate(image, detections)
29
+ image = label_annotator.annotate(image, detections)
30
+ return image
31
 
 
 
 
 
 
32
 
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("# LWDETR Object Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  with gr.Row():
37
+ with gr.Column():
38
+ model = gr.Radio(
39
+ [
40
+ "lwdetr_tiny_30e_objects365",
41
+ "lwdetr_small_30e_objects365",
42
+ "lwdetr_medium_30e_objects365",
43
+ "lwdetr_large_30e_objects365",
44
+ "lwdetr_xlarge_30e_objects365",
45
+ "lwdetr_tiny_60e_coco",
46
+ "lwdetr_small_60e_coco",
47
+ "lwdetr_medium_60e_coco",
48
+ "lwdetr_large_60e_coco",
49
+ "lwdetr_xlarge_60e_coco",
50
+ ],
51
+ value="lwdetr_xlarge_60e_coco",
52
+ label="Model",
53
+ )
54
+ input_image = gr.Image(label="Input Image", type="pil")
55
+ send_btn = gr.Button("Infer", variant="primary")
56
+ with gr.Column():
57
+ output_image = gr.Image(label="Output Image", type="pil")
58
+
59
+ gr.Examples(
60
+ examples=[
61
+ "samples/cats.jpg",
62
+ "samples/detectron2.png",
63
+ "samples/cat.jpg",
64
+ "samples/hotdog.jpg",
65
+ ],
66
+ inputs=input_image,
67
+ )
68
  send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
69
 
 
 
 
 
 
70
 
 
 
71
  demo.launch(debug=True)
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  torch
2
- transformers[timm]
 
 
 
1
  torch
2
+ transformers[timm] @ git+https://github.com/huggingface/transformers.git
3
+ supervision
4
+ spaces