adnlp commited on
Commit
7c22b3b
·
verified ·
1 Parent(s): bb74d11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -28
app.py CHANGED
@@ -9,10 +9,14 @@ import pickle
9
  import requests
10
 
11
  hf_token = {
12
- "clipqwen": os.environ["HF_CLIPQwenTimer_Token"],
13
- "clipllama": os.environ["HF_CLIPLLaMATimer_Token"],
14
- "blipqwen": os.environ["HF_BLIPQwenTimer_Token"],
15
- "blipllama": os.environ["HF_BLIPLLaMATimer_Token"]
 
 
 
 
16
  }
17
 
18
  with open('example/inputs.pkl', 'rb') as f:
@@ -38,7 +42,7 @@ context_length = {
38
  def selected_dataset(dataset):
39
  gallery_items = [(Image.open(f'example/img/{dataset.replace(" ", "_")}/{i}.png').convert('RGB'), str(i+1)) for i in range(3)]
40
  gallery_items.append((Image.open('example/img/custom.png').convert('RGB'), 'Custom Input'))
41
- return gr.Gallery(gallery_items, interactive=False, height="350px", object_fit="contain"), gr.Textbox(value=descriptions[dataset], label="Dataset Description", interactive=False)
42
 
43
  def selected_example(gallery, evt: gr.SelectData):
44
  if evt.index == len(gallery) -1:
@@ -77,17 +81,21 @@ def update_time_series_dataframe(dataset, example_index):
77
  if example_index is None:
78
  return None, None
79
  elif example_index == -1: # Custom Input
80
- return gr.File(label="Time Series CSV File", file_types=[".csv"], visible=True), gr.Dataframe(value=None, datatype="str", label="Time Series Input", interactive=True)
81
  else:
82
  df = inputs[dataset][example_index]
83
- return gr.File(value=None, visible=False), gr.Dataframe(value=df, label="Time Series Input", interactive=False)
84
 
85
- def load_csv(file):
86
- if file is None:
87
- return pd.DataFrame()
88
- return pd.read_csv(file.name)
 
 
 
 
89
 
90
- def predict(dataset, text, example_index, file, vision_encoder, text_encoder):
91
  if (dataset is None or example_index is None) or (example_index == -1 and file is None):
92
  return (
93
  gr.Markdown(
@@ -96,7 +104,7 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder):
96
  ),
97
  None
98
  )
99
- elif (vision_encoder is None or text_encoder is None):
100
  return (
101
  gr.Markdown(
102
  value=f"Please Select Pretrained Model For UniCast.",
@@ -118,9 +126,9 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder):
118
 
119
  text = None if text == '' else text
120
 
121
- unicast_model = f"{vision_encoder.lower()}{text_encoder.lower()}"
122
 
123
- url = f"https://adnlp-unicast-{unicast_model}timer.hf.space/predict"
124
  headers = {"Authorization": f"Bearer {hf_token[unicast_model]}"}
125
  payload = {
126
  "dataset": dataset,
@@ -128,13 +136,15 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder):
128
  "text": text
129
  }
130
  res = requests.post(url, headers=headers, json=payload)
131
-
132
- out = np.array(res.json()['prediction'])
 
 
 
133
 
134
  cl = context_length[dataset]
135
- out = out[:cl]
136
- # out = out.detach().cpu().numpy()
137
- out = out*std+mean
138
 
139
  input_dates_series = pd.to_datetime(df["Timestamp"])
140
  time_diff = input_dates_series.diff().mode()[0]
@@ -144,13 +154,13 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder):
144
  plt.style.use("seaborn-v0_8")
145
  fig, ax = plt.subplots()
146
  ax.plot(input_dates_series, time_series, color="black", alpha=0.7, linewidth=3, label='Input')
147
- ax.plot(forecast_dates_series, out, color='C2', alpha=0.7, linewidth=3, label='Forecast')
148
  if example_index == -1: # Custom Input
149
- true = df["True"]
150
  else:
151
  true = targets[dataset][example_index].iloc[:, -1]
152
  if len(true) == context_length[dataset]:
153
- ax.plot(forecast_dates_series, true, color='C0', alpha=0.7, linewidth=3, label='True')
154
  ax.legend()
155
 
156
  return gr.Markdown(visible=False), fig
@@ -159,7 +169,7 @@ def add_example_gallery(dataset, gallery, example_index, file):
159
  if example_index == -1 and file:
160
  df = pd.read_csv(file.name)
161
  custom_input = df[["Timestamp", "Value"]]
