Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import tempfile | |
| from PIL import Image | |
| import traceback | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2 import model_zoo | |
| from detectron2.utils.visualizer import Visualizer | |
| from detectron2.data import MetadataCatalog | |
| # Setup page config | |
| st.set_page_config( | |
| page_title="Object Detection with Detectron2", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Print environment info for debugging | |
| if "debug" in st.experimental_get_query_params(): | |
| st.write("Environment variables:", dict(os.environ)) | |
| st.write("Current working directory:", os.getcwd()) | |
| st.write("Directory contents:", os.listdir()) | |
| st.write("Temp directory:", tempfile.gettempdir()) | |
| st.write("PyTorch CUDA available:", torch.cuda.is_available()) | |
| # Load the Detectron2 model | |
| def load_model(): | |
| try: | |
| # Set the configuration for the model | |
| cfg = get_cfg() | |
| cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for detection confidence | |
| cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") | |
| # Use GPU if available, otherwise use CPU | |
| if torch.cuda.is_available(): | |
| st.sidebar.success("GPU is available! Using CUDA.") | |
| cfg.MODEL.DEVICE = "cuda" | |
| else: | |
| st.sidebar.info("GPU not available. Using CPU.") | |
| cfg.MODEL.DEVICE = "cpu" | |
| # Initialize the predictor | |
| predictor = DefaultPredictor(cfg) | |
| return predictor | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| st.error(traceback.format_exc()) | |
| return None | |
| # Function for image prediction | |
| def predict_fn(image, predictor): | |
| try: | |
| # Convert the PIL image to a format the model can use | |
| image_array = np.array(image.convert("RGB")) | |
| # Make predictions | |
| outputs = predictor(image_array) | |
| # Get the predicted classes and bounding boxes | |
| instances = outputs["instances"].to("cpu") | |
| pred_classes = instances.pred_classes.numpy() | |
| pred_boxes = instances.pred_boxes.tensor.numpy() | |
| scores = instances.scores.numpy() | |
| # Get class names from metadata | |
| metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val") | |
| class_names = metadata.thing_classes | |
| return pred_classes, pred_boxes, scores, image_array, class_names | |
| except Exception as e: | |
| st.error(f"Error during prediction: {e}") | |
| st.error(traceback.format_exc()) | |
| return None, None, None, np.array(image), None | |
| # Function to visualize predictions | |
| def visualize_predictions(image_array, instances, class_names): | |
| try: | |
| # Create a visualizer object | |
| metadata = MetadataCatalog.get("coco_2017_val") | |
| if class_names is not None: | |
| metadata.thing_classes = class_names | |
| v = Visualizer(image_array, metadata=metadata, scale=1.2) | |
| # Draw the predictions on the image | |
| vis_output = v.draw_instance_predictions(instances.to("cpu")) | |
| # Return the visualized image | |
| return vis_output.get_image() | |
| except Exception as e: | |
| st.error(f"Error during visualization: {e}") | |
| st.error(traceback.format_exc()) | |
| return image_array | |
| # Main application | |
| def main(): | |
| st.title("π Object Detection with Detectron2") | |
| st.markdown(""" | |
| Upload an image to detect objects using Facebook AI Research's Detectron2 model. | |
| This demo uses the Faster R-CNN model with ResNet-50-FPN backbone. | |
| """) | |
| # Create a sidebar | |
| st.sidebar.title("Settings") | |
| confidence_threshold = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05) | |
| # Load the model | |
| with st.spinner("Loading model... This might take a minute on first run."): | |
| predictor = load_model() | |
| if predictor is None: | |
| st.error("Failed to load model. Please check the error messages above.") | |
| return | |
| # Update model confidence threshold | |
| predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold | |
| # Upload image | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| col1, col2 = st.columns(2) | |
| if uploaded_file is not None: | |
| # Display original image | |
| image = Image.open(uploaded_file) | |
| with col1: | |
| st.subheader("Original Image") | |
| st.image(image, use_column_width=True) | |
| # Process the image | |
| with st.spinner("Detecting objects..."): | |
| try: | |
| # Make predictions | |
| pred_classes, pred_boxes, scores, image_array, class_names = predict_fn(image, predictor) | |
| if pred_classes is not None and len(pred_classes) > 0: | |
| # Reconstruct instances for visualization | |
| from detectron2.structures import Instances | |
| h, w = image_array.shape[:2] | |
| instances = Instances((h, w)) | |
| instances.pred_boxes = torch.tensor(pred_boxes) | |
| instances.scores = torch.tensor(scores) | |
| instances.pred_classes = torch.tensor(pred_classes) | |
| # Visualize predictions | |
| result_image = visualize_predictions(image_array, instances, class_names) | |
| # Display results | |
| with col2: | |
| st.subheader("Detected Objects") | |
| st.image(result_image, use_column_width=True) | |
| # Display detection details | |
| st.subheader("Detection Results") | |
| # Create a table for results | |
| results_data = [] | |
| for i, (cls, box, score) in enumerate(zip(pred_classes, pred_boxes, scores)): | |
| class_name = class_names[cls] if class_names is not None else f"Class {cls}" | |
| results_data.append({ | |
| "Object": class_name, | |
| "Confidence": f"{score:.2f}", | |
| "Bounding Box": f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}]" | |
| }) | |
| # Display results as table | |
| if results_data: | |
| st.table(results_data) | |
| else: | |
| st.info("No objects detected.") | |
| else: | |
| with col2: | |
| st.info("No objects detected in the image.") | |
| except Exception as e: | |
| st.error(f"Error processing image: {e}") | |
| st.error(traceback.format_exc()) | |
| else: | |
| # Show a sample image option | |
| if st.button("Use a sample image"): | |
| # You could include a sample image in your repository | |
| # and load it here to demonstrate the functionality | |
| st.info("Sample image option selected - this would load a demo image if implemented") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("Built with Streamlit and Facebook AI Research's Detectron2") | |
| if __name__ == "__main__": | |
| main() |