keivalya commited on
Commit
4c8e425
·
verified ·
1 Parent(s): e39d89b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from PIL import Image
5
  import torchvision.transforms as T
6
  from model import DepthSTAR
 
7
 
8
  # Load model
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -24,17 +25,24 @@ def predict_depth(image):
24
  img = transform(image).unsqueeze(0).to(device)
25
  with torch.no_grad():
26
  pred = model(img)[0, 0].cpu().numpy()
27
-
 
28
  pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
29
- pred_image = Image.fromarray((pred_normalized * 255).astype(np.uint8))
30
 
31
- # Resize both for better display (e.g., 512x512)
 
 
 
 
 
 
32
  upscale_size = (256, 256)
33
  image_resized = image.resize(upscale_size, resample=Image.NEAREST)
34
- pred_resized = pred_image.resize(upscale_size, resample=Image.NEAREST)
35
 
36
  return [image_resized, pred_resized]
37
 
 
38
  # Gradio UI
39
  # examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
40
  examples = [["example.png"]]
 
4
  from PIL import Image
5
  import torchvision.transforms as T
6
  from model import DepthSTAR
7
+ import matplotlib.cm as cm
8
 
9
  # Load model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
25
  img = transform(image).unsqueeze(0).to(device)
26
  with torch.no_grad():
27
  pred = model(img)[0, 0].cpu().numpy()
28
+
29
+ # Normalize prediction
30
  pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
 
31
 
32
+ # Apply inferno colormap
33
+ cmap = cm.get_cmap('inferno')
34
+ pred_colored = cmap(pred_normalized)[:, :, :3] # Drop alpha channel
35
+ pred_colored = (pred_colored * 255).astype(np.uint8)
36
+ pred_pil = Image.fromarray(pred_colored)
37
+
38
+ # Upscale for display
39
  upscale_size = (256, 256)
40
  image_resized = image.resize(upscale_size, resample=Image.NEAREST)
41
+ pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST)
42
 
43
  return [image_resized, pred_resized]
44
 
45
+
46
  # Gradio UI
47
  # examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
48
  examples = [["example.png"]]