162
- custom_target = df[["Timestamp", "True"]]
163
 
164
 
165
  plt.style.use("seaborn-v0_8")
@@ -217,8 +227,8 @@ with gr.Blocks() as demo:
217
  guide_text_markdown = gr.Markdown(visible=False)
218
  sample_csv_file = gr.File(visible=False)
219
 
220
- time_series_file = gr.File(label="Time Series CSV File", file_types=[".csv"], visible=False)
221
- time_series_dataframe = gr.Dataframe(value=None, headers=["Timestamp", "Value"], label="Time Series Input", interactive=False)
222
 
223
  dataset_dropdown.change(selected_dataset, inputs=dataset_dropdown, outputs=[example_gallery, dataset_description_textbox])
224
  dataset_dropdown.change(update_guide_markdown, inputs=[dataset_dropdown, example_index], outputs=[guide_text_markdown, sample_csv_file])
@@ -226,16 +236,17 @@ with gr.Blocks() as demo:
226
  example_index.change(update_guide_markdown, inputs=[dataset_dropdown, example_index], outputs=[guide_text_markdown, sample_csv_file])
227
  example_index.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=[time_series_file, time_series_dataframe])
228
 
229
- time_series_file.change(load_csv, inputs=time_series_file, outputs=time_series_dataframe)
230
  with gr.Column(scale=1):
231
  vision_encoder_radio = gr.Radio(["CLIP", "BLIP"], label="Vision Encoder")
232
  text_encoder_radio = gr.Radio(["Qwen", "LLaMA"], label="Text Encoder")
 
233
  warning_markdown = gr.Markdown(visible=False)
234
  btn = gr.Button("Run")
235
  with gr.Column(scale=2):
236
  forecast_plot = gr.Plot(label="Forecast", format="png")
237
 
238
- btn.click(predict, inputs=[dataset_dropdown, dataset_description_textbox, example_index, time_series_file, vision_encoder_radio, text_encoder_radio], outputs=[warning_markdown, forecast_plot])
239
  btn.click(add_example_gallery, inputs=[dataset_dropdown, example_gallery, example_index, time_series_file], outputs=[example_gallery])
240
 
241
  if __name__ == "__main__":
 
9
  import requests
10
 
11
  hf_token = {
12
+ "clipqwentimer": os.environ["HF_CLIPQwenTimer_Token"],
13
+ "clipllamatimer": os.environ["HF_CLIPLLaMATimer_Token"],
14
+ "blipqwentimer": os.environ["HF_BLIPQwenTimer_Token"],
15
+ "blipllamatimer": os.environ["HF_BLIPLLaMATimer_Token"],
16
+ "clipqwenchronos": os.environ["HF_CLIPQwenChronos_Token"],
17
+ "clipllamachronos": os.environ["HF_CLIPLLaMAChronos_Token"],
18
+ "blipqwenchronos": os.environ["HF_BLIPQwenChronos_Token"],
19
+ "blipllamachronos": os.environ["HF_BLIPLLaMAChronos_Token"]
20
  }
21
 
22
  with open('example/inputs.pkl', 'rb') as f:
 
42
  def selected_dataset(dataset):
43
  gallery_items = [(Image.open(f'example/img/{dataset.replace(" ", "_")}/{i}.png').convert('RGB'), str(i+1)) for i in range(3)]
44
  gallery_items.append((Image.open('example/img/custom.png').convert('RGB'), 'Custom Input'))
45
+ return gr.Gallery(gallery_items, interactive=False, height="350px", object_fit="contain", preview=True), gr.Textbox(value=descriptions[dataset], label="Dataset Description", interactive=False)
46
 
47
  def selected_example(gallery, evt: gr.SelectData):
48
  if evt.index == len(gallery) -1:
 
81
  if example_index is None:
82
  return None, None
83
  elif example_index == -1: # Custom Input
84
+ return gr.File(label="Time Series CSV File", file_types=[".csv"], visible=True), gr.Dataframe(value=None, visible=False)
85
  else:
86
  df = inputs[dataset][example_index]
87
+ return gr.File(value=None, visible=False), gr.Dataframe(value=df, label="Time Series Input", interactive=False, visible=True)
88
 
89
+ def load_csv(example_index, file):
90
+ if example_index == -1:
91
+ if file is not None:
92
+ return gr.Dataframe(value=pd.read_csv(file.name), visible=True)
93
+ else:
94
+ return gr.Dataframe(value=None, visible=False)
95
+ else:
96
+ return gr.skip()
97
 
