Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from segment_anything import sam_model_registry, SamPredictor | |
| def load_models(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load SAM (vit_b) | |
| sam_checkpoint = "sam_vit_b_01ec64.pth" | |
| model_type = "vit_b" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) | |
| predictor = SamPredictor(sam) | |
| # Load MiDaS | |
| midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device) | |
| midas.eval() | |
| midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
| transform = midas_transforms.dpt_transform | |
| return predictor, midas, transform | |
| predictor, midas_model, midas_transform = load_models() | |
| st.title("SAM + MiDaS Depth App") | |
| uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
| if uploaded_file: | |
| image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| st.image(image_rgb, caption="Original Image", use_column_width=True) | |
| # Ask for click input | |
| st.write("Click a point for segmentation") | |
| coords = st.image(image_rgb, use_column_width=True) | |
| # For now, run depth estimation directly | |
| input_tensor = midas_transform(image_rgb).to("cuda" if torch.cuda.is_available() else "cpu") | |
| with torch.no_grad(): | |
| depth = midas_model(input_tensor.unsqueeze(0)) | |
| depth = torch.nn.functional.interpolate( | |
| depth.unsqueeze(1), | |
| size=image_rgb.shape[:2], | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze().cpu().numpy() | |
| st.image(depth, caption="Estimated Depth", use_column_width=True, clamp=True) | |