Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| from torchvision.transforms import ToTensor, Resize | |
| import numpy as np | |
| from PIL import Image | |
| import math | |
| from obj2html import obj2html | |
| import os | |
| # DEPTH IMAGE TO OBJ | |
| minDepth=10 | |
| maxDepth=1000 | |
| def my_DepthNorm(x, maxDepth): | |
| return maxDepth / x | |
| def vete(v, vt): | |
| if v == vt: | |
| return str(v) | |
| return str(v)+"/"+str(vt) | |
| def create_obj(img, objPath='model.obj', mtlPath='model.mtl', matName='colored', useMaterial=False): | |
| w = img.shape[1] | |
| h = img.shape[0] | |
| FOV = math.pi/4 | |
| D = (img.shape[0]/2)/math.tan(FOV/2) | |
| if max(objPath.find('\\'), objPath.find('/')) > -1: | |
| os.makedirs(os.path.dirname(mtlPath), exist_ok=True) | |
| with open(objPath, "w") as f: | |
| if useMaterial: | |
| f.write("mtllib " + mtlPath + "\n") | |
| f.write("usemtl " + matName + "\n") | |
| ids = np.zeros((img.shape[1], img.shape[0]), int) | |
| vid = 1 | |
| all_x = [] | |
| all_y = [] | |
| all_z = [] | |
| for u in range(0, w): | |
| for v in range(h-1, -1, -1): | |
| d = img[v, u] | |
| ids[u, v] = vid | |
| if d == 0.0: | |
| ids[u, v] = 0 | |
| vid += 1 | |
| x = u - w/2 | |
| y = v - h/2 | |
| z = -D | |
| norm = 1 / math.sqrt(x*x + y*y + z*z) | |
| t = d/(z*norm) | |
| x = -t*x*norm | |
| y = t*y*norm | |
| z = -t*z*norm | |
| f.write("v " + str(x) + " " + str(y) + " " + str(z) + "\n") | |
| for u in range(0, img.shape[1]): | |
| for v in range(0, img.shape[0]): | |
| f.write("vt " + str(u/img.shape[1]) + | |
| " " + str(v/img.shape[0]) + "\n") | |
| for u in range(0, img.shape[1]-1): | |
| for v in range(0, img.shape[0]-1): | |
| v1 = ids[u, v] | |
| v3 = ids[u+1, v] | |
| v2 = ids[u, v+1] | |
| v4 = ids[u+1, v+1] | |
| if v1 == 0 or v2 == 0 or v3 == 0 or v4 == 0: | |
| continue | |
| f.write("f " + vete(v1, v1) + " " + | |
| vete(v2, v2) + " " + vete(v3, v3) + "\n") | |
| f.write("f " + vete(v3, v3) + " " + | |
| vete(v2, v2) + " " + vete(v4, v4) + "\n") | |
| # MODEL | |
| class UpSample(nn.Sequential): | |
| def __init__(self, skip_input, output_features): | |
| super(UpSample, self).__init__() | |
| self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1) | |
| self.leakyreluA = nn.LeakyReLU(0.2) | |
| self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1) | |
| self.leakyreluB = nn.LeakyReLU(0.2) | |
| def forward(self, x, concat_with): | |
| up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) | |
| return self.leakyreluB( self.convB( self.convA( torch.cat([up_x, concat_with], dim=1) ) ) ) | |
| class Decoder(nn.Module): | |
| def __init__(self, num_features=1664, decoder_width = 1.0): | |
| super(Decoder, self).__init__() | |
| features = int(num_features * decoder_width) | |
| self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0) | |
| self.up1 = UpSample(skip_input=features//1 + 256, output_features=features//2) | |
| self.up2 = UpSample(skip_input=features//2 + 128, output_features=features//4) | |
| self.up3 = UpSample(skip_input=features//4 + 64, output_features=features//8) | |
| self.up4 = UpSample(skip_input=features//8 + 64, output_features=features//16) | |
| self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1) | |
| def forward(self, features): | |
| x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[12] | |
| x_d0 = self.conv2(F.relu(x_block4)) | |
| x_d1 = self.up1(x_d0, x_block3) | |
| x_d2 = self.up2(x_d1, x_block2) | |
| x_d3 = self.up3(x_d2, x_block1) | |
| x_d4 = self.up4(x_d3, x_block0) | |
| return self.conv3(x_d4) | |
| class Encoder(nn.Module): | |
| def __init__(self): | |
| super(Encoder, self).__init__() | |
| self.original_model = models.densenet169( pretrained=False ) | |
| def forward(self, x): | |
| features = [x] | |
| for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) ) | |
| return features | |
| class PTModel(nn.Module): | |
| def __init__(self): | |
| super(PTModel, self).__init__() | |
| self.encoder = Encoder() | |
| self.decoder = Decoder() | |
| def forward(self, x): | |
| return self.decoder( self.encoder(x) ) | |
| model = PTModel().float() | |
| path = "https://github.com/nicolalandro/DenseDepth/releases/download/0.1/nyu.pth" | |
| model.load_state_dict(torch.hub.load_state_dict_from_url(path, progress=True)) | |
| model.eval() | |
| def predict(inp): | |
| width, height = inp.size | |
| if width > height: | |
| scale_fn = Resize((640, int((640*width)/height))) | |
| else: | |
| scale_fn = Resize((int((640*height)/width), 640)) | |
| res_img = scale_fn(inp) | |
| torch_image = ToTensor()(res_img) | |
| images = torch_image.unsqueeze(0) | |
| with torch.no_grad(): | |
| predictions = model(images) | |
| output = np.clip(my_DepthNorm(predictions.numpy(), maxDepth=maxDepth), minDepth, maxDepth) / maxDepth | |
| depth = output[0,0,:,:] | |
| img = Image.fromarray(np.uint8(depth*255)) | |
| create_obj(depth, 'model.obj') | |
| html_string = obj2html('model.obj', html_elements_only=True) | |
| return res_img, img, html_string | |
| # STREAMLIT | |
| uploader = st.file_uploader('Wait the demo file to be rendered and upload your favourite image here.',type=['jpg','jpeg','png']) | |
| if uploader is not None: | |
| pil_image = Image.open(uploader) | |
| else: | |
| pil_image = Image.open('119_image.png') | |
| with st.spinner("Waiting for the predictions..."): | |
| pil_scaled, pil_depth, html_string = predict(pil_image) | |
| components.html(html_string) | |
| #st.markdown(html_string, unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.image(pil_scaled) | |
| with col2: | |
| st.image(pil_depth) | |
| with col3: | |
| with open('model.obj') as f: | |
| st.download_button('Download model.obj', f, file_name="model.obj") | |
| os.remove('model.obj') | |
| pil_depth.save('tmp.png') | |
| with open('tmp.png', "rb") as f: | |
| st.download_button('Download depth.png', f,file_name="depth.png", mime="image/png") | |
| os.remove('tmp.png') | |