adnlp commited on
Commit
46511a6
·
verified ·
1 Parent(s): e15f8df

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -166,7 +166,10 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
166
  value=f"Please Select Example or Provide CSV File.",
167
  visible=True
168
  ),
169
- None
 
 
 
170
  )
171
  elif (vision_encoder is None or text_encoder is None or tsfm is None):
172
  return (
@@ -174,7 +177,10 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
174
  value=f"Please Select Pretrained Model For UniCast.",
175
  visible=True
176
  ),
177
- None
 
 
 
178
  )
179
  else:
180
  pass
@@ -260,6 +266,7 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
260
  return (
261
  gr.Markdown(visible=False),
262
  fig,
 
263
  gr.Gallery(vision_heatmap_gallery_items, interactive=False, height="350px", object_fit="contain", visible=True),
264
  gr.Gallery(time_series_heatmap_gallery_items, interactive=False, height="350px", object_fit="contain", visible=True if time_series_heatmap_gallery_items else False)
265
  )
@@ -312,13 +319,16 @@ with gr.Blocks() as demo:
312
  )
313
  with gr.Row():
314
  with gr.Column(scale=2):
 
315
  dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity"], value=None, label="Datasets", interactive=True)
316
 
317
  dataset_description_textbox = gr.Textbox(label="Dataset Description", interactive=False)
318
 
 
319
  example_gallery = gr.Gallery(
320
  None,
321
- interactive=False
 
322
  )
323
  example_index = gr.State(value=None)
324
  example_gallery.select(selected_example, inputs=example_gallery, outputs=example_index)
@@ -336,14 +346,22 @@ with gr.Blocks() as demo:
336
  example_index.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=[time_series_file, time_series_dataframe])
337
 
338
  time_series_file.change(load_csv, inputs=[example_index, time_series_file], outputs=time_series_dataframe)
 
339
  with gr.Column(scale=1):
 
340
  vision_encoder_radio = gr.Radio(["CLIP", "BLIP"], label="Vision Encoder")
341
  text_encoder_radio = gr.Radio(["Qwen", "LLaMA"], label="Text Encoder")
342
  tsfm_radio = gr.Radio(["Timer", "Chronos"], label="Time Series Foundation Model")
343
  warning_markdown = gr.Markdown(visible=False)
344
  btn = gr.Button("Run")
 
345
  with gr.Column(scale=2):
 
 
346
  forecast_plot = gr.Plot(label="Forecast", format="png")
 
 
 
347
  vision_heatmap_gallery = gr.Gallery(visible=False)
348
  time_series_heatmap_gallery = gr.Gallery(visible=False)
349
 
 
166
  value=f"Please Select Example or Provide CSV File.",
167
  visible=True
168
  ),
169
+ None,
170
+ gr.Markdown(visible=False), # Hide attention header
171
+ gr.Gallery(visible=False), # Hide vision heatmaps
172
+ gr.Gallery(visible=False) # Hide time series heatmaps
173
  )
174
  elif (vision_encoder is None or text_encoder is None or tsfm is None):
175
  return (
 
177
  value=f"Please Select Pretrained Model For UniCast.",
178
  visible=True
179
  ),
180
+ None,
181
+ gr.Markdown(visible=False), # Hide attention header
182
+ gr.Gallery(visible=False), # Hide vision heatmaps
183
+ gr.Gallery(visible=False) # Hide time series heatmaps
184
  )
185
  else:
186
  pass
 
266
  return (
267
  gr.Markdown(visible=False),
268
  fig,
269
+ gr.Markdown("## Attention Analysis", visible=True),
270
  gr.Gallery(vision_heatmap_gallery_items, interactive=False, height="350px", object_fit="contain", visible=True),
271
  gr.Gallery(time_series_heatmap_gallery_items, interactive=False, height="350px", object_fit="contain", visible=True if time_series_heatmap_gallery_items else False)
272
  )
 
319
  )
320
  with gr.Row():
321
  with gr.Column(scale=2):
322
+ gr.Markdown("## Dataset Selection")
323
  dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity"], value=None, label="Datasets", interactive=True)
324
 
325
  dataset_description_textbox = gr.Textbox(label="Dataset Description", interactive=False)
326
 
327
+ gr.Markdown("## Time Series Examples")
328
  example_gallery = gr.Gallery(
329
  None,
330
+ interactive=False,
331
+ label="Select Time Series Example"
332
  )
333
  example_index = gr.State(value=None)
334
  example_gallery.select(selected_example, inputs=example_gallery, outputs=example_index)
 
346
  example_index.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=[time_series_file, time_series_dataframe])
347
 
348
  time_series_file.change(load_csv, inputs=[example_index, time_series_file], outputs=time_series_dataframe)
349
+
350
  with gr.Column(scale=1):
351
+ gr.Markdown("## Model Configuration")
352
  vision_encoder_radio = gr.Radio(["CLIP", "BLIP"], label="Vision Encoder")
353
  text_encoder_radio = gr.Radio(["Qwen", "LLaMA"], label="Text Encoder")
354
  tsfm_radio = gr.Radio(["Timer", "Chronos"], label="Time Series Foundation Model")
355
  warning_markdown = gr.Markdown(visible=False)
356
  btn = gr.Button("Run")
357
+
358
  with gr.Column(scale=2):
359
+ gr.Markdown("## Results")
360
+
361
  forecast_plot = gr.Plot(label="Forecast", format="png")
362
+
363
+ gr.Markdown("## Attention Analysis", visible=False)
364
+
365
  vision_heatmap_gallery = gr.Gallery(visible=False)
366
  time_series_heatmap_gallery = gr.Gallery(visible=False)
367