Spaces:
Running
Running
Liu Yiwen
commited on
Commit
·
e03ca4d
1
Parent(s):
0edb9ff
更新了选择target的功能
Browse files- __pycache__/utils.cpython-311.pyc +0 -0
- app.py +18 -12
- utils.py +15 -10
__pycache__/utils.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
|
|
|
app.py
CHANGED
|
@@ -222,8 +222,10 @@ with gr.Blocks() as demo:
|
|
| 222 |
# componets = []
|
| 223 |
# for _ in range(TIME_PLOTS_NUM):
|
| 224 |
with gr.Row():
|
| 225 |
-
with gr.Column(scale=
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
with gr.Column(scale=1):
|
| 228 |
select_buttom = gr.Button("Show selected items")
|
| 229 |
with gr.Row():
|
|
@@ -232,7 +234,7 @@ with gr.Blocks() as demo:
|
|
| 232 |
with gr.Column(scale=3):
|
| 233 |
plot = gr.Plot()
|
| 234 |
user_input_text = gr.Textbox(placeholder="输入一些内容")
|
| 235 |
-
# componets.append({"
|
| 236 |
# "statistics_textbox": statistics_textbox,
|
| 237 |
# "user_input_text": user_input_text,
|
| 238 |
# "plot": plot})
|
|
@@ -248,7 +250,7 @@ with gr.Blocks() as demo:
|
|
| 248 |
cp_result: gr.update(visible=False, value=""),
|
| 249 |
}
|
| 250 |
|
| 251 |
-
def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str]) -> dict:
|
| 252 |
try:
|
| 253 |
ret = {}
|
| 254 |
if dataset != 'Salesforce/lotsa_data':
|
|
@@ -261,15 +263,17 @@ with gr.Blocks() as demo:
|
|
| 261 |
df_list, id_list = [], []
|
| 262 |
for i, page in enumerate(page):
|
| 263 |
df, max_page, info = get_page(dataset, config, split, page)
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
| 265 |
row = df.iloc[0]
|
| 266 |
id_list.append(row['item_id'])
|
| 267 |
# 将单行的DataFrame展开为新的DataFrame
|
| 268 |
df_without_index = row.drop('item_id').to_frame().T
|
| 269 |
df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
|
| 270 |
df_list.append(df_expanded)
|
| 271 |
-
|
| 272 |
-
tot_samples = max_page
|
| 273 |
return {
|
| 274 |
statistics_textbox: gr.update(value=create_statistic(df_list, id_list)),
|
| 275 |
plot: gr.update(value=create_plot(df_list, id_list)),
|
|
@@ -292,8 +296,9 @@ with gr.Blocks() as demo:
|
|
| 292 |
def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
|
| 293 |
try:
|
| 294 |
return {
|
| 295 |
-
**show_dataset_at_config_and_split_and_page(dataset, config, split, "1"),
|
| 296 |
-
|
|
|
|
| 297 |
# cp_page: gr.update(value="1", visible=True),
|
| 298 |
# cp_goto_page: gr.update(visible=True),
|
| 299 |
# cp_goto_next_page: gr.update(visible=True),
|
|
@@ -336,17 +341,18 @@ with gr.Blocks() as demo:
|
|
| 336 |
all_outputs = [cp_config, cp_split,
|
| 337 |
# cp_page, cp_goto_page, cp_goto_next_page,
|
| 338 |
cp_result, cp_info, cp_error,
|
| 339 |
-
|
|
|
|
| 340 |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
|
| 341 |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
|
| 342 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
| 343 |
# cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
| 344 |
# cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
| 345 |
user_input_text.submit(save_to_file, inputs=user_input_text)
|
| 346 |
-
select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split,
|
| 347 |
|
| 348 |
|
| 349 |
if __name__ == "__main__":
|
| 350 |
|
| 351 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 352 |
-
uvicorn.run(app, host="
|
|
|
|
| 222 |
# componets = []
|
| 223 |
# for _ in range(TIME_PLOTS_NUM):
|
| 224 |
with gr.Row():
|
| 225 |
+
with gr.Column(scale=2):
|
| 226 |
+
select_sample_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True)
|
| 227 |
+
with gr.Column(scale=2):
|
| 228 |
+
select_subtarget_box = gr.Dropdown(choices=["subtargets"], label="Select some subtargets", multiselect=True, interactive=True)
|
| 229 |
with gr.Column(scale=1):
|
| 230 |
select_buttom = gr.Button("Show selected items")
|
| 231 |
with gr.Row():
|
|
|
|
| 234 |
with gr.Column(scale=3):
|
| 235 |
plot = gr.Plot()
|
| 236 |
user_input_text = gr.Textbox(placeholder="输入一些内容")
|
| 237 |
+
# componets.append({"select_sample_box": select_sample_box,
|
| 238 |
# "statistics_textbox": statistics_textbox,
|
| 239 |
# "user_input_text": user_input_text,
|
| 240 |
# "plot": plot})
|
|
|
|
| 250 |
cp_result: gr.update(visible=False, value=""),
|
| 251 |
}
|
| 252 |
|
| 253 |
+
def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]) -> dict:
|
| 254 |
try:
|
| 255 |
ret = {}
|
| 256 |
if dataset != 'Salesforce/lotsa_data':
|
|
|
|
| 263 |
df_list, id_list = [], []
|
| 264 |
for i, page in enumerate(page):
|
| 265 |
df, max_page, info = get_page(dataset, config, split, page)
|
| 266 |
+
global tot_samples, tot_targets
|
| 267 |
+
tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) else 1
|
| 268 |
+
|
| 269 |
+
df = clean_up_df(df, sub_targets)
|
| 270 |
row = df.iloc[0]
|
| 271 |
id_list.append(row['item_id'])
|
| 272 |
# 将单行的DataFrame展开为新的DataFrame
|
| 273 |
df_without_index = row.drop('item_id').to_frame().T
|
| 274 |
df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
|
| 275 |
df_list.append(df_expanded)
|
| 276 |
+
|
|
|
|
| 277 |
return {
|
| 278 |
statistics_textbox: gr.update(value=create_statistic(df_list, id_list)),
|
| 279 |
plot: gr.update(value=create_plot(df_list, id_list)),
|
|
|
|
| 296 |
def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
|
| 297 |
try:
|
| 298 |
return {
|
| 299 |
+
**show_dataset_at_config_and_split_and_page(dataset, config, split, "1", [0]),
|
| 300 |
+
select_sample_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]),
|
| 301 |
+
select_subtarget_box: gr.update(choices=[i for i in range(tot_targets)]+['all'], value=[0]),
|
| 302 |
# cp_page: gr.update(value="1", visible=True),
|
| 303 |
# cp_goto_page: gr.update(visible=True),
|
| 304 |
# cp_goto_next_page: gr.update(visible=True),
|
|
|
|
| 341 |
all_outputs = [cp_config, cp_split,
|
| 342 |
# cp_page, cp_goto_page, cp_goto_next_page,
|
| 343 |
cp_result, cp_info, cp_error,
|
| 344 |
+
select_sample_box, select_subtarget_box,
|
| 345 |
+
select_buttom, statistics_textbox, user_input_text, plot]
|
| 346 |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
|
| 347 |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
|
| 348 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
| 349 |
# cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
| 350 |
# cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
| 351 |
user_input_text.submit(save_to_file, inputs=user_input_text)
|
| 352 |
+
select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
|
| 353 |
|
| 354 |
|
| 355 |
if __name__ == "__main__":
|
| 356 |
|
| 357 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 358 |
+
uvicorn.run(app, host="127.0.0.1", port=7860)
|
utils.py
CHANGED
|
@@ -33,22 +33,22 @@ def ndarray_to_base64(ndarray):
|
|
| 33 |
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 34 |
return f"data:image/png;base64,{base64_str}"
|
| 35 |
|
| 36 |
-
def flatten_ndarray_column(df, column_name):
|
| 37 |
"""
|
| 38 |
-
将嵌套的np.ndarray
|
| 39 |
"""
|
| 40 |
-
def
|
| 41 |
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
|
| 42 |
-
|
|
|
|
| 43 |
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
|
| 44 |
return np.expand_dims(ndarray, axis=0)
|
| 45 |
return ndarray
|
| 46 |
|
| 47 |
-
|
| 48 |
-
max_length = max(flattened_data.apply(len))
|
| 49 |
|
| 50 |
-
for i in
|
| 51 |
-
df[f'{column_name}_{i}'] =
|
| 52 |
|
| 53 |
return df
|
| 54 |
|
|
@@ -110,16 +110,21 @@ def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
|
|
| 110 |
combined_stats_df = pd.concat(stats_list, ignore_index=True)
|
| 111 |
return combined_stats_df
|
| 112 |
|
| 113 |
-
def clean_up_df(df: pd.DataFrame) -> pd.DataFrame:
|
| 114 |
"""
|
| 115 |
清理数据集,将嵌套的np.ndarray列展平为多列。
|
| 116 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
df['timestamp'] = df.apply(lambda row: pd.date_range(
|
| 118 |
start=row['start'],
|
| 119 |
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
|
| 120 |
freq=row['freq']
|
| 121 |
).to_pydatetime().tolist(), axis=1)
|
| 122 |
-
df = flatten_ndarray_column(df, 'target')
|
| 123 |
# 删除原始的start和freq列
|
| 124 |
df.drop(columns=['start', 'freq', 'target'], inplace=True)
|
| 125 |
if 'past_feat_dynamic_real' in df.columns:
|
|
|
|
| 33 |
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 34 |
return f"data:image/png;base64,{base64_str}"
|
| 35 |
|
| 36 |
+
def flatten_ndarray_column(df, column_name, rows_to_include):
|
| 37 |
"""
|
| 38 |
+
将嵌套的np.ndarray列展平为多列,并只保留指定的行。
|
| 39 |
"""
|
| 40 |
+
def select_and_flatten(ndarray):
|
| 41 |
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
|
| 42 |
+
selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
|
| 43 |
+
return np.concatenate([select_and_flatten(subarray) for subarray in selected])
|
| 44 |
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
|
| 45 |
return np.expand_dims(ndarray, axis=0)
|
| 46 |
return ndarray
|
| 47 |
|
| 48 |
+
selected_data = df[column_name].apply(select_and_flatten)
|
|
|
|
| 49 |
|
| 50 |
+
for i in rows_to_include:
|
| 51 |
+
df[f'{column_name}_{i}'] = selected_data.apply(lambda x: x[i] if i < len(x) else np.nan)
|
| 52 |
|
| 53 |
return df
|
| 54 |
|
|
|
|
| 110 |
combined_stats_df = pd.concat(stats_list, ignore_index=True)
|
| 111 |
return combined_stats_df
|
| 112 |
|
| 113 |
+
def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
|
| 114 |
"""
|
| 115 |
清理数据集,将嵌套的np.ndarray列展平为多列。
|
| 116 |
"""
|
| 117 |
+
if 'all' in rows_to_include:
|
| 118 |
+
rows_to_include = list(range(len(df['target'][0]))) if isinstance(df['target'][0], np.ndarray) else 1
|
| 119 |
+
else:
|
| 120 |
+
rows_to_include = sorted(rows_to_include)
|
| 121 |
+
|
| 122 |
df['timestamp'] = df.apply(lambda row: pd.date_range(
|
| 123 |
start=row['start'],
|
| 124 |
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
|
| 125 |
freq=row['freq']
|
| 126 |
).to_pydatetime().tolist(), axis=1)
|
| 127 |
+
df = flatten_ndarray_column(df, 'target', rows_to_include)
|
| 128 |
# 删除原始的start和freq列
|
| 129 |
df.drop(columns=['start', 'freq', 'target'], inplace=True)
|
| 130 |
if 'past_feat_dynamic_real' in df.columns:
|