nomanmanzoor commited on
Commit
504b1c1
Β·
verified Β·
1 Parent(s): b572a4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -71
app.py CHANGED
@@ -1,92 +1,65 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
- from torchvision import transforms
5
- from torchvision.models.detection import fasterrcnn_resnet50_fpn
6
  import torchvision
 
 
7
 
8
- # Load model
9
- model = fasterrcnn_resnet50_fpn(pretrained=True)
10
  model.eval()
11
 
12
- # Define class labels
13
  COCO_INSTANCE_CATEGORY_NAMES = [
14
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
15
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
16
  'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
17
- 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
18
- 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
19
- 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
20
- 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
21
- 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
22
- 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
23
- 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
24
- 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
25
- 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
26
- 'hair drier', 'toothbrush'
27
  ]
28
 
29
- def get_prediction(img, threshold):
30
- transform = transforms.Compose([transforms.ToTensor()])
31
- img = transform(img)
32
- pred = model([img])
33
- pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
34
- pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
35
- pred_score = list(pred[0]['scores'].detach().numpy())
36
- pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
37
- boxes = pred_boxes[:pred_t+1]
38
- classes = pred_classes[:pred_t+1]
39
- return boxes, classes
40
 
41
- # UI design
42
- st.set_page_config(page_title="AI Object Detector", layout="wide")
43
 
44
- st.markdown("""
45
- <style>
46
- .main {
47
- background-color: #f5f7fa;
48
- padding: 20px;
49
- border-radius: 10px;
50
- }
51
- h1 {
52
- color: #2c3e50;
53
- }
54
- .stButton>button {
55
- background-color: #008CBA;
56
- color: white;
57
- font-weight: bold;
58
- border-radius: 8px;
59
- padding: 10px 24px;
60
- }
61
- </style>
62
- """, unsafe_allow_html=True)
63
-
64
- st.title("πŸ” AI Object Detection App")
65
- st.markdown("Upload an image and let the AI detect what's in it!")
66
-
67
- img_file = st.file_uploader("πŸ“Έ Upload an Image", type=["jpg", "jpeg", "png"])
68
-
69
- confidence = st.slider("🎯 Confidence Threshold", 0.0, 1.0, 0.5)
70
-
71
- if img_file is not None:
72
- image = Image.open(img_file).convert("RGB")
73
  st.image(image, caption="Uploaded Image", use_column_width=True)
74
 
75
- boxes, classes = get_prediction(image, confidence)
 
 
76
 
77
- # Draw results
78
- import matplotlib.pyplot as plt
79
- import matplotlib.patches as patches
80
 
81
- fig, ax = plt.subplots(1, figsize=(12, 8))
 
82
  ax.imshow(image)
83
- for i in range(len(boxes)):
84
- box = boxes[i]
85
- label = classes[i]
86
- rect = patches.Rectangle(box[0], box[1][0]-box[0][0], box[1][1]-box[0][1],
87
- linewidth=2, edgecolor='blue', facecolor='none')
88
- ax.add_patch(rect)
89
- ax.text(box[0][0], box[0][1]-10, label, fontsize=12,
90
- color='black', bbox=dict(facecolor='lightblue', edgecolor='blue', boxstyle='round,pad=0.5'))
 
 
 
 
 
 
91
  st.pyplot(fig)
92
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
+ import torchvision.transforms as T
 
5
  import torchvision
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.patches as patches
8
 
9
+ # Load pre-trained object detection model
10
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
11
  model.eval()
12
 
13
+ # COCO class labels
14
  COCO_INSTANCE_CATEGORY_NAMES = [
15
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
16
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
17
  'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
18
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
19
+ 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
20
+ 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
21
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
22
+ 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
23
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table',
24
+ 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
25
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
26
+ 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
 
27
  ]
28
 
29
+ # Streamlit UI
30
+ st.set_page_config(page_title="AI Object Detector", layout="centered")
31
+ st.title("🎯 Object Detection with AI")
32
+ st.markdown("Upload an image and let the AI detect objects with names!")
 
 
 
 
 
 
 
33
 
34
+ uploaded_file = st.file_uploader("πŸ“· Upload an image", type=["jpg", "png", "jpeg"])
 
35
 
36
+ if uploaded_file:
37
+ image = Image.open(uploaded_file).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  st.image(image, caption="Uploaded Image", use_column_width=True)
39
 
40
+ # Transform image for model input
41
+ transform = T.Compose([T.ToTensor()])
42
+ image_tensor = transform(image).unsqueeze(0)
43
 
44
+ with st.spinner("πŸ” Detecting objects..."):
45
+ predictions = model(image_tensor)[0]
 
46
 
47
+ # Draw bounding boxes
48
+ fig, ax = plt.subplots(1)
49
  ax.imshow(image)
50
+
51
+ threshold = 0.7 # confidence threshold
52
+ for idx in range(len(predictions["boxes"])):
53
+ score = predictions["scores"][idx].item()
54
+ if score > threshold:
55
+ box = predictions["boxes"][idx].detach().numpy()
56
+ label = COCO_INSTANCE_CATEGORY_NAMES[predictions["labels"][idx]]
57
+ x1, y1, x2, y2 = box
58
+ rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
59
+ linewidth=2, edgecolor='lime', facecolor='none')
60
+ ax.add_patch(rect)
61
+ ax.text(x1, y1 - 10, f"{label} ({score:.2f})", color='lime',
62
+ fontsize=10, backgroundcolor='black')
63
+
64
  st.pyplot(fig)
65