cvdetectors commited on
Commit
4a7a94f
·
verified ·
1 Parent(s): ccb37f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torchvision
4
  from torchvision import transforms
5
  import numpy as np
6
- from PIL import Image
7
 
8
  # Use GPU if available
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -33,21 +33,37 @@ def count_persons(image):
33
  with autocast():
34
  outputs = model(img_tensor)[0]
35
 
36
- # Count persons (label 1 in COCO)
37
- person_count = sum(
38
- 1 for label, score in zip(outputs['labels'], outputs['scores'])
39
- if label.item() == 1 and score.item() > 0.65
40
- )
 
 
 
 
41
 
42
- return f"Number of persons detected: {person_count}"
43
 
44
- # Gradio interface for image upload
 
 
 
 
 
 
 
 
 
45
  demo = gr.Interface(
46
  fn=count_persons,
47
  inputs=gr.Image(type="pil", label="Upload Image"),
48
- outputs=gr.Text(label="Person Count"),
 
 
 
49
  title="Person Counter in Image (Fast)",
50
- description="Upload an image to count the number of people using a fast MobileNet-based detector. GPU supported."
51
  )
52
 
53
  if __name__ == "__main__":
 
3
  import torchvision
4
  from torchvision import transforms
5
  import numpy as np
6
+ from PIL import Image, ImageDraw
7
 
8
  # Use GPU if available
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
33
  with autocast():
34
  outputs = model(img_tensor)[0]
35
 
36
+ # Filter for persons (label 1 in COCO) with score > threshold
37
+ threshold = 0.65
38
+ boxes = outputs['boxes']
39
+ labels = outputs['labels']
40
+ scores = outputs['scores']
41
+ keep_indices = [
42
+ i for i, (label, score) in enumerate(zip(labels, scores))
43
+ if label.item() == 1 and score.item() > threshold
44
+ ]
45
 
46
+ person_count = len(keep_indices)
47
 
48
+ # Draw bounding boxes on the image
49
+ annotated_image = image.convert("RGB")
50
+ draw = ImageDraw.Draw(annotated_image)
51
+ for i in keep_indices:
52
+ box = boxes[i].cpu().numpy()
53
+ draw.rectangle([ (box[0], box[1]), (box[2], box[3]) ], outline="red", width=2)
54
+
55
+ return annotated_image, f"Number of persons detected: {person_count}"
56
+
57
+ # Gradio interface for image upload and outputs
58
  demo = gr.Interface(
59
  fn=count_persons,
60
  inputs=gr.Image(type="pil", label="Upload Image"),
61
+ outputs=[
62
+ gr.Image(type="pil", label="Annotated Image"),
63
+ gr.Text(label="Person Count")
64
+ ],
65
  title="Person Counter in Image (Fast)",
66
+ description="Upload an image to count the number of people and see bounding boxes using a fast MobileNet-based detector. GPU supported."
67
  )
68
 
69
  if __name__ == "__main__":