import streamlit as st import torch import cv2 import numpy as np from segment_anything import sam_model_registry, SamPredictor @st.cache_resource def load_models(): device = "cuda" if torch.cuda.is_available() else "cpu" # Load SAM (vit_b) sam_checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) predictor = SamPredictor(sam) # Load MiDaS midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device) midas.eval() midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") transform = midas_transforms.dpt_transform return predictor, midas, transform predictor, midas_model, midas_transform = load_models() st.title("SAM + MiDaS Depth App") uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) if uploaded_file: image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) st.image(image_rgb, caption="Original Image", use_column_width=True) # Ask for click input st.write("Click a point for segmentation") coords = st.image(image_rgb, use_column_width=True) # For now, run depth estimation directly input_tensor = midas_transform(image_rgb).to("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): depth = midas_model(input_tensor.unsqueeze(0)) depth = torch.nn.functional.interpolate( depth.unsqueeze(1), size=image_rgb.shape[:2], mode="bicubic", align_corners=False, ).squeeze().cpu().numpy() st.image(depth, caption="Estimated Depth", use_column_width=True, clamp=True)