import streamlit as st from transformers import DetrImageProcessor, DetrForObjectDetection import torch from PIL import Image import io # Load the model and processor processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") def detect_objects(image, object_types): try: # Convert registered object types to lowercase object_types = [obj.strip().lower() for obj in object_types.split(",")] inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) # Post-process the outputs to get the bounding boxes target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0] detected_objects = [] picking_positions = [] total_count = 0 for idx, (label, box) in enumerate(zip(results["labels"], results["boxes"]), start=1): object_type = model.config.id2label[label.item()].lower() if object_type in object_types: box = [round(i, 2) for i in box.tolist()] picking_position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) detected_objects.append(f"Object {idx}: {model.config.id2label[label.item()].capitalize()}") picking_positions.append(picking_position) total_count += 1 if not detected_objects: return "No registered objects detected.", picking_positions, total_count return "\n".join(detected_objects), picking_positions, total_count except Exception as e: return str(e), [], 0 # Streamlit app st.title("Object Detection") st.write("Upload an image, register object types (comma-separated), and the app will detect, count, and find the best picking positions for the registered objects.") # Image upload uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) object_types = st.text_input("Registered Object Types (comma separated, e.g., 'cat, dog')") if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) if object_types: detected_objects, picking_positions, total_count = detect_objects(image, object_types) result = f"{detected_objects}\n\nPicking Positions: {picking_positions}\nTotal Count: {total_count}" st.text_area("Detection Results", value=result, height=200)