|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import src.depth_pro |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
model, transform = depth_pro.create_model_and_transforms() |
|
|
model.eval() |
|
|
|
|
|
def predict_depth(input_image): |
|
|
|
|
|
result = depth_pro.load_rgb(input_image.name) |
|
|
image = result[0] |
|
|
f_px = result[-1] |
|
|
image = transform(image) |
|
|
|
|
|
|
|
|
prediction = model.infer(image, f_px=f_px) |
|
|
depth = prediction["depth"] |
|
|
focallength_px = prediction["focallength_px"] |
|
|
|
|
|
|
|
|
depth_normalized = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(depth_normalized, cmap='viridis') |
|
|
plt.colorbar(label='Depth') |
|
|
plt.title('Predicted Depth Map') |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
output_path = "depth_map.png" |
|
|
plt.savefig(output_path) |
|
|
plt.close() |
|
|
|
|
|
return output_path, f"Focal length: {focallength_px:.2f} pixels" |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_depth, |
|
|
inputs=gr.Image(type="filepath"), |
|
|
outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length")], |
|
|
title="Depth Prediction Demo", |
|
|
description="Upload an image to predict its depth map and focal length." |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |