Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| import torch | |
| from PIL import Image | |
| import io | |
| # Load the model and processor | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") | |
| def detect_objects(image, object_types): | |
| try: | |
| # Convert registered object types to lowercase | |
| object_types = [obj.strip().lower() for obj in object_types.split(",")] | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| # Post-process the outputs to get the bounding boxes | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0] | |
| detected_objects = [] | |
| picking_positions = [] | |
| total_count = 0 | |
| for idx, (label, box) in enumerate(zip(results["labels"], results["boxes"]), start=1): | |
| object_type = model.config.id2label[label.item()].lower() | |
| if object_type in object_types: | |
| box = [round(i, 2) for i in box.tolist()] | |
| picking_position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) | |
| detected_objects.append(f"Object {idx}: {model.config.id2label[label.item()].capitalize()}") | |
| picking_positions.append(picking_position) | |
| total_count += 1 | |
| if not detected_objects: | |
| return "No registered objects detected.", picking_positions, total_count | |
| return "\n".join(detected_objects), picking_positions, total_count | |
| except Exception as e: | |
| return str(e), [], 0 | |
| # Streamlit app | |
| st.title("Object Detection") | |
| st.write("Upload an image, register object types (comma-separated), and the app will detect, count, and find the best picking positions for the registered objects.") | |
| # Image upload | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| object_types = st.text_input("Registered Object Types (comma separated, e.g., 'cat, dog')") | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| if object_types: | |
| detected_objects, picking_positions, total_count = detect_objects(image, object_types) | |
| result = f"{detected_objects}\n\nPicking Positions: {picking_positions}\nTotal Count: {total_count}" | |
| st.text_area("Detection Results", value=result, height=200) | |