Create app.py
Browse filesAdd an app.py that:
Loads PathfindingNetwork and your weights.
Lets users either:
Upload a .npz sample (voxel_data [1,3,32,32,32], positions [1,2,3]), or
Generate a random environment and run inference.
Displays decoded actions.
app.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from pathfinding_nn import PathfindingNetwork, create_voxel_input
|
| 8 |
+
|
| 9 |
+
ACTION_NAMES = ['FORWARD','BACK','LEFT','RIGHT','UP','DOWN']
|
| 10 |
+
|
| 11 |
+
def load_model():
|
| 12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
+
model = PathfindingNetwork().to(device).eval()
|
| 14 |
+
|
| 15 |
+
# Prefer local checkpoint
|
| 16 |
+
local_ckpt = Path('training_outputs/final_model.pth')
|
| 17 |
+
ckpt_path = None
|
| 18 |
+
if local_ckpt.exists():
|
| 19 |
+
ckpt_path = str(local_ckpt)
|
| 20 |
+
else:
|
| 21 |
+
# Fallback to Hub (configure your repo and filename)
|
| 22 |
+
repo_id = os.getenv('MODEL_REPO_ID', '') # e.g. "your-username/voxel-pathfinder"
|
| 23 |
+
filename = os.getenv('MODEL_FILENAME', 'final_model.pth')
|
| 24 |
+
if repo_id:
|
| 25 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 26 |
+
|
| 27 |
+
if ckpt_path is None:
|
| 28 |
+
raise FileNotFoundError("Model checkpoint not found. Upload to training_outputs/final_model.pth or set MODEL_REPO_ID+MODEL_FILENAME env vars.")
|
| 29 |
+
|
| 30 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
| 31 |
+
state = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
|
| 32 |
+
model.load_state_dict(state)
|
| 33 |
+
return model, device
|
| 34 |
+
|
| 35 |
+
MODEL, DEVICE = load_model()
|
| 36 |
+
|
| 37 |
+
def decode(actions):
|
| 38 |
+
return [ACTION_NAMES[a] for a in actions if 0 <= a < 6]
|
| 39 |
+
|
| 40 |
+
def infer_random(obstacle_prob=0.2, seed=None):
|
| 41 |
+
if seed is not None:
|
| 42 |
+
np.random.seed(int(seed))
|
| 43 |
+
voxel_dim = MODEL.voxel_dim # (32,32,32)
|
| 44 |
+
D,H,W = voxel_dim
|
| 45 |
+
obstacles = (np.random.rand(D,H,W) < float(obstacle_prob)).astype(np.float32)
|
| 46 |
+
free = np.argwhere(obstacles == 0)
|
| 47 |
+
if len(free) < 2:
|
| 48 |
+
return {"error": "Not enough free cells; lower obstacle_prob."}
|
| 49 |
+
s_idx, g_idx = np.random.choice(len(free), size=2, replace=False)
|
| 50 |
+
start = tuple(free[s_idx])
|
| 51 |
+
goal = tuple(free[g_idx])
|
| 52 |
+
|
| 53 |
+
voxel_np = create_voxel_input(obstacles, start, goal, voxel_dim=voxel_dim)
|
| 54 |
+
voxel = torch.from_numpy(voxel_np).float().unsqueeze(0).to(DEVICE) # (1,3,32,32,32)
|
| 55 |
+
pos = torch.tensor([[start, goal]], dtype=torch.long, device=DEVICE) # (1,2,3)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
actions = MODEL(voxel, pos)[0].tolist()
|
| 59 |
+
return {
|
| 60 |
+
"start": start,
|
| 61 |
+
"goal": goal,
|
| 62 |
+
"num_actions": len([a for a in actions if 0 <= a < 6]),
|
| 63 |
+
"actions_ids": actions,
|
| 64 |
+
"actions_decoded": decode(actions)[:50]
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def infer_npz(npz_file):
|
| 68 |
+
if npz_file is None:
|
| 69 |
+
return {"error": "Please upload a .npz with keys 'voxel_data' and 'positions'."}
|
| 70 |
+
data = np.load(npz_file.name)
|
| 71 |
+
voxel = torch.from_numpy(data['voxel_data']).float().unsqueeze(0).to(DEVICE) # (1,3,32,32,32)
|
| 72 |
+
pos = torch.from_numpy(data['positions']).long().unsqueeze(0).to(DEVICE) # (1,2,3)
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
actions = MODEL(voxel, pos)[0].tolist()
|
| 75 |
+
return {
|
| 76 |
+
"num_actions": len([a for a in actions if 0 <= a < 6]),
|
| 77 |
+
"actions_ids": actions,
|
| 78 |
+
"actions_decoded": decode(actions)[:50]
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
with gr.Blocks(title="Voxel Path Finder") as demo:
|
| 82 |
+
gr.Markdown("## 3D Voxel Path Finder — Inference")
|
| 83 |
+
with gr.Tab("Random environment"):
|
| 84 |
+
obstacle = gr.Slider(0.0, 0.9, value=0.2, step=0.05, label="Obstacle probability")
|
| 85 |
+
seed = gr.Number(value=None, label="Seed (optional)")
|
| 86 |
+
btn = gr.Button("Run inference")
|
| 87 |
+
out = gr.JSON(label="Result")
|
| 88 |
+
btn.click(infer_random, inputs=[obstacle, seed], outputs=out)
|
| 89 |
+
|
| 90 |
+
with gr.Tab("Upload .npz sample"):
|
| 91 |
+
file = gr.File(file_types=[".npz"], label="Upload sample (voxel_data, positions)")
|
| 92 |
+
btn2 = gr.Button("Run inference")
|
| 93 |
+
out2 = gr.JSON(label="Result")
|
| 94 |
+
btn2.click(infer_npz, inputs=file, outputs=out2)
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
demo.launch()
|