Reaper200 commited on
Commit
30222f3
·
verified ·
1 Parent(s): 8773442

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -10
app.py CHANGED
@@ -2,16 +2,21 @@ import streamlit as st
2
  from transformers import DetrForObjectDetection, DetrImageProcessor
3
  from PIL import Image
4
  import torch
 
 
5
 
6
- # Load model and processor
7
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
8
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
9
 
10
  st.title("Context-Aware Object Detection")
 
11
 
12
  # Upload an image
13
- uploaded_file = st.file_uploader("Choose an image...", type="jpg")
 
14
  if uploaded_file is not None:
 
15
  image = Image.open(uploaded_file)
16
  st.image(image, caption="Uploaded Image", use_column_width=True)
17
 
@@ -19,15 +24,30 @@ if uploaded_file is not None:
19
  inputs = processor(images=image, return_tensors="pt")
20
  outputs = model(**inputs)
21
 
22
- # Extract and display bounding boxes and labels
23
- logits = outputs.logits.softmax(-1)
24
- boxes = outputs.pred_boxes
25
 
26
- # Define a confidence threshold
27
  threshold = 0.9
28
- for logit, box in zip(logits[0], boxes[0]):
29
- score, label = logit.max(0)
 
 
 
 
 
 
 
30
  if score > threshold:
31
- st.write(f"Detected object with confidence {score:.2f}")
 
 
 
 
 
 
 
 
 
32
 
33
- st.write("Detection complete!")
 
2
  from transformers import DetrForObjectDetection, DetrImageProcessor
3
  from PIL import Image
4
  import torch
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as patches
7
 
8
+ # Load the model and processor
9
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
10
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
11
 
12
  st.title("Context-Aware Object Detection")
13
+ st.write("Upload an image to detect objects with contextual awareness.")
14
 
15
  # Upload an image
16
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
17
+
18
  if uploaded_file is not None:
19
+ # Open the uploaded image
20
  image = Image.open(uploaded_file)
21
  st.image(image, caption="Uploaded Image", use_column_width=True)
22
 
 
24
  inputs = processor(images=image, return_tensors="pt")
25
  outputs = model(**inputs)
26
 
27
+ # Get logits and bounding boxes
28
+ logits = outputs.logits.softmax(-1)[0]
29
+ boxes = outputs.pred_boxes[0]
30
 
31
+ # Set a confidence threshold for displaying boxes
32
  threshold = 0.9
33
+ labels = processor.tokenizer.convert_ids_to_tokens(logits.argmax(-1))
34
+ scores = logits.max(-1).values
35
+
36
+ # Display the image with bounding boxes
37
+ fig, ax = plt.subplots(1)
38
+ ax.imshow(image)
39
+
40
+ # Plot each detected object if it meets the confidence threshold
41
+ for score, label, box in zip(scores, labels, boxes):
42
  if score > threshold:
43
+ # Convert bounding box coordinates to absolute pixel values
44
+ x, y, w, h = box * torch.tensor([image.width, image.height, image.width, image.height])
45
+ x0, y0 = x - w / 2, y - h / 2
46
+
47
+ # Draw the bounding box
48
+ rect = patches.Rectangle((x0, y0), w, h, linewidth=2, edgecolor='r', facecolor='none')
49
+ ax.add_patch(rect)
50
+ ax.text(x0, y0, f"{label}: {score:.2f}", color='red', fontsize=8, weight='bold')
51
+
52
+ st.pyplot(fig)
53