alaahilal commited on
Commit
d6102b8
·
verified ·
1 Parent(s): 3489591

changed the file app

Browse files
Files changed (1) hide show
  1. app.py +34 -79
app.py CHANGED
@@ -1,71 +1,13 @@
1
  import streamlit as st
 
 
 
2
  from PIL import Image
3
- import torch
4
- from torchvision.models import detection
5
- from torchvision import transforms
6
- import numpy as np
7
-
8
- def load_model():
9
- """Load the model directly from torchvision"""
10
- with st.spinner('Loading model...'):
11
- model = detection.retinanet_resnet50_fpn(pretrained=True)
12
- model.eval()
13
- return model
14
-
15
- def get_prediction(image, model, threshold=0.5):
16
- """Get predictions for the image"""
17
- # Transform the image
18
- transform = transforms.Compose([
19
- transforms.ToTensor()
20
- ])
21
-
22
- img_tensor = transform(image)
23
-
24
- # Get prediction
25
- with torch.no_grad():
26
- prediction = model([img_tensor])
27
-
28
- # Get all the predicted class labels
29
- pred_classes = [COCO_CLASSES[i] for i in prediction[0]['labels'].numpy()]
30
-
31
- # Get all the predicted bounding boxes
32
- pred_boxes = prediction[0]['boxes'].numpy()
33
-
34
- # Get the predicted scores
35
- pred_scores = prediction[0]['scores'].numpy()
36
-
37
- # Filter predictions based on threshold
38
- mask = pred_scores >= threshold
39
- boxes = pred_boxes[mask]
40
- classes = np.array(pred_classes)[mask]
41
- scores = pred_scores[mask]
42
-
43
- return boxes, classes, scores
44
-
45
- # COCO class labels
46
- COCO_CLASSES = [
47
- '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
48
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
49
- 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
50
- 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
51
- 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
52
- 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
53
- 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
54
- 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
55
- 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
56
- 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
57
- 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
58
- 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
59
- ]
60
 
61
  def main():
62
  st.title("Object Detection App")
63
  st.write("Upload an image and get object detections!")
64
 
65
- # Load model
66
- model = load_model()
67
- st.success("Model loaded successfully!")
68
-
69
  # File uploader
70
  uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
71
 
@@ -73,36 +15,49 @@ def main():
73
  # Display the uploaded image
74
  image = Image.open(uploaded_file)
75
  st.image(image, caption='Uploaded Image', use_column_width=True)
 
 
 
 
76
 
77
  # Button to start detection
78
  if st.button('Start Detection'):
79
- with st.spinner('Performing detection...'):
 
80
  try:
81
- # Get predictions
82
- boxes, classes, scores = get_prediction(image, model)
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Display results
85
  st.success('Detection Complete!')
86
 
87
- # Create a nice looking table for detections
88
- st.write("### Detected Objects:")
 
89
 
90
- # Create columns for better visualization
91
- col1, col2 = st.columns(2)
92
-
93
- with col1:
94
- st.write("Object")
95
- with col2:
96
- st.write("Confidence")
97
-
98
- for cls, score in zip(classes, scores):
99
- with col1:
100
- st.write(f"{cls}")
101
- with col2:
102
- st.write(f"{score*100:.2f}%")
103
 
104
  except Exception as e:
105
  st.error(f"An error occurred: {str(e)}")
106
 
 
 
 
 
 
 
107
  if __name__ == "__main__":
108
  main()
 
1
  import streamlit as st
2
+ from imageai.Detection import ObjectDetection
3
+ import os
4
+ import time
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def main():
8
  st.title("Object Detection App")
9
  st.write("Upload an image and get object detections!")
10
 
 
 
 
 
11
  # File uploader
12
  uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
13
 
 
15
  # Display the uploaded image
16
  image = Image.open(uploaded_file)
17
  st.image(image, caption='Uploaded Image', use_column_width=True)
18
+
19
+ # Save the uploaded file temporarily
20
+ with open("temp_image.jpg", "wb") as f:
21
+ f.write(uploaded_file.getbuffer())
22
 
23
  # Button to start detection
24
  if st.button('Start Detection'):
25
+ # Spinner while model loads
26
+ with st.spinner('Loading model and performing detection...'):
27
  try:
28
+ execution_path = os.getcwd()
29
+ detector = ObjectDetection()
30
+ detector.setModelTypeAsRetinaNet()
31
+ detector.setModelPath(os.path.join(execution_path, "retinanet_resnet50_fpn_coco-eeacb38b.pth"))
32
+ detector.loadModel()
33
+
34
+ # Perform detection
35
+ detections = detector.detectObjectsFromImage(
36
+ input_image="temp_image.jpg",
37
+ output_image_path="output_image.jpg",
38
+ minimum_percentage_probability=10
39
+ )
40
 
41
  # Display results
42
  st.success('Detection Complete!')
43
 
44
+ # Display detected image
45
+ detected_image = Image.open("output_image.jpg")
46
+ st.image(detected_image, caption='Detected Objects', use_column_width=True)
47
 
48
+ # Display detections with probabilities
49
+ st.write("### Detected Objects:")
50
+ for obj in detections:
51
+ st.write(f"- {obj['name']}: {obj['percentage_probability']:.2f}%")
 
 
 
 
 
 
 
 
 
52
 
53
  except Exception as e:
54
  st.error(f"An error occurred: {str(e)}")
55
 
56
+ # Clean up temporary files
57
+ if os.path.exists("temp_image.jpg"):
58
+ os.remove("temp_image.jpg")
59
+ if os.path.exists("output_image.jpg"):
60
+ os.remove("output_image.jpg")
61
+
62
  if __name__ == "__main__":
63
  main()