Ghmustafa11 commited on
Commit
ed78524
·
verified ·
1 Parent(s): 7882a07

Create App. Py

Browse files
Files changed (1) hide show
  1. App. Py +79 -0
App. Py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import open3d as o3d
8
+
9
+ # Load MiDaS model
10
+ @st.cache_resource
11
+ def load_model():
12
+ model_type = "DPT_Large" # Use DPT_Large for higher accuracy
13
+ model = torch.hub.load("intel-isl/MiDaS", model_type)
14
+ model.eval()
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
+
18
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
19
+ transform = midas_transforms.default_transform
20
+ return model, transform, device
21
+
22
+ model, transform, device = load_model()
23
+
24
+ # Streamlit app UI
25
+ st.title("2D to 3D Image Converter")
26
+ st.write("Upload a 2D image to generate its 3D depth map and point cloud.")
27
+
28
+ # File uploader
29
+ uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
30
+ if uploaded_file:
31
+ # Load the image
32
+ input_image = Image.open(uploaded_file).convert("RGB")
33
+ st.image(input_image, caption="Uploaded Image", use_column_width=True)
34
+
35
+ # Preprocess the image
36
+ input_batch = transform(input_image).unsqueeze(0).to(device)
37
+
38
+ # Predict depth map
39
+ st.write("Generating Depth Map...")
40
+ with torch.no_grad():
41
+ prediction = model(input_batch)
42
+ depth_map = torch.nn.functional.interpolate(
43
+ prediction.unsqueeze(1),
44
+ size=input_image.size[::-1],
45
+ mode="bicubic",
46
+ align_corners=False,
47
+ ).squeeze().cpu().numpy()
48
+
49
+ # Display depth map
50
+ st.write("Depth Map:")
51
+ fig, ax = plt.subplots()
52
+ ax.imshow(depth_map, cmap="plasma")
53
+ ax.axis("off")
54
+ st.pyplot(fig)
55
+
56
+ # Generate 3D point cloud
57
+ st.write("Generating 3D Point Cloud...")
58
+ h, w = depth_map.shape
59
+ xx, yy = np.meshgrid(np.arange(w), np.arange(h))
60
+ points = np.stack((xx, yy, depth_map), axis=-1).reshape(-1, 3)
61
+
62
+ # Normalize points
63
+ points -= points.mean(axis=0)
64
+ points /= points.max(axis=0)
65
+
66
+ # Create Open3D point cloud
67
+ pcd = o3d.geometry.PointCloud()
68
+ pcd.points = o3d.utility.Vector3dVector(points)
69
+
70
+ # Save point cloud for download
71
+ output_file = "point_cloud.ply"
72
+ o3d.io.write_point_cloud(output_file, pcd)
73
+ st.write("3D Point Cloud Generated!")
74
+ st.download_button(
75
+ label="Download 3D Point Cloud (.ply)",
76
+ data=open(output_file, "rb").read(),
77
+ file_name="3d_point_cloud.ply",
78
+ mime="application/octet-stream",
79
+ )