import streamlit as st import numpy as np import torch import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt import open3d as o3d # Load MiDaS model @st.cache_resource def load_model(): model_type = "DPT_Large" # Use DPT_Large for higher accuracy model = torch.hub.load("intel-isl/MiDaS", model_type) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") transform = midas_transforms.default_transform return model, transform, device model, transform, device = load_model() # Streamlit app UI st.title("2D to 3D Image Converter") st.write("Upload a 2D image to generate its 3D depth map and point cloud.") # File uploader uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"]) if uploaded_file: # Load the image input_image = Image.open(uploaded_file).convert("RGB") st.image(input_image, caption="Uploaded Image", use_column_width=True) # Preprocess the image input_batch = transform(input_image).unsqueeze(0).to(device) # Predict depth map st.write("Generating Depth Map...") with torch.no_grad(): prediction = model(input_batch) depth_map = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=input_image.size[::-1], mode="bicubic", align_corners=False, ).squeeze().cpu().numpy() # Display depth map st.write("Depth Map:") fig, ax = plt.subplots() ax.imshow(depth_map, cmap="plasma") ax.axis("off") st.pyplot(fig) # Generate 3D point cloud st.write("Generating 3D Point Cloud...") h, w = depth_map.shape xx, yy = np.meshgrid(np.arange(w), np.arange(h)) points = np.stack((xx, yy, depth_map), axis=-1).reshape(-1, 3) # Normalize points points -= points.mean(axis=0) points /= points.max(axis=0) # Create Open3D point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) # Save point cloud for download output_file = "point_cloud.ply" o3d.io.write_point_cloud(output_file, pcd) st.write("3D Point Cloud Generated!") st.download_button( label="Download 3D Point Cloud (.ply)", data=open(output_file, "rb").read(), file_name="3d_point_cloud.ply", mime="application/octet-stream", )