keivalya commited on
Commit
98141d1
·
verified ·
1 Parent(s): 4ac96b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as T
6
+ from model import HybridDepthModel
7
+
8
+ # Load model
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = HybridDepthModel().to(device)
11
+ model.load_state_dict(torch.load("depth_model_all.pth", map_location=device))
12
+ model.eval()
13
+
14
+ # Preprocessing
15
+ transform = T.Compose([
16
+ T.Resize((32, 32)),
17
+ T.ToTensor(),
18
+ ])
19
+
20
+ def predict_depth(image):
21
+ img = transform(image).unsqueeze(0).to(device)
22
+ with torch.no_grad():
23
+ pred = model(img)[0, 0].cpu().numpy()
24
+ pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
25
+ pred_image = Image.fromarray((pred_normalized * 255).astype(np.uint8))
26
+ return [image, pred_image]
27
+
28
+ # Gradio UI
29
+ examples = [["example.png"]]
30
+
31
+ demo = gr.Interface(
32
+ fn=predict_depth,
33
+ inputs=gr.Image(type="pil", label="Input RGB Image"),
34
+ outputs=[
35
+ gr.Image(type="pil", label="Original Image"),
36
+ gr.Image(type="pil", label="Predicted Depth Map"),
37
+ ],
38
+ title="🔭 DepthStar: Light-weight Depth Estimation",
39
+ description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.",
40
+ examples=examples,
41
+ allow_flagging="never",
42
+ theme="huggingface",
43
+ )
44
+
45
+ if __name__ == "__main__":
46
+ demo.launch()