98
+ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, tsfm):
99
  if (dataset is None or example_index is None) or (example_index == -1 and file is None):
100
  return (
101
  gr.Markdown(
 
104
  ),
105
  None
106
  )
107
+ elif (vision_encoder is None or text_encoder is None or tsfm is None):
108
  return (
109
  gr.Markdown(
110
  value=f"Please Select Pretrained Model For UniCast.",
 
126
 
127
  text = None if text == '' else text
128
 
129
+ unicast_model = f"{vision_encoder.lower()}{text_encoder.lower()}{tsfm.lower()}"
130
 
131
+ url = f"https://adnlp-unicast-{unicast_model}.hf.space/predict"
132
  headers = {"Authorization": f"Bearer {hf_token[unicast_model]}"}
133
  payload = {
134
  "dataset": dataset,
 
136
  "text": text
137
  }
138
  res = requests.post(url, headers=headers, json=payload)
139
+ res_json = res.json()
140
+
141
+ prediction = np.array(res_json['prediction'])
142
+ vision_attentions = np.array(res_json['vision_attentions'])
143
+ time_series_attentions = np.array(res_json['time_series_attentions'])
144
 
145
  cl = context_length[dataset]
146
+ prediction = prediction[:cl]
147
+ prediction = prediction*std+mean
 
148
 
149
  input_dates_series = pd.to_datetime(df["Timestamp"])
150
  time_diff = input_dates_series.diff().mode()[0]
 
154
  plt.style.use("seaborn-v0_8")
155
  fig, ax = plt.subplots()
156
  ax.plot(input_dates_series, time_series, color="black", alpha=0.7, linewidth=3, label='Input')
157
+ ax.plot(forecast_dates_series, prediction, color='C2', alpha=0.7, linewidth=3, label='Forecast')
158
  if example_index == -1: # Custom Input
159
+ true = df["Ground Truth"]
160
  else:
161
  true = targets[dataset][example_index].iloc[:, -1]
162
  if len(true) == context_length[dataset]:
163
+ ax.plot(forecast_dates_series, true, color='C0', alpha=0.7, linewidth=3, label='Ground Truth')
164
  ax.legend()
165
 
166
  return gr.Markdown(visible=False), fig
 
169
  if example_index == -1 and file:
170
  df = pd.read_csv(file.name)
171
  custom_input = df[["Timestamp", "Value"]]
172
+ custom_target = df[["Timestamp", "Ground Truth"]]
173
 
174
 
175
  plt.style.use("seaborn-v0_8")
 
227
  guide_text_markdown = gr.Markdown(visible=False)
228
  sample_csv_file = gr.File(visible=False)
229
 
230
+ time_series_file = gr.File(value=None, visible=False)
231
+ time_series_dataframe = gr.Dataframe(visible=False)
232
 
233
  dataset_dropdown.change(selected_dataset, inputs=dataset_dropdown, outputs=[example_gallery, dataset_description_textbox])
234
  dataset_dropdown.change(update_guide_markdown, inputs=[dataset_dropdown, example_index], outputs=[guide_text_markdown, sample_csv_file])
 
236
  example_index.change(update_guide_markdown, inputs=[dataset_dropdown, example_index], outputs=[guide_text_markdown, sample_csv_file])
237
  example_index.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=[time_series_file, time_series_dataframe])
238
 
239
+ time_series_file.change(load_csv, inputs=[example_index, time_series_file], outputs=time_series_dataframe)
240
  with gr.Column(scale=1):
241
  vision_encoder_radio = gr.Radio(["CLIP", "BLIP"], label="Vision Encoder")
242
  text_encoder_radio = gr.Radio(["Qwen", "LLaMA"], label="Text Encoder")
243
+ tsfm_radio = gr.Radio(["Timer", "Chronos"], label="Time Series Foundation Model")
244
  warning_markdown = gr.Markdown(visible=False)
245
  btn = gr.Button("Run")
246
  with gr.Column(scale=2):
247
  forecast_plot = gr.Plot(label="Forecast", format="png")
248
 
249
+ btn.click(predict, inputs=[dataset_dropdown, dataset_description_textbox, example_index, time_series_file, vision_encoder_radio, text_encoder_radio, tsfm_radio], outputs=[warning_markdown, forecast_plot])
250
  btn.click(add_example_gallery, inputs=[dataset_dropdown, example_gallery, example_index, time_series_file], outputs=[example_gallery])
251
 
252
  if __name__ == "__main__":