Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import open3d as o3d | |
| # Load MiDaS model | |
| def load_model(): | |
| model_type = "DPT_Large" # Use DPT_Large for higher accuracy | |
| model = torch.hub.load("intel-isl/MiDaS", model_type) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
| transform = midas_transforms.default_transform | |
| return model, transform, device | |
| model, transform, device = load_model() | |
| # Streamlit app UI | |
| st.title("2D to 3D Image Converter") | |
| st.write("Upload a 2D image to generate its 3D depth map and point cloud.") | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| # Load the image | |
| input_image = Image.open(uploaded_file).convert("RGB") | |
| st.image(input_image, caption="Uploaded Image", use_column_width=True) | |
| # Preprocess the image | |
| input_batch = transform(input_image).unsqueeze(0).to(device) | |
| # Predict depth map | |
| st.write("Generating Depth Map...") | |
| with torch.no_grad(): | |
| prediction = model(input_batch) | |
| depth_map = torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=input_image.size[::-1], | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze().cpu().numpy() | |
| # Display depth map | |
| st.write("Depth Map:") | |
| fig, ax = plt.subplots() | |
| ax.imshow(depth_map, cmap="plasma") | |
| ax.axis("off") | |
| st.pyplot(fig) | |
| # Generate 3D point cloud | |
| st.write("Generating 3D Point Cloud...") | |
| h, w = depth_map.shape | |
| xx, yy = np.meshgrid(np.arange(w), np.arange(h)) | |
| points = np.stack((xx, yy, depth_map), axis=-1).reshape(-1, 3) | |
| # Normalize points | |
| points -= points.mean(axis=0) | |
| points /= points.max(axis=0) | |
| # Create Open3D point cloud | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(points) | |
| # Save point cloud for download | |
| output_file = "point_cloud.ply" | |
| o3d.io.write_point_cloud(output_file, pcd) | |
| st.write("3D Point Cloud Generated!") | |
| st.download_button( | |
| label="Download 3D Point Cloud (.ply)", | |
| data=open(output_file, "rb").read(), | |
| file_name="3d_point_cloud.ply", | |
| mime="application/octet-stream", | |
| ) |