Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import numpy as np
|
|
| 4 |
from PIL import Image
|
| 5 |
import torchvision.transforms as T
|
| 6 |
from model import DepthSTAR
|
|
|
|
| 7 |
|
| 8 |
# Load model
|
| 9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -24,17 +25,24 @@ def predict_depth(image):
|
|
| 24 |
img = transform(image).unsqueeze(0).to(device)
|
| 25 |
with torch.no_grad():
|
| 26 |
pred = model(img)[0, 0].cpu().numpy()
|
| 27 |
-
|
|
|
|
| 28 |
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
| 29 |
-
pred_image = Image.fromarray((pred_normalized * 255).astype(np.uint8))
|
| 30 |
|
| 31 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
upscale_size = (256, 256)
|
| 33 |
image_resized = image.resize(upscale_size, resample=Image.NEAREST)
|
| 34 |
-
pred_resized =
|
| 35 |
|
| 36 |
return [image_resized, pred_resized]
|
| 37 |
|
|
|
|
| 38 |
# Gradio UI
|
| 39 |
# examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
|
| 40 |
examples = [["example.png"]]
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
import torchvision.transforms as T
|
| 6 |
from model import DepthSTAR
|
| 7 |
+
import matplotlib.cm as cm
|
| 8 |
|
| 9 |
# Load model
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 25 |
img = transform(image).unsqueeze(0).to(device)
|
| 26 |
with torch.no_grad():
|
| 27 |
pred = model(img)[0, 0].cpu().numpy()
|
| 28 |
+
|
| 29 |
+
# Normalize prediction
|
| 30 |
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
|
|
|
| 31 |
|
| 32 |
+
# Apply inferno colormap
|
| 33 |
+
cmap = cm.get_cmap('inferno')
|
| 34 |
+
pred_colored = cmap(pred_normalized)[:, :, :3] # Drop alpha channel
|
| 35 |
+
pred_colored = (pred_colored * 255).astype(np.uint8)
|
| 36 |
+
pred_pil = Image.fromarray(pred_colored)
|
| 37 |
+
|
| 38 |
+
# Upscale for display
|
| 39 |
upscale_size = (256, 256)
|
| 40 |
image_resized = image.resize(upscale_size, resample=Image.NEAREST)
|
| 41 |
+
pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST)
|
| 42 |
|
| 43 |
return [image_resized, pred_resized]
|
| 44 |
|
| 45 |
+
|
| 46 |
# Gradio UI
|
| 47 |
# examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
|
| 48 |
examples = [["example.png"]]
|