Girishug commited on
Commit
6ceca29
·
verified ·
1 Parent(s): 8e98acf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -2
app.py CHANGED
@@ -23,9 +23,29 @@ transform = transforms.Compose([
23
  transforms.ToTensor(),
24
  ])
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Prediction function
27
  def predict(image):
28
  try:
 
 
 
 
29
  # Transform the image
30
  image_tensor = transform(image).unsqueeze(0) # Add batch dimension
31
  with torch.no_grad():
@@ -44,11 +64,12 @@ def predict(image):
44
  # Convert the input image to a NumPy array
45
  image_np = np.array(image)
46
 
47
- # Draw boxes on the image
48
  for box, label in zip(boxes, labels):
49
  x1, y1, x2, y2 = box.astype(int)
50
  image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
51
- image_np = cv2.putText(image_np, str(label), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
 
52
 
53
  # Ensure the output is in the correct format
54
  return Image.fromarray(image_np.astype(np.uint8))
 
23
  transforms.ToTensor(),
24
  ])
25
 
26
+ # COCO class names
27
+ COCO_CLASSES = [
28
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
29
+ "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
30
+ "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
31
+ "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
32
+ "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
33
+ "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
34
+ "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
35
+ "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
36
+ "potted plant", "bed", "dining table", "toilet", "TV", "laptop", "mouse",
37
+ "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
38
+ "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
39
+ "toothbrush"
40
+ ]
41
+
42
  # Prediction function
43
  def predict(image):
44
  try:
45
+ # Convert to RGB if the image has an alpha channel
46
+ if image.mode != 'RGB':
47
+ image = image.convert('RGB')
48
+
49
  # Transform the image
50
  image_tensor = transform(image).unsqueeze(0) # Add batch dimension
51
  with torch.no_grad():
 
64
  # Convert the input image to a NumPy array
65
  image_np = np.array(image)
66
 
67
+ # Draw boxes and labels on the image
68
  for box, label in zip(boxes, labels):
69
  x1, y1, x2, y2 = box.astype(int)
70
  image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
71
+ class_name = COCO_CLASSES[label] # Get the class name
72
+ image_np = cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
73
 
74
  # Ensure the output is in the correct format
75
  return Image.fromarray(image_np.astype(np.uint8))