Huzaifa424 commited on
Commit
9e3edef
·
verified ·
1 Parent(s): 6177fe6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+
7
+ # Load the model and processor
8
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
9
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
+
11
+ def detect_objects(image, object_types):
12
+ try:
13
+ # Convert registered object types to lowercase
14
+ object_types = [obj.strip().lower() for obj in object_types.split(",")]
15
+
16
+ inputs = processor(images=image, return_tensors="pt")
17
+ outputs = model(**inputs)
18
+
19
+ # Post-process the outputs to get the bounding boxes
20
+ target_sizes = torch.tensor([image.size[::-1]])
21
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
22
+
23
+ detected_objects = []
24
+ picking_positions = []
25
+ total_count = 0
26
+
27
+ for idx, (label, box) in enumerate(zip(results["labels"], results["boxes"]), start=1):
28
+ object_type = model.config.id2label[label.item()].lower()
29
+ if object_type in object_types:
30
+ box = [round(i, 2) for i in box.tolist()]
31
+ picking_position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
32
+ detected_objects.append(f"Object {idx}: {model.config.id2label[label.item()].capitalize()}")
33
+ picking_positions.append(picking_position)
34
+ total_count += 1
35
+
36
+ if not detected_objects:
37
+ return "No registered objects detected.", picking_positions, total_count
38
+
39
+ return "\n".join(detected_objects), picking_positions, total_count
40
+
41
+ except Exception as e:
42
+ return str(e), [], 0
43
+
44
+ # Streamlit app
45
+ st.title("Object Detection")
46
+ st.write("Upload an image, register object types (comma-separated), and the app will detect, count, and find the best picking positions for the registered objects.")
47
+
48
+ # Image upload
49
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
50
+ object_types = st.text_input("Registered Object Types (comma separated, e.g., 'cat, dog')")
51
+
52
+ if uploaded_file is not None:
53
+ image = Image.open(uploaded_file)
54
+ st.image(image, caption="Uploaded Image", use_column_width=True)
55
+
56
+ if object_types:
57
+ detected_objects, picking_positions, total_count = detect_objects(image, object_types)
58
+ result = f"{detected_objects}\n\nPicking Positions: {picking_positions}\nTotal Count: {total_count}"
59
+ st.text_area("Detection Results", value=result, height=200)
60
+