Spaces:
Sleeping
Sleeping
separate load data and load model, set debug False
Browse files
scripts/scripts_utils/plotly_interface.py
CHANGED
|
@@ -244,7 +244,7 @@ def prediction_plot(
|
|
| 244 |
n_samples: int = 1,
|
| 245 |
use_biaser: bool = True,
|
| 246 |
) -> go.Figure:
|
| 247 |
-
range_radius =
|
| 248 |
if use_biaser:
|
| 249 |
risk_level = float(risk_level)
|
| 250 |
else:
|
|
@@ -262,8 +262,8 @@ def prediction_plot(
|
|
| 262 |
),
|
| 263 |
title_text="Road Scene",
|
| 264 |
hovermode="closest",
|
| 265 |
-
width=
|
| 266 |
-
height=
|
| 267 |
updatemenus=[
|
| 268 |
dict(
|
| 269 |
type="buttons",
|
|
@@ -332,8 +332,7 @@ def update_figure(
|
|
| 332 |
|
| 333 |
return fig
|
| 334 |
|
| 335 |
-
def
|
| 336 |
-
dataset = load_dataset(data_source, split="test")
|
| 337 |
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
|
| 338 |
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
|
| 339 |
cfg = Config.fromfile(config_file)
|
|
@@ -342,23 +341,31 @@ def load_from_huggingface(model_source: str = "TRI-ML/risk_biased_model", data_
|
|
| 342 |
predictor.eval()
|
| 343 |
predictor = predictor.to(device)
|
| 344 |
|
| 345 |
-
return predictor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
def main(load_from=None, cfg_path=None):
|
| 348 |
# Define the device to use
|
| 349 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
# Do the same thing as above but using the gradio blocks API
|
| 352 |
with gr.Blocks() as interface:
|
| 353 |
-
|
| 354 |
-
predictor, dataset = load_from_huggingface(device=device)
|
| 355 |
-
|
| 356 |
-
if load_from is not None:
|
| 357 |
-
cfg = Config.fromfile(cfg_path)
|
| 358 |
-
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
|
| 359 |
-
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
|
| 360 |
|
| 361 |
-
ui_update_fn = partial(update_figure, predictor, dataset)
|
| 362 |
gr.Markdown(
|
| 363 |
"""
|
| 364 |
# Risk-Aware Prediction
|
|
@@ -391,7 +398,7 @@ def main(load_from=None, cfg_path=None):
|
|
| 391 |
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
| 392 |
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
| 393 |
|
| 394 |
-
interface.launch(debug=
|
| 395 |
|
| 396 |
|
| 397 |
if __name__ == "__main__":
|
|
|
|
| 244 |
n_samples: int = 1,
|
| 245 |
use_biaser: bool = True,
|
| 246 |
) -> go.Figure:
|
| 247 |
+
range_radius = 50
|
| 248 |
if use_biaser:
|
| 249 |
risk_level = float(risk_level)
|
| 250 |
else:
|
|
|
|
| 262 |
),
|
| 263 |
title_text="Road Scene",
|
| 264 |
hovermode="closest",
|
| 265 |
+
width=600,
|
| 266 |
+
height=300,
|
| 267 |
updatemenus=[
|
| 268 |
dict(
|
| 269 |
type="buttons",
|
|
|
|
| 332 |
|
| 333 |
return fig
|
| 334 |
|
| 335 |
+
def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]:
|
|
|
|
| 336 |
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
|
| 337 |
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
|
| 338 |
cfg = Config.fromfile(config_file)
|
|
|
|
| 341 |
predictor.eval()
|
| 342 |
predictor = predictor.to(device)
|
| 343 |
|
| 344 |
+
return predictor
|
| 345 |
+
|
| 346 |
+
def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset:
|
| 347 |
+
dataset = load_dataset(data_source, split="test")
|
| 348 |
+
return dataset
|
| 349 |
|
| 350 |
def main(load_from=None, cfg_path=None):
|
| 351 |
# Define the device to use
|
| 352 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 353 |
+
print("Getting dataset")
|
| 354 |
+
dataset = load_dataset_from_hf()
|
| 355 |
+
|
| 356 |
+
if load_from is not None:
|
| 357 |
+
cfg = Config.fromfile(cfg_path)
|
| 358 |
+
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
|
| 359 |
+
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
|
| 360 |
+
else:
|
| 361 |
+
print("Getting model.")
|
| 362 |
+
predictor = load_predictor_from_hf(device=device)
|
| 363 |
+
|
| 364 |
+
ui_update_fn = partial(update_figure, predictor, dataset)
|
| 365 |
|
| 366 |
# Do the same thing as above but using the gradio blocks API
|
| 367 |
with gr.Blocks() as interface:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
|
|
|
| 369 |
gr.Markdown(
|
| 370 |
"""
|
| 371 |
# Risk-Aware Prediction
|
|
|
|
| 398 |
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
| 399 |
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
| 400 |
|
| 401 |
+
interface.launch(debug=False)
|
| 402 |
|
| 403 |
|
| 404 |
if __name__ == "__main__":
|