adnlp commited on
Commit
47c5cf0
·
verified ·
1 Parent(s): 7c22b3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -16
app.py CHANGED
@@ -7,6 +7,7 @@ import io
7
  from PIL import Image
8
  import pickle
9
  import requests
 
10
 
11
  hf_token = {
12
  "clipqwentimer": os.environ["HF_CLIPQwenTimer_Token"],
@@ -95,6 +96,72 @@ def load_csv(example_index, file):
95
  else:
96
  return gr.skip()
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def predict(dataset, text, example_index, file, vision_encoder, text_encoder, tsfm):
99
  if (dataset is None or example_index is None) or (example_index == -1 and file is None):
100
  return (
@@ -138,10 +205,8 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
138
  res = requests.post(url, headers=headers, json=payload)
139
  res_json = res.json()
140
 
 
141
  prediction = np.array(res_json['prediction'])
142
- vision_attentions = np.array(res_json['vision_attentions'])
143
- time_series_attentions = np.array(res_json['time_series_attentions'])
144
-
145
  cl = context_length[dataset]
146
  prediction = prediction[:cl]
147
  prediction = prediction*std+mean
@@ -151,19 +216,56 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
151
  start_time = input_dates_series.iloc[-1] + time_diff
152
  forecast_dates_series = pd.date_range(start=start_time, periods=len(input_dates_series), freq=time_diff)
153
 
154
- plt.style.use("seaborn-v0_8")
155
- fig, ax = plt.subplots()
156
- ax.plot(input_dates_series, time_series, color="black", alpha=0.7, linewidth=3, label='Input')
157
- ax.plot(forecast_dates_series, prediction, color='C2', alpha=0.7, linewidth=3, label='Forecast')
158
- if example_index == -1: # Custom Input
159
- true = df["Ground Truth"]
160
- else:
161
- true = targets[dataset][example_index].iloc[:, -1]
162
- if len(true) == context_length[dataset]:
163
- ax.plot(forecast_dates_series, true, color='C0', alpha=0.7, linewidth=3, label='Ground Truth')
164
- ax.legend()
 
 
 
 
 
 
 
 
 
165
 
166
- return gr.Markdown(visible=False), fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def add_example_gallery(dataset, gallery, example_index, file):
169
  if example_index == -1 and file:
@@ -245,8 +347,10 @@ with gr.Blocks() as demo:
245
  btn = gr.Button("Run")
246
  with gr.Column(scale=2):
247
  forecast_plot = gr.Plot(label="Forecast", format="png")
 
 
248
 
249
- btn.click(predict, inputs=[dataset_dropdown, dataset_description_textbox, example_index, time_series_file, vision_encoder_radio, text_encoder_radio, tsfm_radio], outputs=[warning_markdown, forecast_plot])
250
  btn.click(add_example_gallery, inputs=[dataset_dropdown, example_gallery, example_index, time_series_file], outputs=[example_gallery])
251
 
252
  if __name__ == "__main__":
 
7
  from PIL import Image
8
  import pickle
9
  import requests
10
+ import cv2
11
 
12
  hf_token = {
13
  "clipqwentimer": os.environ["HF_CLIPQwenTimer_Token"],
 
96
  else:
97
  return gr.skip()
98
 
99
+ def vision_attention_rollout(attentions, start_layer=0, end_layer=12):
100
+ seq_len = attentions.shape[-1]
101
+ result = np.eye(seq_len)
102
+
103
+ for attn in attentions[start_layer:end_layer]:
104
+ attn_heads = attn.mean(axis=0)
105
+ attn_aug = attn_heads + np.eye(seq_len)
106
+ attn_aug = attn_aug / attn_aug.sum(axis=-1, keepdims=True)
107
+ result = attn_aug @ result
108
+
109
+ return result[0, -49:]
110
+
111
+ def plot_vision_heatmap(image, rollout_attention, alpha=0.5, cmap='jet'):
112
+ num_patches = rollout_attention.shape[0]
113
+ grid_size = int(np.sqrt(num_patches))
114
+
115
+ attn_grid = rollout_attention.reshape(grid_size, grid_size)
116
+
117
+ H, W = image.shape[:2]
118
+ attn_map = cv2.resize(attn_grid, (W, H), interpolation=cv2.INTER_CUBIC)
119
+ attn_map = attn_map / attn_map.max()
120
+
121
+ plt.figure(figsize=(6,6))
122
+ plt.imshow(image)
123
+ plt.imshow(attn_map, cmap=cmap, alpha=alpha)
124
+ plt.axis('off')
125
+ buf = io.BytesIO()
126
+ plt.savefig(buf, format='png')
127
+ buf.seek(0)
128
+ plot_img = Image.open(buf).convert('RGB')
129
+ plt.clf()
130
+
131
+ return plot_img
132
+
133
+ def time_series_attention_sum(attentions, context_length, start_layer=0, end_layer=12):
134
+ import math
135
+ seq_len = attentions.shape[-1]
136
+ result = np.zeros(seq_len)
137
+ for attn in attentions[start_layer:end_layer]:
138
+ attn_heads = attn.mean(0).squeeze()
139
+ result += attn_heads
140
+ att_len = math.ceil(context_length/16)
141
+ return result[-att_len:]
142
+
143
+ def plot_time_series_heatmap(context, attention, time_steps):
144
+ fig, ax1 = plt.subplots(figsize=(8, 4))
145
+
146
+ attention = attention/attention.max()
147
+ cmap = plt.get_cmap("coolwarm")
148
+ colors = cmap(attention)
149
+ colors[:, -1] = attention
150
+ ax1.bar([16*i for i in range(len(attention))], attention, width=[16 if time_steps-16*(i+1)>0 else time_steps-16*i for i in range(len(attention))], align="edge", color=colors)
151
+ ax1.set_yticks([])
152
+
153
+ ax2 = ax1.twinx()
154
+ ax2.plot(context, color="black", linewidth=2)
155
+ ax2.yaxis.set_ticks_position("left")
156
+
157
+ buf = io.BytesIO()
158
+ plt.savefig(buf, format='png')
159
+ buf.seek(0)
160
+ plot_img = Image.open(buf).convert('RGB')
161
+ plt.clf()
162
+
163
+ return plot_img
164
+
165
  def predict(dataset, text, example_index, file, vision_encoder, text_encoder, tsfm):
166
  if (dataset is None or example_index is None) or (example_index == -1 and file is None):
167
  return (
 
205
  res = requests.post(url, headers=headers, json=payload)
206
  res_json = res.json()
207
 
208
+ # Forecast Plot
209
  prediction = np.array(res_json['prediction'])
 
 
 
210
  cl = context_length[dataset]
211
  prediction = prediction[:cl]
212
  prediction = prediction*std+mean
 
216
  start_time = input_dates_series.iloc[-1] + time_diff
217
  forecast_dates_series = pd.date_range(start=start_time, periods=len(input_dates_series), freq=time_diff)
218
 
219
+ plt.close()
220
+ with plt.style.context("seaborn-v0_8"):
221
+ fig, ax = plt.subplots(figsize=(10,4))
222
+ ax.plot(input_dates_series, time_series, color="black", alpha=0.7, linewidth=3, label='Input')
223
+ ax.plot(forecast_dates_series, prediction, color='C2', alpha=0.7, linewidth=3, label='Forecast')
224
+ if example_index == -1: # Custom Input
225
+ true = df["Ground Truth"]
226
+ else:
227
+ true = targets[dataset][example_index].iloc[:, -1]
228
+ if len(true) == context_length[dataset]:
229
+ ax.plot(forecast_dates_series, true, color='C0', alpha=0.7, linewidth=3, label='Ground Truth')
230
+ ax.legend()
231
+
232
+ # Vision Heatmap
233
+ plt.figure(figsize=(384/100, 384/100), dpi=100)
234
+ plt.plot(time_series_normalized, color="black", linestyle="-", linewidth=1, marker="*", markersize=1)
235
+ plt.xticks([])
236
+ plt.yticks([])
237
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
238
+ plt.margins(0,0)
239
 
240
+ buf = io.BytesIO()
241
+ plt.savefig(buf, format='png')
242
+ buf.seek(0)
243
+ context_image = np.array(Image.open(buf).convert('RGB'))
244
+
245
+ vision_attentions = np.array(res_json['vision_attentions'])
246
+ vision_heatmap_gallery_items = []
247
+ for i in range(0, 12, 3):
248
+ vis_attn = vision_attention_rollout(vision_attentions, i, i+3)
249
+ vision_heatmap = plot_vision_heatmap(context_image, vis_attn)
250
+ vision_heatmap_gallery_items.append((vision_heatmap, f"Heatmap from Layer{i}:{i+3}"))
251
+
252
+ # Time Series Heatmap
253
+ if tsfm == "Chronos":
254
+ time_series_attentions = np.array(res_json['time_series_attentions'])
255
+ time_series_heatmap_gallery_items = []
256
+ for i in range(0, 12, 3):
257
+ ts_attn = time_series_attention_sum(time_series_attentions, cl, i, i+3)
258
+ time_series_heatmap = plot_time_series_heatmap(time_series, ts_attn, cl)
259
+ time_series_heatmap_gallery_items.append((time_series_heatmap, f"Heatmap from Layer{i}:{i+3}"))
260
+ else:
261
+ time_series_heatmap_gallery_items = None
262
+
263
+ return (
264
+ gr.Markdown(visible=False),
265
+ fig,
266
+ gr.Gallery(vision_heatmap_gallery_items, interactive=False, height="350px", object_fit="contain", visible=True),
267
+ gr.Gallery(time_series_heatmap_gallery_items, interactive=False, height="350px", object_fit="contain", visible=True if time_series_heatmap_gallery_items else False)
268
+ )
269
 
270
  def add_example_gallery(dataset, gallery, example_index, file):
271
  if example_index == -1 and file:
 
347
  btn = gr.Button("Run")
348
  with gr.Column(scale=2):
349
  forecast_plot = gr.Plot(label="Forecast", format="png")
350
+ vision_heatmap_gallery = gr.Gallery(visible=False)
351
+ time_series_heatmap_gallery = gr.Gallery(visible=False)
352
 
353
+ btn.click(predict, inputs=[dataset_dropdown, dataset_description_textbox, example_index, time_series_file, vision_encoder_radio, text_encoder_radio, tsfm_radio], outputs=[warning_markdown, forecast_plot, vision_heatmap_gallery, time_series_heatmap_gallery])
354
  btn.click(add_example_gallery, inputs=[dataset_dropdown, example_gallery, example_index, time_series_file], outputs=[example_gallery])
355
 
356
  if __name__ == "__main__":