engajify commited on
Commit
1882592
·
verified ·
1 Parent(s): 514e0eb

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +29 -0
  2. app.py +165 -0
  3. gitattributes +35 -0
  4. requirements.txt +15 -0
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Object Detection Video
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.31.5
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # Object Detection in Video
16
+
17
+ This is a Gradio interface that allows users to upload a video and an image to detect if the object in the image is present in the video. The app uses ResNet and OWL-ViT models for object detection and similarity measurement.
18
+
19
+ ## How to Use
20
+
21
+ 1. Upload a video.
22
+ 2. Upload a query image containing the object you want to detect.
23
+ 3. Adjust the skip frames and threshold sliders as needed.
24
+ 4. The app will process the video and display the results.
25
+
26
+ ## Example
27
+
28
+ For instance, to check if a specific object is present in a video, upload the video and the image of the object. Adjust the parameters and view the results.
29
+
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection, ResNetModel
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import cv2
8
+ import torch.nn.functional as F
9
+ import tempfile
10
+ import os
11
+
12
+ # Load models
13
+ resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
14
+ resnet.eval()
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ resnet = resnet.to(device)
17
+
18
+ mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
19
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
20
+ model = mixin.to(device)
21
+
22
+ # Preprocess the image
23
+ def preprocess_image(image):
24
+ transform = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28
+ ])
29
+ return transform(image).unsqueeze(0)
30
+
31
+ def extract_embedding(image):
32
+ image_tensor = preprocess_image(image).to(device)
33
+ with torch.no_grad():
34
+ output = resnet(image_tensor)
35
+ embedding = output.pooler_output
36
+ return embedding
37
+
38
+ def cosine_similarity(embedding1, embedding2):
39
+ return F.cosine_similarity(embedding1, embedding2)
40
+
41
+ def l2_distance(embedding1, embedding2):
42
+ return torch.norm(embedding1 - embedding2, p=2)
43
+
44
+ def save_array_to_temp_image(arr):
45
+ rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
46
+ img = Image.fromarray(rgb_arr)
47
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
48
+ temp_file_name = temp_file.name
49
+ temp_file.close()
50
+ img.save(temp_file_name)
51
+ return temp_file_name
52
+
53
+ def detect_and_crop(target_image, query_image, threshold=0.6, nms_threshold=0.3):
54
+ target_sizes = torch.Tensor([target_image.size[::-1]])
55
+ inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
56
+ with torch.no_grad():
57
+ outputs = model.image_guided_detection(**inputs)
58
+
59
+ img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
60
+ outputs.logits = outputs.logits.cpu()
61
+ outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
62
+
63
+ results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
64
+ boxes, scores = results[0]["boxes"], results[0]["scores"]
65
+
66
+ if len(boxes) == 0:
67
+ return []
68
+
69
+ filtered_boxes = []
70
+ for box in boxes:
71
+ x1, y1, x2, y2 = [int(i) for i in box.tolist()]
72
+ cropped_img = img[y1:y2, x1:x2]
73
+ if cropped_img.size != 0:
74
+ filtered_boxes.append(cropped_img)
75
+
76
+ return filtered_boxes
77
+
78
+ def process_video(video_path, query_image, skipframes=0):
79
+ cap = cv2.VideoCapture(video_path)
80
+ if not cap.isOpened():
81
+ return
82
+
83
+ frame_count = 0
84
+ all_results = []
85
+ while True:
86
+ ret, frame = cap.read()
87
+ if not ret:
88
+ break
89
+ if frame_count % (skipframes + 1) == 0:
90
+ frame_file = save_array_to_temp_image(frame)
91
+ result_frames = detect_and_crop(Image.open(frame_file), query_image)
92
+ for res in result_frames:
93
+ saved_res = save_array_to_temp_image(res)
94
+ embedding1 = extract_embedding(query_image)
95
+ embedding2 = extract_embedding(Image.open(saved_res))
96
+ dist = l2_distance(embedding1, embedding2).item()
97
+ cos = cosine_similarity(embedding1, embedding2).item()
98
+ all_results.append({'l2_dist': dist, 'cos': cos})
99
+ frame_count += 1
100
+ cap.release()
101
+ return all_results
102
+
103
+ def process_videos_and_compare(image, video, skipframes=5, threshold=0.47):
104
+ def median(values):
105
+ n = len(values)
106
+ return (values[n // 2 - 1] + values[n // 2]) / 2 if n % 2 == 0 else values[n // 2]
107
+
108
+ results = process_video(video, image, skipframes)
109
+ if results:
110
+ l2_dists = [item['l2_dist'] for item in results]
111
+ cosines = [item['cos'] for item in results]
112
+ avg_l2_dist = sum(l2_dists) / len(l2_dists)
113
+ avg_cos = sum(cosines) / len(cosines)
114
+ median_l2_dist = median(sorted(l2_dists))
115
+ median_cos = median(sorted(cosines))
116
+ result = {
117
+ "avg_l2_dist": avg_l2_dist,
118
+ "avg_cos": avg_cos,
119
+ "median_l2_dist": median_l2_dist,
120
+ "median_cos": median_cos,
121
+ "avg_cos_dist": 1 - avg_cos,
122
+ "median_cos_dist": 1 - median_cos,
123
+ "is_present": avg_cos >= threshold
124
+ }
125
+ else:
126
+ result = {
127
+ "avg_l2_dist": float('inf'),
128
+ "avg_cos": 0,
129
+ "median_l2_dist": float('inf'),
130
+ "median_cos": 0,
131
+ "avg_cos_dist": float('inf'),
132
+ "median_cos_dist": float('inf'),
133
+ "is_present": False
134
+ }
135
+ return result
136
+
137
+ def interface(video, image, skipframes, threshold):
138
+ result = process_videos_and_compare(image, video, skipframes, threshold)
139
+ return result
140
+
141
+ iface = gr.Interface(
142
+ fn=interface,
143
+ inputs=[
144
+ gr.Video(label="Upload a Video"),
145
+ gr.Image(type="pil", label="Upload a Query Image"),
146
+ gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Skip Frames"),
147
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.47, label="Threshold")
148
+ ],
149
+ outputs=[
150
+ gr.JSON(label="Result")
151
+ ],
152
+ title="Object Detection in Video",
153
+ description="""
154
+ **Instructions:**
155
+
156
+ 1. **Upload a Video**: Select a video file to upload.
157
+ 2. **Upload a Query Image**: Select an image file that contains the object you want to detect in the video.
158
+ 3. **Set Skip Frames**: Adjust the slider to set the number of frames to skip between each processing.
159
+ 4. **Set Threshold**: Adjust the slider to set the threshold for cosine similarity to determine if the object is present in the video.
160
+ 5. **View Results**: The result will show the average and median distances and similarities, and whether the object is present in the video based on the threshold.
161
+ """
162
+ )
163
+
164
+ if __name__ == "__main__":
165
+ iface.launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ opencv-python
6
+ requests
7
+ matplotlib
8
+ numpy
9
+ fastapi
10
+ uvicorn
11
+ gunicorn
12
+ pathlib
13
+ argparse
14
+ ipython
15
+ moviepy