jisujang commited on
Commit
a005c19
·
1 Parent(s): a454ab2
.gitignore ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # *.json
174
+
175
+ assets
176
+ DevMACS-AI-solution-devmacs
177
+ Research-AI-research-t2v_f1score_evaluator
178
+ .env
179
+ enviroments/abnormal-situation-leaderboard-3ca42d06719e.json
180
+ leaderboard_test
181
+ enviroments/deep-byte-352904-a072fdf439e7.json
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from leaderboard_ui.tab.submit_tab import submit_tab
4
+ from leaderboard_ui.tab.leaderboard_tab import leaderboard_tab
5
+ from leaderboard_ui.tab.dataset_visual_tab import visual_tab
6
+ from leaderboard_ui.tab.metric_visaul_tab import metric_visual_tab
7
+
8
+ abs_path = Path(__file__).parent
9
+
10
+ with gr.Blocks() as demo:
11
+ gr.Markdown("""
12
+ # 🥇 PIA_leaderboard
13
+ """)
14
+ with gr.Tabs():
15
+ leaderboard_tab()
16
+ submit_tab()
17
+ visual_tab()
18
+ metric_visual_tab()
19
+
20
+ if __name__ == "__main__":
21
+ demo.launch()
enviroments/.gitkeep ADDED
File without changes
enviroments/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ EXCLUDE_DIRS = {"@eaDir", 'temp'}
4
+
5
+ TYPES = [
6
+ "markdown",
7
+ "markdown",
8
+ "number",
9
+ "number",
10
+ "number",
11
+ "number",
12
+ "number",
13
+ "number",
14
+ "number",
15
+ "str",
16
+ "str",
17
+ "str",
18
+ "str",
19
+ "bool",
20
+ "str",
21
+ "number",
22
+ "number",
23
+ "bool",
24
+ "str",
25
+ "bool",
26
+ "bool",
27
+ "str",
28
+ ]
29
+
30
+ ON_LOAD_COLUMNS = [
31
+ "TASK",
32
+ "Model",
33
+ "PIA" # 모델 이름
34
+ ]
35
+
36
+ OFF_LOAD_COLUMNS = ["Model link", "PIA", "PIA * 100" , "Model name" ]
37
+
38
+ HIDE_COLUMNS = ["PIA * 100"]
39
+
40
+ FILTER_COLUMNS = ["T"]
41
+
42
+ NUMERIC_COLUMNS = ["PIA"]
43
+
44
+ NUMERIC_INTERVALS = {
45
+ "?": pd.Interval(-1, 0, closed="right"),
46
+ "~1.5": pd.Interval(0, 2, closed="right"),
47
+ "~3": pd.Interval(2, 4, closed="right"),
48
+ "~7": pd.Interval(4, 9, closed="right"),
49
+ "~13": pd.Interval(9, 20, closed="right"),
50
+ "~35": pd.Interval(20, 45, closed="right"),
51
+ "~60": pd.Interval(45, 70, closed="right"),
52
+ "70+": pd.Interval(70, 10000, closed="right"),
53
+ }
enviroments/convert.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from dotenv import load_dotenv
4
+ load_dotenv()
5
+
6
+ def get_json_from_env_var(env_var_name):
7
+ """
8
+ 환경 변수에서 JSON 데이터를 가져와 딕셔너리로 변환하는 함수.
9
+ :param env_var_name: 환경 변수 이름
10
+ :return: 딕셔너리 형태의 JSON 데이터
11
+ """
12
+ json_string = os.getenv(env_var_name)
13
+ if not json_string:
14
+ raise EnvironmentError(f"환경 변수 '{env_var_name}'가 설정되지 않았습니다.")
15
+
16
+ try:
17
+ # 줄바꿈(\n)을 이스케이프 문자(\\n)로 변환
18
+ json_string = json_string.replace("\n", "\\n")
19
+
20
+ # JSON 문자열을 딕셔너리로 변환
21
+ json_data = json.loads(json_string)
22
+ except json.JSONDecodeError as e:
23
+ raise ValueError(f"JSON 변환 실패: {e}")
24
+
25
+ return json_data
26
+
27
+
28
+
29
+ def json_to_env_var(json_file_path, env_var_name="JSON_ENV_VAR"):
30
+ """
31
+ 주어진 JSON 파일의 데이터를 환경 변수 형태로 변환하여 출력하는 함수.
32
+
33
+ :param json_file_path: JSON 파일 경로
34
+ :param env_var_name: 환경 변수 이름 (기본값: JSON_ENV_VAR)
35
+ :return: None
36
+ """
37
+ try:
38
+ # JSON 파일 읽기
39
+ with open(json_file_path, 'r') as json_file:
40
+ json_data = json.load(json_file)
41
+
42
+ # JSON 데이터를 문자열로 변환
43
+ json_string = json.dumps(json_data)
44
+
45
+ # 환경 변수 형태로 출력
46
+ env_variable = f'{env_var_name}={json_string}'
47
+ print("\n환경 변수로 사용할 수 있는 출력값:\n")
48
+ print(env_variable)
49
+ print("\n위 값을 .env 파일에 복사하여 붙여넣으세요.")
50
+ except FileNotFoundError:
51
+ print(f"파일을 찾을 수 없습니다: {json_file_path}")
52
+ except json.JSONDecodeError:
53
+ print(f"유효한 JSON 파일이 아닙니다: {json_file_path}")
54
+
leaderboard_ui/tab/dataset_visual_tab.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from leaderboard_ui.tab.submit_tab import submit_tab
4
+ from leaderboard_ui.tab.leaderboard_tab import leaderboard_tab
5
+ abs_path = Path(__file__).parent
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import pandas as pd
9
+ import numpy as np
10
+ from utils.bench_meta import process_videos_in_directory
11
+ # Mock 데이터 생성
12
+ def create_mock_data():
13
+ benchmarks = ['VQA-2023', 'ImageQuality-2024', 'VideoEnhance-2024']
14
+ categories = ['Animation', 'Game', 'Movie', 'Sports', 'Vlog']
15
+
16
+ data_list = []
17
+
18
+ for benchmark in benchmarks:
19
+ n_videos = np.random.randint(50, 100)
20
+ for _ in range(n_videos):
21
+ category = np.random.choice(categories)
22
+
23
+ data_list.append({
24
+ "video_name": f"video_{np.random.randint(1000, 9999)}.mp4",
25
+ "resolution": np.random.choice(["1920x1080", "3840x2160", "1280x720"]),
26
+ "video_duration": f"{np.random.randint(0, 10)}:{np.random.randint(0, 60)}",
27
+ "category": category,
28
+ "benchmark": benchmark,
29
+ "duration_seconds": np.random.randint(30, 600),
30
+ "total_frames": np.random.randint(1000, 10000),
31
+ "file_format": ".mp4",
32
+ "file_size_mb": round(np.random.uniform(10, 1000), 2),
33
+ "aspect_ratio": 16/9,
34
+ "fps": np.random.choice([24, 30, 60])
35
+ })
36
+
37
+ return pd.DataFrame(data_list)
38
+
39
+ # Mock 데이터 생성
40
+ # df = process_videos_in_directory("/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench")
41
+ df = pd.read_csv("sample.csv")
42
+ print("DataFrame shape:", df.shape)
43
+ print("DataFrame columns:", df.columns)
44
+ print("DataFrame head:\n", df.head())
45
+ def create_category_pie_chart(df, selected_benchmark, selected_categories=None):
46
+ filtered_df = df[df['benchmark'] == selected_benchmark]
47
+
48
+ if selected_categories:
49
+ filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
50
+
51
+ category_counts = filtered_df['category'].value_counts()
52
+
53
+ fig = px.pie(
54
+ values=category_counts.values,
55
+ names=category_counts.index,
56
+ title=f'{selected_benchmark} - Video Distribution by Category',
57
+ hole=0.3
58
+ )
59
+ fig.update_traces(textposition='inside', textinfo='percent+label')
60
+
61
+ return fig
62
+
63
+ ###TODO 스트링일경우 어케 처리
64
+
65
+ def create_bar_chart(df, selected_benchmark, selected_categories, selected_column):
66
+ # Filter by benchmark and categories
67
+ filtered_df = df[df['benchmark'] == selected_benchmark]
68
+ if selected_categories:
69
+ filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
70
+
71
+ # Create bar chart for selected column
72
+ fig = px.bar(
73
+ filtered_df,
74
+ x=selected_column,
75
+ y='video_name',
76
+ color='category', # Color by category
77
+ title=f'{selected_benchmark} - Video {selected_column}',
78
+ orientation='h', # Horizontal bar chart
79
+ color_discrete_sequence=px.colors.qualitative.Set3 # Color palette
80
+ )
81
+
82
+ # Adjust layout
83
+ fig.update_layout(
84
+ height=max(400, len(filtered_df) * 30), # Adjust height based on data
85
+ yaxis={'categoryorder': 'total ascending'}, # Sort by value
86
+ margin=dict(l=200), # Margin for long video names
87
+ showlegend=True, # Show legend
88
+ legend=dict(
89
+ orientation="h", # Horizontal legend
90
+ yanchor="bottom",
91
+ y=1.02, # Place legend above graph
92
+ xanchor="right",
93
+ x=1
94
+ )
95
+ )
96
+
97
+ return fig
98
+
99
+ def submit_tab():
100
+ with gr.Tab("🚀 Submit here! "):
101
+ with gr.Row():
102
+ gr.Markdown("# ✉️✨ Submit your Result here!")
103
+
104
+ def visual_tab():
105
+ with gr.Tab("📊 Bench Info"):
106
+ with gr.Row():
107
+ benchmark_dropdown = gr.Dropdown(
108
+ choices=sorted(df['benchmark'].unique().tolist()),
109
+ value=sorted(df['benchmark'].unique().tolist())[0],
110
+ label="Select Benchmark",
111
+ interactive=True
112
+ )
113
+
114
+ category_multiselect = gr.CheckboxGroup(
115
+ choices=sorted(df['category'].unique().tolist()),
116
+ label="Select Categories (empty for all)",
117
+ interactive=True
118
+ )
119
+
120
+ # Pie chart
121
+ pie_plot_output = gr.Plot(label="pie")
122
+
123
+ # Column selection dropdown
124
+ column_options = [
125
+ "video_duration", "duration_seconds", "total_frames",
126
+ "file_size_mb", "aspect_ratio", "fps", "file_format"
127
+ ]
128
+
129
+ column_dropdown = gr.Dropdown(
130
+ choices=column_options,
131
+ value=column_options[0],
132
+ label="Select Data to Compare",
133
+ interactive=True
134
+ )
135
+
136
+ # Bar chart
137
+ bar_plot_output = gr.Plot(label="video")
138
+
139
+ def update_plots(benchmark, categories, selected_column):
140
+ pie_chart = create_category_pie_chart(df, benchmark, categories)
141
+ bar_chart = create_bar_chart(df, benchmark, categories, selected_column)
142
+ return pie_chart, bar_chart
143
+
144
+ # Connect event handlers
145
+ benchmark_dropdown.change(
146
+ fn=update_plots,
147
+ inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
148
+ outputs=[pie_plot_output, bar_plot_output]
149
+ )
150
+ category_multiselect.change(
151
+ fn=update_plots,
152
+ inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
153
+ outputs=[pie_plot_output, bar_plot_output]
154
+ )
155
+ column_dropdown.change(
156
+ fn=update_plots,
157
+ inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
158
+ outputs=[pie_plot_output, bar_plot_output]
159
+ )
160
+
leaderboard_ui/tab/leaderboard_tab.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_leaderboard import Leaderboard, SelectColumns, ColumnFilter,SearchColumns
3
+ import enviroments.config as config
4
+ from sheet_manager.sheet_loader.sheet2df import sheet2df
5
+
6
+ def leaderboard_tab():
7
+ with gr.Tab("🏆Leaderboard"):
8
+ leaderboard = Leaderboard(
9
+ value=sheet2df(),
10
+ select_columns=SelectColumns(
11
+ default_selection=config.ON_LOAD_COLUMNS,
12
+ cant_deselect=config.OFF_LOAD_COLUMNS,
13
+ label="Select Columns to Display:",
14
+ info="Check"
15
+ ),
16
+
17
+ search_columns=SearchColumns(
18
+ primary_column="Model name",
19
+ secondary_columns=["TASK"],
20
+ placeholder="Search",
21
+ label="Search"
22
+ ),
23
+ hide_columns=config.HIDE_COLUMNS,
24
+ filter_columns=[
25
+ ColumnFilter(
26
+ column= "TASK",
27
+ ),
28
+ ColumnFilter(
29
+ column="PIA * 100",
30
+ type="slider",
31
+ min=0, # 77
32
+ max=100, # 92
33
+ # default=[min_val, max_val],
34
+ default = [77 ,92],
35
+ label="PIA" # 실제 값의 100배로 표시됨,
36
+ )
37
+ ],
38
+
39
+ datatype=config.TYPES,
40
+ # column_widths=["33%", "10%"],
41
+ )
42
+ refresh_button = gr.Button("🔄 Refresh Leaderboard")
43
+
44
+ def refresh_leaderboard():
45
+ return sheet2df()
46
+
47
+ refresh_button.click(
48
+ refresh_leaderboard,
49
+ inputs=[],
50
+ outputs=leaderboard,
51
+ )
52
+
leaderboard_ui/tab/metric_visaul_tab.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ abs_path = Path(__file__).parent
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ import pandas as pd
7
+ import numpy as np
8
+ from sheet_manager.sheet_loader.sheet2df import sheet2df
9
+ from sheet_manager.sheet_convert.json2sheet import str2json
10
+ # Mock 데이터 생성
11
+ def calculate_avg_metrics(df):
12
+ """
13
+ 각 모델의 카테고리별 평균 성능 지표를 계산
14
+ """
15
+ metrics_data = []
16
+
17
+ for _, row in df.iterrows():
18
+ model_name = row['Model name']
19
+
20
+ # PIA가 비어있거나 다른 값인 경우 건너뛰기
21
+ if pd.isna(row['PIA']) or not isinstance(row['PIA'], str):
22
+ print(f"Skipping model {model_name}: Invalid PIA data")
23
+ continue
24
+
25
+ try:
26
+ metrics = str2json(row['PIA'])
27
+
28
+ # metrics가 None이거나 dict가 아닌 경우 건너뛰기
29
+ if not metrics or not isinstance(metrics, dict):
30
+ print(f"Skipping model {model_name}: Invalid JSON format")
31
+ continue
32
+
33
+ # 필요한 카테고리가 모두 있는지 확인
34
+ required_categories = ['falldown', 'violence', 'fire']
35
+ if not all(cat in metrics for cat in required_categories):
36
+ print(f"Skipping model {model_name}: Missing required categories")
37
+ continue
38
+
39
+ # 필요한 메트릭이 모두 있는지 확인
40
+ required_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
41
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
42
+
43
+ avg_metrics = {}
44
+ for metric in required_metrics:
45
+ try:
46
+ values = [metrics[cat][metric] for cat in required_categories
47
+ if metric in metrics[cat]]
48
+ if values: # 값이 있는 경우만 평균 계산
49
+ avg_metrics[metric] = sum(values) / len(values)
50
+ else:
51
+ avg_metrics[metric] = 0 # 또는 다른 기본값 설정
52
+ except (KeyError, TypeError) as e:
53
+ print(f"Error calculating {metric} for {model_name}: {str(e)}")
54
+ avg_metrics[metric] = 0 # 에러 발생 시 기본값 설정
55
+
56
+ metrics_data.append({
57
+ 'model_name': model_name,
58
+ **avg_metrics
59
+ })
60
+
61
+ except Exception as e:
62
+ print(f"Error processing model {model_name}: {str(e)}")
63
+ continue
64
+
65
+ return pd.DataFrame(metrics_data)
66
+
67
+ def create_performance_chart(df, selected_metrics):
68
+ """
69
+ 모델별 선택된 성능 지표의 수평 막대 그래프 생성
70
+ """
71
+ fig = go.Figure()
72
+
73
+ # 모델 이름 길이에 따른 마진 계산
74
+ max_name_length = max([len(name) for name in df['model_name']])
75
+ left_margin = min(max_name_length * 7, 500) # 글자 수에 따라 마진 조정, 최대 500
76
+
77
+ for metric in selected_metrics:
78
+ fig.add_trace(go.Bar(
79
+ name=metric,
80
+ y=df['model_name'], # y축에 모델 이름
81
+ x=df[metric], # x축에 성능 지표 값
82
+ text=[f'{val:.3f}' for val in df[metric]],
83
+ textposition='auto',
84
+ orientation='h' # 수평 방향 막대
85
+ ))
86
+
87
+ fig.update_layout(
88
+ title='Model Performance Comparison',
89
+ yaxis_title='Model Name',
90
+ xaxis_title='Performance',
91
+ barmode='group',
92
+ height=max(400, len(df) * 40), # 모델 수에 따라 높이 조정
93
+ margin=dict(l=left_margin, r=50, t=50, b=50), # 왼쪽 마진 동적 조정
94
+ showlegend=True,
95
+ legend=dict(
96
+ orientation="h",
97
+ yanchor="bottom",
98
+ y=1.02,
99
+ xanchor="right",
100
+ x=1
101
+ ),
102
+ yaxis={'categoryorder': 'total ascending'} # 성능 순으로 정렬
103
+ )
104
+
105
+ # y축 레이블 스타일 조정
106
+ fig.update_yaxes(tickfont=dict(size=10)) # 글자 크기 조정
107
+
108
+ return fig
109
+ def create_confusion_matrix(metrics_data, selected_category):
110
+ """혼동 행렬 시각화 생성"""
111
+ # 선택된 카테고리의 혼동 행렬 데이터
112
+ tp = metrics_data[selected_category]['tp']
113
+ tn = metrics_data[selected_category]['tn']
114
+ fp = metrics_data[selected_category]['fp']
115
+ fn = metrics_data[selected_category]['fn']
116
+
117
+ # 혼동 행렬 데이터
118
+ z = [[tn, fp], [fn, tp]]
119
+ x = ['Negative', 'Positive']
120
+ y = ['Negative', 'Positive']
121
+
122
+ # 히트맵 생성
123
+ fig = go.Figure(data=go.Heatmap(
124
+ z=z,
125
+ x=x,
126
+ y=y,
127
+ colorscale=[[0, '#f7fbff'], [1, '#08306b']],
128
+ showscale=False,
129
+ text=[[str(val) for val in row] for row in z],
130
+ texttemplate="%{text}",
131
+ textfont={"color": "black", "size": 16}, # 글자 색���을 검정색으로 고정
132
+ ))
133
+
134
+ # 레이아웃 업데이트
135
+ fig.update_layout(
136
+ title={
137
+ 'text': f'Confusion Matrix - {selected_category}',
138
+ 'y':0.9,
139
+ 'x':0.5,
140
+ 'xanchor': 'center',
141
+ 'yanchor': 'top'
142
+ },
143
+ xaxis_title='Predicted',
144
+ yaxis_title='Actual',
145
+ width=600, # 너비 증가
146
+ height=600, # 높이 증가
147
+ margin=dict(l=80, r=80, t=100, b=80), # 여백 조정
148
+ paper_bgcolor='white',
149
+ plot_bgcolor='white',
150
+ font=dict(size=14) # 전체 폰트 크기 조정
151
+ )
152
+
153
+ # 축 설정
154
+ fig.update_xaxes(side="bottom", tickfont=dict(size=14))
155
+ fig.update_yaxes(side="left", tickfont=dict(size=14))
156
+
157
+ return fig
158
+
159
+ def get_metrics_for_model(df, model_name, benchmark_name):
160
+ """특정 모델과 벤치마크에 대한 메트릭스 데이터 추출"""
161
+ row = df[(df['Model name'] == model_name) & (df['Benchmark'] == benchmark_name)]
162
+ if not row.empty:
163
+ metrics = str2json(row['PIA'].iloc[0])
164
+ return metrics
165
+ return None
166
+
167
+ def metric_visual_tab():
168
+ # 데이터 로드
169
+ df = sheet2df(sheet_name="metric")
170
+ avg_metrics_df = calculate_avg_metrics(df)
171
+
172
+ # 가능한 모든 메트릭 리스트
173
+ all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
174
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
175
+
176
+ with gr.Tab("📊 Performance Visualization"):
177
+ with gr.Row():
178
+ metrics_multiselect = gr.CheckboxGroup(
179
+ choices=all_metrics,
180
+ value=[], # 초기 선택 없음
181
+ label="Select Performance Metrics",
182
+ interactive=True
183
+ )
184
+
185
+ # Performance comparison chart (초기값 없음)
186
+ performance_plot = gr.Plot()
187
+
188
+ def update_plot(selected_metrics):
189
+ if not selected_metrics: # 선택된 메트릭이 없는 경우
190
+ return None
191
+
192
+ try:
193
+ # accuracy 기준으로 정렬
194
+ sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
195
+ return create_performance_chart(sorted_df, selected_metrics)
196
+ except Exception as e:
197
+ print(f"Error in update_plot: {str(e)}")
198
+ return None
199
+
200
+ # Connect event handler
201
+ metrics_multiselect.change(
202
+ fn=update_plot,
203
+ inputs=[metrics_multiselect],
204
+ outputs=[performance_plot]
205
+ )
206
+
207
+ def create_category_metrics_chart(metrics_data, selected_metrics):
208
+ """
209
+ 선택된 모델의 각 카테고리별 성능 지표 시각화
210
+ """
211
+ fig = go.Figure()
212
+ categories = ['falldown', 'violence', 'fire']
213
+
214
+ for metric in selected_metrics:
215
+ values = []
216
+ for category in categories:
217
+ values.append(metrics_data[category][metric])
218
+
219
+ fig.add_trace(go.Bar(
220
+ name=metric,
221
+ x=categories,
222
+ y=values,
223
+ text=[f'{val:.3f}' for val in values],
224
+ textposition='auto',
225
+ ))
226
+
227
+ fig.update_layout(
228
+ title='Performance Metrics by Category',
229
+ xaxis_title='Category',
230
+ yaxis_title='Score',
231
+ barmode='group',
232
+ height=500,
233
+ showlegend=True,
234
+ legend=dict(
235
+ orientation="h",
236
+ yanchor="bottom",
237
+ y=1.02,
238
+ xanchor="right",
239
+ x=1
240
+ )
241
+ )
242
+
243
+ return fig
244
+
245
+ def metric_visual_tab():
246
+ # 데이터 로드 및 첫 번째 시각화 부분
247
+ df = sheet2df(sheet_name="metric")
248
+ avg_metrics_df = calculate_avg_metrics(df)
249
+
250
+ # 가능한 모든 메트릭 리스트
251
+ all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
252
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
253
+
254
+ with gr.Tab("📊 Performance Visualization"):
255
+ with gr.Row():
256
+ metrics_multiselect = gr.CheckboxGroup(
257
+ choices=all_metrics,
258
+ value=[], # 초기 선택 없음
259
+ label="Select Performance Metrics",
260
+ interactive=True
261
+ )
262
+
263
+ performance_plot = gr.Plot()
264
+
265
+ def update_plot(selected_metrics):
266
+ if not selected_metrics:
267
+ return None
268
+ try:
269
+ sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
270
+ return create_performance_chart(sorted_df, selected_metrics)
271
+ except Exception as e:
272
+ print(f"Error in update_plot: {str(e)}")
273
+ return None
274
+
275
+ metrics_multiselect.change(
276
+ fn=update_plot,
277
+ inputs=[metrics_multiselect],
278
+ outputs=[performance_plot]
279
+ )
280
+
281
+ # 두 번째 시각화 섹션
282
+ gr.Markdown("## Detailed Model Analysis")
283
+
284
+ with gr.Row():
285
+ # 모델 선택
286
+ model_dropdown = gr.Dropdown(
287
+ choices=sorted(df['Model name'].unique().tolist()),
288
+ label="Select Model",
289
+ interactive=True
290
+ )
291
+
292
+ # 컬럼 선택 (Model name 제외)
293
+ column_dropdown = gr.Dropdown(
294
+ choices=[col for col in df.columns if col != 'Model name'],
295
+ label="Select Metric Column",
296
+ interactive=True
297
+ )
298
+
299
+ # 카테고리 선택
300
+ category_dropdown = gr.Dropdown(
301
+ choices=['falldown', 'violence', 'fire'],
302
+ label="Select Category",
303
+ interactive=True
304
+ )
305
+
306
+ # 혼동 행렬 시각화
307
+ with gr.Row():
308
+ with gr.Column(scale=1):
309
+ gr.Markdown("") # 빈 공간
310
+ with gr.Column(scale=2):
311
+ confusion_matrix_plot = gr.Plot(container=True) # container=True 추가
312
+ with gr.Column(scale=1):
313
+ gr.Markdown("") # 빈 공간
314
+
315
+ with gr.Column(scale=2):
316
+ # 성능 지표 선택
317
+ metrics_select = gr.CheckboxGroup(
318
+ choices=['accuracy', 'precision', 'recall', 'specificity', 'f1',
319
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'],
320
+ value=['accuracy'], # 기본값
321
+ label="Select Metrics to Display",
322
+ interactive=True
323
+ )
324
+ category_metrics_plot = gr.Plot()
325
+
326
+ def update_visualizations(model, column, category, selected_metrics):
327
+ if not all([model, column]): # category는 혼동행렬에만 필요
328
+ return None, None
329
+
330
+ try:
331
+ # 선택된 모델의 데이터 가져오기
332
+ selected_data = df[df['Model name'] == model][column].iloc[0]
333
+ metrics = str2json(selected_data)
334
+
335
+ if not metrics:
336
+ return None, None
337
+
338
+ # 혼동 행렬 (왼쪽)
339
+ confusion_fig = create_confusion_matrix(metrics, category) if category else None
340
+
341
+ # 카테고리별 성능 지표 (오른쪽)
342
+ if not selected_metrics:
343
+ selected_metrics = ['accuracy']
344
+ category_fig = create_category_metrics_chart(metrics, selected_metrics)
345
+
346
+ return confusion_fig, category_fig
347
+
348
+ except Exception as e:
349
+ print(f"Error updating visualizations: {str(e)}")
350
+ return None, None
351
+
352
+ # 이벤트 핸들러 연결
353
+ for input_component in [model_dropdown, column_dropdown, category_dropdown, metrics_select]:
354
+ input_component.change(
355
+ fn=update_visualizations,
356
+ inputs=[model_dropdown, column_dropdown, category_dropdown, metrics_select],
357
+ outputs=[confusion_matrix_plot, category_metrics_plot]
358
+ )
359
+ # def update_confusion_matrix(model, column, category):
360
+ # if not all([model, column, category]):
361
+ # return None
362
+
363
+ # try:
364
+ # # 선택된 모델의 데이터 가져오기
365
+ # selected_data = df[df['Model name'] == model][column].iloc[0]
366
+ # metrics = str2json(selected_data)
367
+
368
+ # if metrics and category in metrics:
369
+ # category_data = metrics[category]
370
+
371
+ # # 혼동 행렬 데이터
372
+ # confusion_data = {
373
+ # 'tp': category_data['tp'],
374
+ # 'tn': category_data['tn'],
375
+ # 'fp': category_data['fp'],
376
+ # 'fn': category_data['fn']
377
+ # }
378
+
379
+ # # 히트맵 생성
380
+ # z = [[confusion_data['tn'], confusion_data['fp']],
381
+ # [confusion_data['fn'], confusion_data['tp']]]
382
+
383
+ # fig = go.Figure(data=go.Heatmap(
384
+ # z=z,
385
+ # x=['Negative', 'Positive'],
386
+ # y=['Negative', 'Positive'],
387
+ # text=[[str(val) for val in row] for row in z],
388
+ # texttemplate="%{text}",
389
+ # textfont={"size": 16},
390
+ # colorscale='Blues',
391
+ # showscale=False
392
+ # ))
393
+
394
+ # fig.update_layout(
395
+ # title=f'Confusion Matrix - {category}',
396
+ # xaxis_title='Predicted',
397
+ # yaxis_title='Actual',
398
+ # width=500,
399
+ # height=500
400
+ # )
401
+
402
+ # return fig
403
+
404
+ # except Exception as e:
405
+ # print(f"Error updating confusion matrix: {str(e)}")
406
+ # return None
407
+
408
+ # # 이벤트 핸들러 연결
409
+ # for dropdown in [model_dropdown, column_dropdown, category_dropdown]:
410
+ # dropdown.change(
411
+ # fn=update_confusion_matrix,
412
+ # inputs=[model_dropdown, column_dropdown, category_dropdown],
413
+ # outputs=confusion_matrix_plot
414
+ # )
415
+
416
+
417
+
418
+
leaderboard_ui/tab/submit_tab.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
3
+ import pandas as pd
4
+
5
+ def list_to_dataframe(data):
6
+ """
7
+ 리스트 데이터를 데이터프레임으로 변환하는 함수.
8
+ 각 값이 데이터프레임의 한 행(row)에 들어가도록 설정.
9
+
10
+ :param data: 리스트 형태의 데이터
11
+ :return: pandas.DataFrame
12
+ """
13
+ if not isinstance(data, list):
14
+ raise ValueError("입력 데이터는 리스트 형태여야 합니다.")
15
+
16
+ # 열 이름을 문자열로 설정
17
+ headers = [f"Queue {i}" for i in range(len(data))]
18
+ df = pd.DataFrame([data], columns=headers)
19
+ return df
20
+
21
+ def model_submit(model_id , benchmark_name, prompt_cfg_name):
22
+ model_id = model_id.split("/")[-1]
23
+ sheet_manager = SheetManager()
24
+ sheet_manager.push(model_id)
25
+ model_q = list_to_dataframe(sheet_manager.get_all_values())
26
+ sheet_manager.change_column("benchmark_name")
27
+ sheet_manager.push(benchmark_name)
28
+ sheet_manager.change_column("prompt_cfg_name")
29
+ sheet_manager.push(prompt_cfg_name)
30
+
31
+ return model_q
32
+
33
+ def read_queue():
34
+ sheet_manager = SheetManager()
35
+ return list_to_dataframe(sheet_manager.get_all_values())
36
+
37
+ def submit_tab():
38
+ with gr.Tab("🚀 Submit here! "):
39
+ with gr.Row():
40
+ gr.Markdown("# ✉️✨ Submit your Result here!")
41
+
42
+ with gr.Row():
43
+ with gr.Tab("Model"):
44
+ with gr.Row():
45
+ with gr.Column():
46
+ model_id_textbox = gr.Textbox(
47
+ label="huggingface_id",
48
+ placeholder="PIA-SPACE-LAB/T2V_CLIP4Clip",
49
+ interactive = True
50
+ )
51
+ benchmark_name_textbox = gr.Textbox(
52
+ label="benchmark_name",
53
+ placeholder="PiaFSV",
54
+ interactive = True,
55
+ value="PIA"
56
+ )
57
+ prompt_cfg_name_textbox = gr.Textbox(
58
+ label="prompt_cfg_name",
59
+ placeholder="topk",
60
+ interactive = True,
61
+ value="topk"
62
+ )
63
+ with gr.Column():
64
+ gr.Markdown("## 평가를 받아보세요 반드시 허깅페이스에 업로드된 모델이어야 합니다.")
65
+ gr.Markdown("#### 현재 평가 대기중 모델입니다.")
66
+ model_queue = gr.Dataframe()
67
+ refresh_button = gr.Button("refresh")
68
+ refresh_button.click(
69
+ fn=read_queue,
70
+ outputs=model_queue
71
+ )
72
+ with gr.Row():
73
+ model_submit_button = gr.Button("Submit Eval")
74
+ model_submit_button.click(
75
+ fn=model_submit,
76
+ inputs=[model_id_textbox,
77
+ benchmark_name_textbox ,
78
+ prompt_cfg_name_textbox],
79
+ outputs=model_queue
80
+ )
81
+ with gr.Tab("Prompt"):
82
+ with gr.Row():
83
+ with gr.Column():
84
+ prompt_cfg_selector = gr.Dropdown(
85
+ choices=["전부"],
86
+ label="Prompt_CFG",
87
+ multiselect=False,
88
+ value=None,
89
+ interactive=True,
90
+ )
91
+ weight_type = gr.Dropdown(
92
+ choices=["전부"],
93
+ label="Weights type",
94
+ multiselect=False,
95
+ value=None,
96
+ interactive=True,
97
+ )
98
+ with gr.Column():
99
+ gr.Markdown("## 평가를 받아보세요 반드시 허깅페이스에 업로드된 모델이어야 합니다.")
100
+
101
+ with gr.Row():
102
+ prompt_submit_button = gr.Button("Submit Eval")
103
+
main.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import os
4
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
5
+ from sheet_manager.sheet_monitor.sheet_sync import SheetMonitor, MainLoop
6
+ import time
7
+ from pia_bench.pipe_line.piepline import BenchmarkPipeline, PipelineConfig
8
+ from sheet_manager.sheet_convert.json2sheet import update_benchmark_json
9
+ import os
10
+ import shutil
11
+ import json
12
+
13
+ def calculate_total_accuracy(metrics: dict) -> float:
14
+ """
15
+ Calculate the average accuracy across all categories excluding 'micro_avg'.
16
+
17
+ Args:
18
+ metrics (dict): Metrics dictionary containing accuracy values.
19
+
20
+ Returns:
21
+ float: The average accuracy across categories.
22
+ """
23
+ total_accuracy = 0
24
+ total_count = 0
25
+
26
+ for category, values in metrics.items():
27
+ if category == "micro_avg":
28
+ continue # Skip 'micro_avg'
29
+
30
+ if "accuracy" in values:
31
+ total_accuracy += values["accuracy"]
32
+ total_count += 1
33
+
34
+ if total_count == 0:
35
+ raise ValueError("No accuracy values found in the provided metrics dictionary.")
36
+
37
+ return total_accuracy / total_count
38
+
39
+ def my_custom_function(huggingface_id, benchmark_name, prompt_cfg_name):
40
+ model_name = huggingface_id.split("/")[-1]
41
+ config = PipelineConfig(
42
+ model_name=model_name,
43
+ benchmark_name=benchmark_name,
44
+ cfg_target_path=f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{benchmark_name}/CFG/{prompt_cfg_name}.json",
45
+ base_path="/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
46
+ )
47
+ pipeline = BenchmarkPipeline(config)
48
+ pipeline.run()
49
+ result = pipeline.bench_result_dict
50
+ value = calculate_total_accuracy(result)
51
+ print("---"*50)
52
+ sheet = SheetManager()
53
+ sheet.change_worksheet("model")
54
+ sheet.update_cell_by_condition(condition_column="Model name",
55
+ condition_value=model_name ,
56
+ target_column=benchmark_name,
57
+ target_value=value)
58
+
59
+ update_benchmark_json(
60
+ model_name = model_name,
61
+ benchmark_data = result,
62
+ target_column = benchmark_name # 타겟 칼럼 파라미터 추가
63
+ )
64
+
65
+ print(f"\n파이프라인 실행 결과:")
66
+
67
+ sheet_manager = SheetManager()
68
+ monitor = SheetMonitor(sheet_manager, check_interval=60.0)
69
+ main_loop = MainLoop(sheet_manager, monitor, callback_function=my_custom_function)
70
+
71
+ try:
72
+ main_loop.start()
73
+ while True:
74
+ time.sleep(5)
75
+ except KeyboardInterrupt:
76
+ main_loop.stop()
pia_bench/bench.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from devmacs_core.devmacs_core import DevMACSCore
4
+ import json
5
+ from typing import Dict, List, Tuple
6
+ from pathlib import Path
7
+ import pandas as pd
8
+ from utils.except_dir import cust_listdir
9
+ def load_config(config_path: str) -> Dict:
10
+ """JSON 설정 파일을 읽어서 딕셔너리로 반환"""
11
+ with open(config_path, 'r', encoding='utf-8') as f:
12
+ return json.load(f)
13
+
14
+ DATA_SET = "dataset"
15
+ CFG = "CFG"
16
+ VECTOR = "vector"
17
+ TEXT = "text"
18
+ VIDEO = "video"
19
+ EXECPT = ["@eaDir", "README.md"]
20
+ ALRAM = "alarm"
21
+ METRIC = "metric"
22
+ MSRVTT = "MSRVTT"
23
+ MODEL = "models"
24
+
25
+ class PiaBenchMark:
26
+ def __init__(self, benchmark_path , cfg_target_path : str = None , model_name : str = MSRVTT , token:str =None):
27
+ self.benchmark_path = benchmark_path
28
+ self.token = token
29
+ self.model_name = model_name
30
+ self.devmacs_core = None
31
+ self.cfg_target_path = cfg_target_path
32
+ self.cfg_name = Path(cfg_target_path).stem
33
+ self.cfg_dict = load_config(self.cfg_target_path)
34
+
35
+ self.dataset_path = os.path.join(benchmark_path, DATA_SET)
36
+ self.cfg_path = os.path.join(benchmark_path , CFG)
37
+
38
+ self.model_path = os.path.join(self.benchmark_path , MODEL)
39
+ self.model_name_path = os.path.join(self.model_path ,self.model_name)
40
+ self.model_name_cfg_path = os.path.join(self.model_name_path , CFG)
41
+ self.model_name_cfg_name_path = os.path.join(self.model_name_cfg_path , self.cfg_name)
42
+ self.alram_path = os.path.join(self.model_name_cfg_name_path , ALRAM)
43
+ self.metric_path = os.path.join(self.model_name_cfg_name_path , METRIC)
44
+
45
+ self.vector_path = os.path.join(self.model_name_path , VECTOR)
46
+ self.vector_text_path = os.path.join(self.vector_path , TEXT)
47
+ self.vector_video_path = os.path.join(self.vector_path , VIDEO)
48
+
49
+ self.categories = []
50
+
51
+ def _create_frame_labels(self, label_data: Dict, total_frames: int) -> pd.DataFrame:
52
+ """프레임 기반의 레이블 데이터프레임 생성"""
53
+ colmuns = ['frame'] + sorted(self.categories)
54
+ df = pd.DataFrame(0, index=range(total_frames), columns=colmuns)
55
+ df['frame'] = range(total_frames)
56
+
57
+ for clip_info in label_data['clips'].values():
58
+ category = clip_info['category']
59
+ if category in self.categories: # 해당 카테고리가 목록에 있는 경우만 처리
60
+ start_frame, end_frame = clip_info['timestamp']
61
+ df.loc[start_frame:end_frame, category] = 1
62
+
63
+ return df
64
+
65
+ def preprocess_label_to_csv(self):
66
+ """데이터셋의 모든 JSON 라벨을 프레임 기반 CSV로 변환"""
67
+ json_files = []
68
+ csv_files = []
69
+
70
+ # categories가 비어있는 경우에만 채우도록 수정
71
+ if not self.categories:
72
+ for cate in cust_listdir(self.dataset_path):
73
+ if os.path.isdir(os.path.join(self.dataset_path, cate)):
74
+ self.categories.append(cate)
75
+
76
+ for category in self.categories:
77
+ category_path = os.path.join(self.dataset_path, category)
78
+ category_jsons = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.json')]
79
+ json_files.extend(category_jsons)
80
+ category_csvs = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.csv')]
81
+ csv_files.extend(category_csvs)
82
+
83
+ if not json_files:
84
+ raise ValueError("No JSON files found in any category directory")
85
+
86
+ if len(json_files) == len(csv_files):
87
+ print("All JSON files have already been processed to CSV. No further processing needed.")
88
+ return
89
+
90
+ for json_file in json_files:
91
+ json_path = os.path.join(self.dataset_path, json_file)
92
+ video_name = os.path.splitext(json_file)[0]
93
+
94
+ label_info = load_config(json_path)
95
+ video_info = label_info['video_info']
96
+ total_frames = video_info['total_frame']
97
+
98
+ df = self._create_frame_labels( label_info, total_frames)
99
+
100
+ output_path = os.path.join(self.dataset_path, f"{video_name}.csv")
101
+ df.to_csv(output_path , index=False)
102
+ print("Complete !")
103
+
104
+ def preprocess_structure(self):
105
+ os.makedirs(self.dataset_path, exist_ok=True)
106
+ os.makedirs(self.cfg_path, exist_ok=True)
107
+ os.makedirs(self.vector_text_path, exist_ok=True)
108
+ os.makedirs(self.vector_video_path, exist_ok=True)
109
+ os.makedirs(self.alram_path, exist_ok=True)
110
+ os.makedirs(self.metric_path, exist_ok=True)
111
+ os.makedirs(self.model_name_cfg_name_path , exist_ok=True)
112
+
113
+
114
+ # dataset 폴더가 이미 존재하고 그 안에 카테고리 폴더들이 있는지 확인
115
+ if os.path.exists(self.dataset_path) and any(os.path.isdir(os.path.join(self.dataset_path, d)) for d in cust_listdir(self.dataset_path)):
116
+ # 이미 구성된 구조라면, dataset 폴더에서 카테고리들을 가져옴
117
+ self.categories = [d for d in cust_listdir(self.dataset_path) if os.path.isdir(os.path.join(self.dataset_path, d))]
118
+ else:
119
+ # 처음 실행되는 경우, 기존 로직대로 진행
120
+ for item in cust_listdir(self.benchmark_path):
121
+ item_path = os.path.join(self.benchmark_path, item)
122
+
123
+ if item.startswith("@") or item in [METRIC ,"README.md",MODEL, CFG, DATA_SET, VECTOR, ALRAM] or not os.path.isdir(item_path):
124
+ continue
125
+ target_path = os.path.join(self.dataset_path, item)
126
+ if not os.path.exists(target_path):
127
+ shutil.move(item_path, target_path)
128
+ self.categories.append(item)
129
+
130
+ for category in self.categories:
131
+ category_path = os.path.join(self.vector_video_path, category)
132
+ os.makedirs(category_path, exist_ok=True)
133
+
134
+ print("Folder preprocessing completed.")
135
+
136
+ def extract_visual_vector(self):
137
+ self.devmacs_core = DevMACSCore.from_huggingface(token=self.token, repo_id=f"PIA-SPACE-LAB/{self.model_name}")
138
+ self.devmacs_core.save_visual_results(
139
+ vid_dir = self.dataset_path,
140
+ result_dir = self.vector_video_path
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ from dotenv import load_dotenv
145
+ import os
146
+ load_dotenv()
147
+
148
+ access_token = os.getenv("ACCESS_TOKEN")
149
+ model_name = "T2V_CLIP4CLIP_MSRVTT"
150
+
151
+ benchmark_path = "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA"
152
+ cfg_target_path= "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA/CFG/topk.json"
153
+
154
+ pia_benchmark = PiaBenchMark(benchmark_path ,model_name=model_name, cfg_target_path= cfg_target_path , token=access_token )
155
+ pia_benchmark.preprocess_structure()
156
+ pia_benchmark.preprocess_label_to_csv()
157
+ print("Categories identified:", pia_benchmark.categories)
pia_bench/checker/bench_checker.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Dict, Optional, Tuple
4
+ from pathlib import Path
5
+ import json
6
+ import numpy as np
7
+ logging.basicConfig(level=logging.INFO)
8
+
9
+ class BenchChecker:
10
+ def __init__(self, base_path: str):
11
+ """Initialize BenchChecker with base assets path.
12
+
13
+ Args:
14
+ base_path (str): Base path to assets directory containing benchmark folders
15
+ """
16
+ self.base_path = Path(base_path)
17
+ self.logger = logging.getLogger(__name__)
18
+
19
+ def check_benchmark_exists(self, benchmark_name: str) -> bool:
20
+ """Check if benchmark folder exists."""
21
+ benchmark_path = self.base_path / benchmark_name
22
+ exists = benchmark_path.exists() and benchmark_path.is_dir()
23
+ if exists:
24
+ self.logger.info(f"Found benchmark directory: {benchmark_name}")
25
+ else:
26
+ self.logger.error(f"Benchmark directory not found: {benchmark_name}")
27
+ return exists
28
+
29
+ def get_video_list(self, benchmark_name: str) -> List[str]:
30
+ """Get list of videos from benchmark's dataset directory. Return empty list if no videos found."""
31
+ dataset_path = self.base_path / benchmark_name / "dataset"
32
+ videos = []
33
+
34
+ if not dataset_path.exists():
35
+ self.logger.info(f"Dataset directory exists but no videos found for {benchmark_name}")
36
+ return videos # 빈 리스트 반환
37
+
38
+ # Recursively find all .mp4 files
39
+ for category in dataset_path.glob("*"):
40
+ if category.is_dir():
41
+ for video_file in category.glob("*.mp4"):
42
+ videos.append(video_file.stem)
43
+
44
+ self.logger.info(f"Found {len(videos)} videos in {benchmark_name} dataset")
45
+ return videos
46
+
47
+ def check_model_exists(self, benchmark_name: str, model_name: str) -> bool:
48
+ """Check if model directory exists in benchmark's models directory."""
49
+ model_path = self.base_path / benchmark_name / "models" / model_name
50
+ exists = model_path.exists() and model_path.is_dir()
51
+ if exists:
52
+ self.logger.info(f"Found model directory: {model_name}")
53
+ else:
54
+ self.logger.error(f"Model directory not found: {model_name}")
55
+ return exists
56
+
57
+ def check_cfg_files(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Tuple[bool, bool]:
58
+ """Check if CFG files/directories exist in both benchmark and model directories."""
59
+ # Check benchmark CFG json
60
+ benchmark_cfg = self.base_path / benchmark_name / "CFG" / f"{cfg_prompt}.json"
61
+ benchmark_cfg_exists = benchmark_cfg.exists() and benchmark_cfg.is_file()
62
+
63
+ # Check model CFG directory
64
+ model_cfg = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt
65
+ model_cfg_exists = model_cfg.exists() and model_cfg.is_dir()
66
+
67
+ if benchmark_cfg_exists:
68
+ self.logger.info(f"Found benchmark CFG file: {cfg_prompt}.json")
69
+ else:
70
+ self.logger.error(f"Benchmark CFG file not found: {cfg_prompt}.json")
71
+
72
+ if model_cfg_exists:
73
+ self.logger.info(f"Found model CFG directory: {cfg_prompt}")
74
+ else:
75
+ self.logger.error(f"Model CFG directory not found: {cfg_prompt}")
76
+
77
+ return benchmark_cfg_exists, model_cfg_exists
78
+ def check_vector_files(self, benchmark_name: str, model_name: str, video_list: List[str]) -> bool:
79
+ """Check if video vectors match with dataset."""
80
+ vector_path = self.base_path / benchmark_name / "models" / model_name / "vector" / "video"
81
+
82
+ # 비디오가 없는 경우는 무조건 False
83
+ if not video_list:
84
+ self.logger.error("No videos found in dataset - cannot proceed")
85
+ return False
86
+
87
+ # 벡터 디렉토리가 있는지 확인
88
+ if not vector_path.exists():
89
+ self.logger.error("Vector directory doesn't exist")
90
+ return False
91
+
92
+ # 벡터 파일 리스트 가져오기
93
+ # vector_files = [f.stem for f in vector_path.glob("*.npy")]
94
+ vector_files = [f.stem for f in vector_path.rglob("*.npy")]
95
+
96
+ missing_vectors = set(video_list) - set(vector_files)
97
+ extra_vectors = set(vector_files) - set(video_list)
98
+
99
+ if missing_vectors:
100
+ self.logger.error(f"Missing vectors for videos: {missing_vectors}")
101
+ return False
102
+ if extra_vectors:
103
+ self.logger.error(f"Extra vectors found: {extra_vectors}")
104
+ return False
105
+
106
+ self.logger.info(f"Vector status: videos={len(video_list)}, vectors={len(vector_files)}")
107
+ return len(video_list) == len(vector_files)
108
+
109
+ def check_metrics_file(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> bool:
110
+ """Check if overall_metrics.json exists in the model's CFG/metrics directory."""
111
+ metrics_path = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt / "metric" / "overall_metrics.json"
112
+ exists = metrics_path.exists() and metrics_path.is_file()
113
+
114
+ if exists:
115
+ self.logger.info(f"Found overall metrics file for {model_name}")
116
+ else:
117
+ self.logger.error(f"Overall metrics file not found for {model_name}")
118
+ return exists
119
+
120
+ def check_benchmark(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Dict[str, bool]:
121
+ """
122
+ Perform all benchmark checks and return status.
123
+ """
124
+ status = {
125
+ 'benchmark_exists': False,
126
+ 'model_exists': False,
127
+ 'cfg_files_exist': False,
128
+ 'vectors_match': False,
129
+ 'metrics_exist': False
130
+ }
131
+
132
+ # Check benchmark directory
133
+ status['benchmark_exists'] = self.check_benchmark_exists(benchmark_name)
134
+ if not status['benchmark_exists']:
135
+ return status
136
+
137
+ # Get video list
138
+ video_list = self.get_video_list(benchmark_name)
139
+
140
+ # Check model directory
141
+ status['model_exists'] = self.check_model_exists(benchmark_name, model_name)
142
+ if not status['model_exists']:
143
+ return status
144
+
145
+ # Check CFG files
146
+ benchmark_cfg, model_cfg = self.check_cfg_files(benchmark_name, model_name, cfg_prompt)
147
+ status['cfg_files_exist'] = benchmark_cfg and model_cfg
148
+ if not status['cfg_files_exist']:
149
+ return status
150
+
151
+ # Check vectors
152
+ status['vectors_match'] = self.check_vector_files(benchmark_name, model_name, video_list)
153
+
154
+ # Check metrics file (only if vectors match)
155
+ if status['vectors_match']:
156
+ status['metrics_exist'] = self.check_metrics_file(benchmark_name, model_name, cfg_prompt)
157
+
158
+ return status
159
+
160
+ def get_benchmark_status(self, check_status: Dict[str, bool]) -> str:
161
+ """Determine which execution path to take based on check results."""
162
+ basic_checks = ['benchmark_exists', 'model_exists', 'cfg_files_exist']
163
+ if not all(check_status[check] for check in basic_checks):
164
+ return "cannot_execute"
165
+ if check_status['vectors_match'] and check_status['metrics_exist']:
166
+ return "all_passed"
167
+ elif not check_status['vectors_match']:
168
+ return "no_vectors"
169
+ else: # vectors exist but no metrics
170
+ return "no_metrics"
171
+
172
+ # Example usage
173
+ if __name__ == "__main__":
174
+
175
+ bench_checker = BenchChecker("assets")
176
+ status = bench_checker.check_benchmark(
177
+ benchmark_name="huggingface_benchmarks_dataset",
178
+ model_name="MSRVTT",
179
+ cfg_prompt="topk"
180
+ )
181
+
182
+ execution_path = bench_checker.get_benchmark_status(status)
183
+ print(f"Checks completed. Execution path: {execution_path}")
184
+ print(f"Status: {status}")
pia_bench/checker/sheet_checker.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Set, Tuple
2
+ import logging
3
+ import gspread
4
+ from dotenv import load_dotenv
5
+ from typing import Optional, List
6
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
7
+
8
+ load_dotenv()
9
+ class SheetChecker:
10
+ def __init__(self, sheet_manager):
11
+ """Initialize SheetChecker with a sheet manager instance."""
12
+ self.sheet_manager = sheet_manager
13
+ self.bench_sheet_manager = None
14
+ self.logger = logging.getLogger(__name__)
15
+ self._init_bench_sheet()
16
+
17
+ def _init_bench_sheet(self):
18
+ """Initialize sheet manager for the model sheet."""
19
+ self.bench_sheet_manager = type(self.sheet_manager)(
20
+ spreadsheet_url=self.sheet_manager.spreadsheet_url,
21
+ worksheet_name="model",
22
+ column_name="Model name"
23
+ )
24
+ def add_benchmark_column(self, column_name: str):
25
+ """Add a new benchmark column to the sheet."""
26
+ try:
27
+ # Get current headers
28
+ headers = self.bench_sheet_manager.get_available_columns()
29
+
30
+ # If column already exists, return
31
+ if column_name in headers:
32
+ return
33
+
34
+ # Add new column header
35
+ new_col_index = len(headers) + 1
36
+ cell = gspread.utils.rowcol_to_a1(1, new_col_index)
37
+ # Update with 2D array format
38
+ self.bench_sheet_manager.sheet.update(cell, [[column_name]]) # 값을 2D 배열로 변경
39
+ self.logger.info(f"Added new benchmark column: {column_name}")
40
+
41
+ # Update headers in bench_sheet_manager
42
+ self.bench_sheet_manager._connect_to_sheet(validate_column=False)
43
+
44
+ except Exception as e:
45
+ self.logger.error(f"Error adding benchmark column {column_name}: {str(e)}")
46
+ raise
47
+ def validate_benchmark_columns(self, benchmark_columns: List[str]) -> Tuple[List[str], List[str]]:
48
+ """
49
+ Validate benchmark columns and add missing ones.
50
+
51
+ Args:
52
+ benchmark_columns: List of benchmark column names to validate
53
+
54
+ Returns:
55
+ Tuple[List[str], List[str]]: (valid columns, invalid columns)
56
+ """
57
+ available_columns = self.bench_sheet_manager.get_available_columns()
58
+ valid_columns = []
59
+ invalid_columns = []
60
+
61
+ for col in benchmark_columns:
62
+ if col in available_columns:
63
+ valid_columns.append(col)
64
+ else:
65
+ try:
66
+ self.add_benchmark_column(col)
67
+ valid_columns.append(col)
68
+ self.logger.info(f"Added new benchmark column: {col}")
69
+ except Exception as e:
70
+ invalid_columns.append(col)
71
+ self.logger.error(f"Failed to add benchmark column '{col}': {str(e)}")
72
+
73
+ return valid_columns, invalid_columns
74
+
75
+ def check_model_and_benchmarks(self, model_name: str, benchmark_columns: List[str]) -> Dict[str, List[str]]:
76
+ """
77
+ Check model existence and which benchmarks need to be filled.
78
+
79
+ Args:
80
+ model_name: Name of the model to check
81
+ benchmark_columns: List of benchmark column names to check
82
+
83
+ Returns:
84
+ Dict with keys:
85
+ 'status': 'model_not_found' or 'model_exists'
86
+ 'empty_benchmarks': List of benchmark columns that need to be filled
87
+ 'filled_benchmarks': List of benchmark columns that are already filled
88
+ 'invalid_benchmarks': List of benchmark columns that don't exist
89
+ """
90
+ result = {
91
+ 'status': '',
92
+ 'empty_benchmarks': [],
93
+ 'filled_benchmarks': [],
94
+ 'invalid_benchmarks': []
95
+ }
96
+
97
+ # First check if model exists
98
+ exists = self.check_model_exists(model_name)
99
+ if not exists:
100
+ result['status'] = 'model_not_found'
101
+ return result
102
+
103
+ result['status'] = 'model_exists'
104
+
105
+ # Validate benchmark columns
106
+ valid_columns, invalid_columns = self.validate_benchmark_columns(benchmark_columns)
107
+ result['invalid_benchmarks'] = invalid_columns
108
+
109
+ if not valid_columns:
110
+ return result
111
+
112
+ # Check which valid benchmarks are empty
113
+ self.bench_sheet_manager.change_column("Model name")
114
+ all_values = self.bench_sheet_manager.get_all_values()
115
+ row_index = all_values.index(model_name) + 2
116
+
117
+ for column in valid_columns:
118
+ try:
119
+ self.bench_sheet_manager.change_column(column)
120
+ value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
121
+ if not value or not value.strip():
122
+ result['empty_benchmarks'].append(column)
123
+ else:
124
+ result['filled_benchmarks'].append(column)
125
+ except Exception as e:
126
+ self.logger.error(f"Error checking column {column}: {str(e)}")
127
+ result['empty_benchmarks'].append(column)
128
+
129
+ return result
130
+
131
+ def update_model_info(self, model_name: str, model_info: Dict[str, str]):
132
+ """Update basic model information columns."""
133
+ try:
134
+ for column_name, value in model_info.items():
135
+ self.bench_sheet_manager.change_column(column_name)
136
+ self.bench_sheet_manager.push(value)
137
+ self.logger.info(f"Successfully added new model: {model_name}")
138
+ except Exception as e:
139
+ self.logger.error(f"Error updating model info: {str(e)}")
140
+ raise
141
+
142
+ def update_benchmarks(self, model_name: str, benchmark_values: Dict[str, str]):
143
+ """
144
+ Update benchmark values.
145
+
146
+ Args:
147
+ model_name: Name of the model
148
+ benchmark_values: Dictionary of benchmark column names and their values
149
+ """
150
+ try:
151
+ self.bench_sheet_manager.change_column("Model name")
152
+ all_values = self.bench_sheet_manager.get_all_values()
153
+ row_index = all_values.index(model_name) + 2
154
+
155
+ for column, value in benchmark_values.items():
156
+ self.bench_sheet_manager.change_column(column)
157
+ self.bench_sheet_manager.sheet.update_cell(row_index, self.bench_sheet_manager.col_index, value)
158
+ self.logger.info(f"Updated benchmark {column} for model {model_name}")
159
+
160
+ except Exception as e:
161
+ self.logger.error(f"Error updating benchmarks: {str(e)}")
162
+ raise
163
+
164
+ def check_model_exists(self, model_name: str) -> bool:
165
+ """Check if model exists in the sheet."""
166
+ try:
167
+ self.bench_sheet_manager.change_column("Model name")
168
+ values = self.bench_sheet_manager.get_all_values()
169
+ return model_name in values
170
+ except Exception as e:
171
+ self.logger.error(f"Error checking model existence: {str(e)}")
172
+ return False
173
+
174
+
175
+ def process_model_benchmarks(
176
+ model_name: str,
177
+ bench_checker: SheetChecker,
178
+ model_info_func,
179
+ benchmark_processor_func: callable,
180
+ benchmark_columns: List[str],
181
+ cfg_prompt: str
182
+ ) -> None:
183
+ """
184
+ Process model benchmarks according to the specified workflow.
185
+
186
+ Args:
187
+ model_name: Name of the model to process
188
+ bench_checker: SheetChecker instance
189
+ model_info_func: Function that returns model info (name, link, etc.)
190
+ benchmark_processor_func: Function that processes empty benchmarks and returns values
191
+ benchmark_columns: List of benchmark columns to check
192
+ """
193
+ try:
194
+ # Check model and benchmarks
195
+ check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
196
+
197
+ # Handle invalid benchmark columns
198
+ if check_result['invalid_benchmarks']:
199
+ bench_checker.logger.warning(
200
+ f"Skipping invalid benchmark columns: {', '.join(check_result['invalid_benchmarks'])}"
201
+ )
202
+
203
+ # If model doesn't exist, add it
204
+ if check_result['status'] == 'model_not_found':
205
+ model_info = model_info_func(model_name)
206
+ bench_checker.update_model_info(model_name, model_info)
207
+ bench_checker.logger.info(f"Added new model: {model_name}")
208
+ # Recheck benchmarks after adding model
209
+ check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
210
+
211
+ # Log filled benchmarks
212
+ if check_result['filled_benchmarks']:
213
+ bench_checker.logger.info(
214
+ f"Skipping filled benchmark columns: {', '.join(check_result['filled_benchmarks'])}"
215
+ )
216
+
217
+ # Process empty benchmarks
218
+ if check_result['empty_benchmarks']:
219
+ bench_checker.logger.info(
220
+ f"Processing empty benchmark columns: {', '.join(check_result['empty_benchmarks'])}"
221
+ )
222
+ # Get benchmark values from processor function
223
+ benchmark_values = benchmark_processor_func(
224
+ model_name,
225
+ check_result['empty_benchmarks'],
226
+ cfg_prompt
227
+ )
228
+ # Update benchmarks
229
+ bench_checker.update_benchmarks(model_name, benchmark_values)
230
+ else:
231
+ bench_checker.logger.info("No empty benchmark columns to process")
232
+
233
+ except Exception as e:
234
+ bench_checker.logger.error(f"Error processing model {model_name}: {str(e)}")
235
+ raise
236
+
237
+ def get_model_info(model_name: str) -> Dict[str, str]:
238
+ return {
239
+ "Model name": model_name,
240
+ "Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
241
+ "Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
242
+
243
+ }
244
+
245
+ def process_benchmarks(
246
+ model_name: str,
247
+ empty_benchmarks: List[str],
248
+ cfg_prompt: str
249
+ ) -> Dict[str, str]:
250
+ """
251
+ Measure benchmark scores for given model with specific configuration.
252
+
253
+ Args:
254
+ model_name: Name of the model to evaluate
255
+ empty_benchmarks: List of benchmarks to measure
256
+ cfg_prompt: Prompt configuration for evaluation
257
+
258
+ Returns:
259
+ Dict[str, str]: Dictionary mapping benchmark names to their scores
260
+ """
261
+ result = {}
262
+ for benchmark in empty_benchmarks:
263
+ # 실제 벤치마크 측정 수행
264
+ # score = measure_benchmark(model_name, benchmark, cfg_prompt)
265
+ if benchmark == "COCO":
266
+ score = 0.5
267
+ elif benchmark == "ImageNet":
268
+ score = 15.0
269
+ result[benchmark] = str(score)
270
+ return result
271
+ # Example usage
272
+ if __name__ == "__main__":
273
+
274
+ sheet_manager = SheetManager()
275
+ bench_checker = SheetChecker(sheet_manager)
276
+
277
+ process_model_benchmarks(
278
+ "test-model",
279
+ bench_checker,
280
+ get_model_info,
281
+ process_benchmarks,
282
+ ["COCO", "ImageNet"],
283
+ "cfg_prompt_value"
284
+ )
pia_bench/event_alarm.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from typing import Dict, List, Tuple
5
+ from devmacs_core.devmacs_core import DevMACSCore
6
+ # from devmacs_core.devmacs_core_copy import DevMACSCore
7
+
8
+ from devmacs_core.utils.common.cal import loose_similarity
9
+ from utils.parser import load_config, PromptManager
10
+ import json
11
+ import pandas as pd
12
+ from tqdm import tqdm
13
+ import logging
14
+ from datetime import datetime
15
+ from utils.except_dir import cust_listdir
16
+
17
+ class EventDetector:
18
+ def __init__(self, config_path: str , model_name:str = None, token:str = None):
19
+ self.config = load_config(config_path)
20
+ self.macs = DevMACSCore.from_huggingface(token=token, repo_id=f"PIA-SPACE-LAB/{model_name}")
21
+ # self.macs = DevMACSCore(model_type="clip4clip_web")
22
+
23
+ self.prompt_manager = PromptManager(config_path)
24
+ self.sentences = self.prompt_manager.sentences
25
+ self.text_vectors = self.macs.get_text_vector(self.sentences)
26
+
27
+ def process_and_save_predictions(self, vector_base_dir: str, label_base_dir: str, save_base_dir: str):
28
+ """비디오 벡터를 처리하고 결과를 CSV로 저장"""
29
+
30
+ # 전체 비디오 파일 수 계산
31
+ total_videos = sum(len([f for f in cust_listdir(os.path.join(vector_base_dir, d))
32
+ if f.endswith('.npy')])
33
+ for d in cust_listdir(vector_base_dir)
34
+ if os.path.isdir(os.path.join(vector_base_dir, d)))
35
+ pbar = tqdm(total=total_videos, desc="Processing videos")
36
+
37
+ for category in cust_listdir(vector_base_dir):
38
+ category_path = os.path.join(vector_base_dir, category)
39
+ if not os.path.isdir(category_path):
40
+ continue
41
+
42
+ # 저장 디렉토리 생성
43
+ save_category_dir = os.path.join(save_base_dir, category)
44
+ os.makedirs(save_category_dir, exist_ok=True)
45
+
46
+ for file in cust_listdir(category_path):
47
+ if file.endswith('.npy'):
48
+ video_name = os.path.splitext(file)[0]
49
+ vector_path = os.path.join(category_path, file)
50
+
51
+ # 라벨 파일 읽기
52
+ label_path = os.path.join(label_base_dir, category, f"{video_name}.json")
53
+ with open(label_path, 'r') as f:
54
+ label_data = json.load(f)
55
+ total_frames = label_data['video_info']['total_frame']
56
+
57
+ # 예측 결과 생성 및 저장
58
+ self._process_and_save_single_video(
59
+ vector_path=vector_path,
60
+ total_frames=total_frames,
61
+ save_path=os.path.join(save_category_dir, f"{video_name}.csv")
62
+ )
63
+ pbar.update(1)
64
+ pbar.close()
65
+
66
+ def _process_and_save_single_video(self, vector_path: str, total_frames: int, save_path: str):
67
+ """단일 비디오 처리 및 저장"""
68
+ # 기본 예측 수행
69
+ sparse_predictions = self._process_single_vector(vector_path)
70
+
71
+ # 데이터프레임으로 변환 및 확장
72
+ df = self._expand_predictions(sparse_predictions, total_frames)
73
+
74
+ # CSV로 저장
75
+ df.to_csv(save_path, index=False)
76
+
77
+ def _process_single_vector(self, vector_path: str) -> Dict:
78
+ """기존 예측 로직"""
79
+ video_vector = np.load(vector_path)
80
+ processed_vectors = []
81
+ frame_interval = 15
82
+
83
+ for vector in video_vector:
84
+ v = vector.squeeze(0) # numpy array
85
+ v = torch.from_numpy(v).unsqueeze(0).cuda() # torch tensor로 변환 후 GPU로
86
+ processed_vectors.append(v)
87
+
88
+ frame_results = {}
89
+ for vector_idx, v in enumerate(processed_vectors):
90
+ actual_frame = vector_idx * frame_interval
91
+ sim_scores = loose_similarity(
92
+ sequence_output=self.text_vectors.cuda(),
93
+ visual_output=v.unsqueeze(1)
94
+ )
95
+ frame_results[actual_frame] = self._calculate_alarms(sim_scores)
96
+
97
+ return frame_results
98
+
99
+ def _expand_predictions(self, sparse_predictions: Dict, total_frames: int) -> pd.DataFrame:
100
+ """예측을 전체 프레임으로 확장"""
101
+ # 카테고리 목록 추출 (첫 번째 프레임의 알람 결과에서)
102
+ first_frame = list(sparse_predictions.keys())[0]
103
+ categories = list(sparse_predictions[first_frame].keys())
104
+
105
+ # 전체 프레임 생성
106
+ df = pd.DataFrame({'frame': range(total_frames)})
107
+
108
+ # 각 카테고리에 대한 알람 값 초기화
109
+ for category in categories:
110
+ df[category] = 0
111
+
112
+ # 예측값 채우기
113
+ frame_keys = sorted(sparse_predictions.keys())
114
+ for i in range(len(frame_keys)):
115
+ current_frame = frame_keys[i]
116
+ next_frame = frame_keys[i + 1] if i + 1 < len(frame_keys) else total_frames
117
+
118
+ # 각 카테고리의 알람 값 설정
119
+ for category in categories:
120
+ alarm_value = sparse_predictions[current_frame][category]['alarm']
121
+ df.loc[current_frame:next_frame-1, category] = alarm_value
122
+
123
+ return df
124
+
125
+
126
+ def _calculate_alarms(self, sim_scores: torch.Tensor) -> Dict:
127
+ """유사도 점수를 기반으로 각 이벤트의 알람 상태 계산"""
128
+ # 로거 설정
129
+ log_filename = f"alarm_calculation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
130
+ logging.basicConfig(
131
+ filename=log_filename,
132
+ level=logging.ERROR,
133
+ format='%(asctime)s - %(message)s',
134
+ datefmt='%Y-%m-%d %H:%M:%S'
135
+ )
136
+ logger = logging.getLogger(__name__)
137
+
138
+ event_alarms = {}
139
+
140
+ for event_config in self.config['PROMPT_CFG']:
141
+ event = event_config['event']
142
+ top_k = event_config['top_candidates']
143
+ threshold = event_config['alert_threshold']
144
+
145
+ # logger.info(f"\nProcessing event: {event}")
146
+ # logger.info(f"Top K: {top_k}, Threshold: {threshold}")
147
+
148
+ event_prompts = self._get_event_prompts(event)
149
+
150
+ # logger.debug(f"\nEvent Prompts Debug for {event}:")
151
+ # logger.debug(f"Indices: {event_prompts['indices']}")
152
+ # logger.debug(f"Types: {event_prompts['types']}")
153
+ # logger.debug(f"\nSim Scores Debug:")
154
+ # logger.debug(f"Shape: {sim_scores.shape}")
155
+ # logger.debug(f"Raw scores: {sim_scores}")
156
+
157
+ # event_scores = sim_scores[event_prompts['indices']]
158
+ event_scores = sim_scores[event_prompts['indices']].squeeze(-1) # shape 변경
159
+
160
+ # logger.debug(f"Event scores shape: {event_scores.shape}")
161
+ # logger.debug(f"Event scores: {event_scores}")
162
+ # 각 프롬프트와 점수 출력
163
+ # logger.info("\nDEBUG VALUES:")
164
+ # logger.info(f"event_scores: {event_scores}")
165
+ # logger.info(f"indices: {event_prompts['indices']}")
166
+ # logger.info(f"types: {event_prompts['types']}")
167
+
168
+ # logger.info("\nAll prompts and scores:")
169
+ # for idx, (score, prompt_type) in enumerate(zip(event_scores, event_prompts['types'])):
170
+ # logger.info(f"Type: {prompt_type}, Score: {score.item():.4f}")
171
+
172
+ top_k_values, top_k_indices = torch.topk(event_scores, min(top_k, len(event_scores)))
173
+
174
+ # logger.info(f"top_k_values: {top_k_values}")
175
+ # logger.info(f"top_k_indices (raw): {top_k_indices}")
176
+ # Top K 결과 출력
177
+ # logger.info(f"\nTop {top_k} selections:")
178
+ for idx, (value, index) in enumerate(zip(top_k_values, top_k_indices)):
179
+ # indices[index]가 아닌 index를 직접 사용
180
+ prompt_type = event_prompts['types'][index] # 수정된 부분
181
+ # logger.info(f"DEBUG: index={index}, types={event_prompts['types']}, selected_type={prompt_type}")
182
+ # logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
183
+
184
+ abnormal_count = sum(1 for idx in top_k_indices
185
+ if event_prompts['types'][idx] == 'abnormal') # 수정된 부분
186
+ # for idx, (value, orig_idx) in enumerate(zip(top_k_values, top_k_indices)):
187
+ # prompt_type = event_prompts['types'][orig_idx.item()]
188
+ # logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
189
+
190
+ # abnormal_count = sum(1 for idx in top_k_indices
191
+ # if event_prompts['types'][idx.item()] == 'abnormal')
192
+
193
+ # 알람 결정 과정 출력
194
+ # logger.info(f"\nAbnormal count: {abnormal_count}")
195
+ alarm_result = 1 if abnormal_count >= threshold else 0
196
+ # logger.info(f"Final alarm decision: {alarm_result}")
197
+ # logger.info("-" * 50)
198
+
199
+ event_alarms[event] = {
200
+ 'alarm': alarm_result,
201
+ 'scores': top_k_values.tolist(),
202
+ 'top_k_types': [event_prompts['types'][idx.item()] for idx in top_k_indices]
203
+ }
204
+
205
+ # 로거 종료
206
+ logging.shutdown()
207
+
208
+ return event_alarms
209
+
210
+ def _get_event_prompts(self, event: str) -> Dict:
211
+ indices = []
212
+ types = []
213
+ current_idx = 0
214
+
215
+ for event_config in self.config['PROMPT_CFG']:
216
+ if event_config['event'] == event:
217
+ for status in ['normal', 'abnormal']:
218
+ for _ in range(len(event_config['prompts'][status])):
219
+ indices.append(current_idx)
220
+ types.append(status)
221
+ current_idx += 1
222
+
223
+ return {'indices': indices, 'types': types}
224
+
225
+
pia_bench/metric.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
5
+ from typing import Dict, List
6
+ import json
7
+ from utils.except_dir import cust_listdir
8
+
9
+ class MetricsEvaluator:
10
+ def __init__(self, pred_dir: str, label_dir: str, save_dir: str):
11
+ """
12
+ Args:
13
+ pred_dir: 예측 csv 파일들이 있는 디렉토리 경로
14
+ label_dir: 정답 csv 파일들이 있는 디렉토리 경로
15
+ save_dir: 결과를 저장할 디렉토리 경로
16
+ """
17
+ self.pred_dir = pred_dir
18
+ self.label_dir = label_dir
19
+ self.save_dir = save_dir
20
+
21
+ def evaluate(self) -> Dict:
22
+ """전체 평가 수행"""
23
+ category_metrics = {} # 카테고리별 평균 성능 저장
24
+ all_metrics = { # 모든 카테고리 통합 메트릭
25
+ 'falldown': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
26
+ 'violence': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
27
+ 'fire': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []}
28
+ }
29
+
30
+ # 모든 카테고리의 metrics를 저장할 DataFrame 리스트
31
+ all_categories_metrics = []
32
+
33
+ for category in cust_listdir(self.pred_dir):
34
+ if not os.path.isdir(os.path.join(self.pred_dir, category)):
35
+ continue
36
+
37
+ pred_category_path = os.path.join(self.pred_dir, category)
38
+ label_category_path = os.path.join(self.label_dir, category)
39
+ save_category_path = os.path.join(self.save_dir, category)
40
+ os.makedirs(save_category_path, exist_ok=True)
41
+
42
+ # 결과 저장을 위한 데이터프레임 생성
43
+ metrics_df = self._evaluate_category(category, pred_category_path, label_category_path)
44
+
45
+ metrics_df['category'] = category
46
+
47
+ metrics_df.to_csv(os.path.join(save_category_path, f"{category}_metrics.csv"), index=False)
48
+
49
+ all_categories_metrics.append(metrics_df)
50
+
51
+ # 카테고리별 평균 성능 저장
52
+ category_metrics[category] = metrics_df.iloc[-1].to_dict() # 마지막 row(평균)
53
+
54
+ # 전체 평균을 위한 메트릭 수집
55
+ # for col in metrics_df.columns:
56
+ # if col != 'video_name':
57
+ # event_type, metric_type = col.split('_')
58
+ # all_metrics[event_type][metric_type].append(category_metrics[category][col])
59
+
60
+ for col in metrics_df.columns:
61
+ if col != 'video_name':
62
+ try:
63
+ # 첫 번째 언더스코어를 기준으로 이벤트 타입과 메트릭 타입 분리
64
+ parts = col.split('_', 1) # maxsplit=1로 첫 번째 언더스코어에서만 분리
65
+ if len(parts) == 2:
66
+ event_type, metric_type = parts
67
+ if event_type in all_metrics and metric_type in all_metrics[event_type]:
68
+ all_metrics[event_type][metric_type].append(category_metrics[category][col])
69
+ except Exception as e:
70
+ print(f"Warning: Could not process column {col}: {str(e)}")
71
+ continue
72
+
73
+ # 각 DataFrame에서 마지막 행(average)을 제거
74
+ all_categories_metrics_without_avg = [df.iloc[:-1] for df in all_categories_metrics]
75
+ # 모든 카테고리의 metrics를 하나의 DataFrame으로 합치기
76
+ combined_metrics_df = pd.concat(all_categories_metrics_without_avg, ignore_index=True)
77
+ # 합쳐진 metrics를 json 파일과 같은 위치에 저장
78
+ combined_metrics_df.to_csv(os.path.join(self.save_dir, "all_categories_metrics.csv"), index=False)
79
+ # 결과 출력
80
+ # print("\nCategory-wise Average Metrics:")
81
+ # for category, metrics in category_metrics.items():
82
+ # print(f"\n{category}:")
83
+ # for metric_name, value in metrics.items():
84
+ # if metric_name != "video_name":
85
+ # print(f"{metric_name}: {value:.3f}")
86
+
87
+ print("\nCategory-wise Average Metrics:")
88
+ for category, metrics in category_metrics.items():
89
+ print(f"\n{category}:")
90
+ for metric_name, value in metrics.items():
91
+ if metric_name != "video_name":
92
+ try:
93
+ if isinstance(value, str):
94
+ print(f"{metric_name}: {value}")
95
+ elif metric_name in ['tp', 'tn', 'fp', 'fn']:
96
+ print(f"{metric_name}: {int(value)}")
97
+ else:
98
+ print(f"{metric_name}: {float(value):.3f}")
99
+ except (ValueError, TypeError):
100
+ print(f"{metric_name}: {value}")
101
+ # 전체 평균 계산 및 출력
102
+ print("\n" + "="*50)
103
+ print("Overall Average Metrics Across All Categories:")
104
+ print("="*50)
105
+
106
+ # for event_type in all_metrics:
107
+ # print(f"\n{event_type}:")
108
+ # for metric_type, values in all_metrics[event_type].items():
109
+ # avg_value = np.mean(values)
110
+ # print(f"{metric_type}: {avg_value:.3f}")
111
+
112
+ for event_type in all_metrics:
113
+ print(f"\n{event_type}:")
114
+ for metric_type, values in all_metrics[event_type].items():
115
+ avg_value = np.mean(values)
116
+ if metric_type in ['tp', 'tn', 'fp', 'fn']: # 정수 값
117
+ print(f"{metric_type}: {int(avg_value)}")
118
+ else: # 소수점 값
119
+ print(f"{metric_type}: {avg_value:.3f}")
120
+ ##################################################################################################
121
+ # 최종 결과를 저장할 딕셔너리
122
+ final_results = {
123
+ "category_metrics": {},
124
+ "overall_metrics": {}
125
+ }
126
+ # 카테고리별 메트릭 저장
127
+
128
+ for category, metrics in category_metrics.items():
129
+ final_results["category_metrics"][category] = {}
130
+ for metric_name, value in metrics.items():
131
+ if metric_name != "video_name":
132
+ if isinstance(value, (int, float)):
133
+ final_results["category_metrics"][category][metric_name] = float(value)
134
+
135
+ # 전체 평균 계산 및 저장
136
+ for event_type in all_metrics:
137
+ # print(f"\n{event_type}:")
138
+ final_results["overall_metrics"][event_type] = {}
139
+ for metric_type, values in all_metrics[event_type].items():
140
+ avg_value = float(np.mean(values))
141
+ # print(f"{metric_type}: {avg_value:.3f}")
142
+ final_results["overall_metrics"][event_type][metric_type] = avg_value
143
+
144
+ # JSON 파일로 저장
145
+ json_path = os.path.join(self.save_dir, "overall_metrics.json")
146
+ with open(json_path, 'w', encoding='utf-8') as f:
147
+ json.dump(final_results, f, indent=4)
148
+
149
+ # return category_metrics
150
+
151
+ # 누적 메트릭 계산
152
+ accumulated_metrics = self.calculate_accumulated_metrics(combined_metrics_df)
153
+
154
+ # JSON에 누적 메트릭 추가
155
+ final_results["accumulated_metrics"] = accumulated_metrics
156
+
157
+ # 누적 메트릭만 따로 저장
158
+ accumulated_json_path = os.path.join(self.save_dir, "accumulated_metrics.json")
159
+ with open(accumulated_json_path, 'w', encoding='utf-8') as f:
160
+ json.dump(accumulated_metrics, f, indent=4)
161
+
162
+ return accumulated_metrics
163
+
164
+ def _evaluate_category(self, category: str, pred_path: str, label_path: str) -> pd.DataFrame:
165
+ """카테고리별 평가 수행"""
166
+ results = []
167
+ metrics_columns = ['video_name']
168
+
169
+ for pred_file in cust_listdir(pred_path):
170
+ if not pred_file.endswith('.csv'):
171
+ continue
172
+
173
+ video_name = os.path.splitext(pred_file)[0]
174
+ pred_df = pd.read_csv(os.path.join(pred_path, pred_file))
175
+
176
+ # 해당 비디오의 정답 CSV 파일 로드
177
+ label_file = f"{video_name}.csv"
178
+ label_path_full = os.path.join(label_path, label_file)
179
+
180
+ if not os.path.exists(label_path_full):
181
+ print(f"Warning: Label file not found for {video_name}")
182
+ continue
183
+
184
+ label_df = pd.read_csv(label_path_full)
185
+
186
+ # 각 카테고리별 메트릭 계산
187
+ video_metrics = {'video_name': video_name}
188
+ categories = [col for col in pred_df.columns if col != 'frame']
189
+
190
+ for cat in categories:
191
+ # 정답값과 예측값
192
+ y_true = label_df[cat].values
193
+ y_pred = pred_df[cat].values
194
+
195
+ # 메트릭 계산
196
+ metrics = self._calculate_metrics(y_true, y_pred)
197
+
198
+ # 결과 저장
199
+ for metric_name, value in metrics.items():
200
+ col_name = f"{cat}_{metric_name}"
201
+ video_metrics[col_name] = value
202
+ if col_name not in metrics_columns:
203
+ metrics_columns.append(col_name)
204
+
205
+ results.append(video_metrics)
206
+
207
+ # 결과를 데이터프레임으로 변환
208
+ metrics_df = pd.DataFrame(results, columns=metrics_columns)
209
+
210
+ # 평균 계산하여 추가
211
+ avg_metrics = {'video_name': 'average'}
212
+ for col in metrics_columns[1:]: # video_name 제외
213
+ avg_metrics[col] = metrics_df[col].mean()
214
+
215
+ metrics_df = pd.concat([metrics_df, pd.DataFrame([avg_metrics])], ignore_index=True)
216
+
217
+ return metrics_df
218
+
219
+ # def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
220
+ # """성능 지표 계산"""
221
+ # tn = np.sum((y_true == 0) & (y_pred == 0))
222
+ # fp = np.sum((y_true == 0) & (y_pred == 1))
223
+
224
+ # metrics = {
225
+ # 'f1': f1_score(y_true, y_pred, zero_division=0),
226
+ # 'accuracy': accuracy_score(y_true, y_pred),
227
+ # 'precision': precision_score(y_true, y_pred, zero_division=0),
228
+ # 'recall': recall_score(y_true, y_pred, zero_division=0),
229
+ # 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0
230
+ # }
231
+
232
+ # return metrics
233
+
234
+
235
+ def calculate_accumulated_metrics(self, all_categories_metrics_df: pd.DataFrame) -> Dict:
236
+ """누적된 혼동행렬로 각 카테고리별 성능 지표 계산"""
237
+ accumulated_results = {"micro_avg": {}}
238
+ categories = ['falldown', 'violence', 'fire']
239
+
240
+ for category in categories:
241
+ # 해당 카테고리의 혼동행렬 값들 누적
242
+ tp = all_categories_metrics_df[f'{category}_tp'].sum()
243
+ tn = all_categories_metrics_df[f'{category}_tn'].sum()
244
+ fp = all_categories_metrics_df[f'{category}_fp'].sum()
245
+ fn = all_categories_metrics_df[f'{category}_fn'].sum()
246
+
247
+ # 기본 메트릭 계산
248
+ metrics = {
249
+ 'tp': int(tp),
250
+ 'tn': int(tn),
251
+ 'fp': int(fp),
252
+ 'fn': int(fn),
253
+ 'accuracy': (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0,
254
+ 'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
255
+ 'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
256
+ 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
257
+ 'f1': 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
258
+ }
259
+
260
+ # 추가 메트릭 계산
261
+ tpr = metrics['recall'] # TPR = recall
262
+ tnr = metrics['specificity'] # TNR = specificity
263
+
264
+ # Balanced Accuracy
265
+ metrics['balanced_accuracy'] = (tpr + tnr) / 2
266
+
267
+ # G-Mean
268
+ metrics['g_mean'] = np.sqrt(tpr * tnr) if (tpr * tnr) > 0 else 0
269
+
270
+ # MCC (Matthews Correlation Coefficient)
271
+ numerator = (tp * tn) - (fp * fn)
272
+ denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
273
+ metrics['mcc'] = numerator / denominator if denominator > 0 else 0
274
+
275
+ # NPV (Negative Predictive Value)
276
+ metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
277
+
278
+ # FAR (False Alarm Rate) = FPR = 1 - specificity
279
+ metrics['far'] = 1 - metrics['specificity']
280
+
281
+ accumulated_results[category] = metrics
282
+
283
+ # 전체 카테고리의 누적 값으로 계산
284
+ total_tp = sum(accumulated_results[cat]['tp'] for cat in categories)
285
+ total_tn = sum(accumulated_results[cat]['tn'] for cat in categories)
286
+ total_fp = sum(accumulated_results[cat]['fp'] for cat in categories)
287
+ total_fn = sum(accumulated_results[cat]['fn'] for cat in categories)
288
+
289
+ # micro average 계산 (전체 누적 값으로 계산)
290
+ accumulated_results["micro_avg"] = {
291
+ 'tp': int(total_tp),
292
+ 'tn': int(total_tn),
293
+ 'fp': int(total_fp),
294
+ 'fn': int(total_fn),
295
+ 'accuracy': (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn),
296
+ 'precision': total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0,
297
+ 'recall': total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0,
298
+ 'f1': 2 * total_tp / (2 * total_tp + total_fp + total_fn) if (2 * total_tp + total_fp + total_fn) > 0 else 0,
299
+ # ... (다른 메트릭들도 동일한 방식으로 계산)
300
+ }
301
+
302
+ return accumulated_results
303
+ def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
304
+ """성능 지표 계산"""
305
+ tn = np.sum((y_true == 0) & (y_pred == 0))
306
+ fp = np.sum((y_true == 0) & (y_pred == 1))
307
+ fn = np.sum((y_true == 1) & (y_pred == 0))
308
+ tp = np.sum((y_true == 1) & (y_pred == 1))
309
+
310
+ metrics = {
311
+ 'f1': f1_score(y_true, y_pred, zero_division=0),
312
+ 'accuracy': accuracy_score(y_true, y_pred),
313
+ 'precision': precision_score(y_true, y_pred, zero_division=0),
314
+ 'recall': recall_score(y_true, y_pred, zero_division=0),
315
+ 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
316
+ 'tp': int(tp),
317
+ 'tn': int(tn),
318
+ 'fp': int(fp),
319
+ 'fn': int(fn)
320
+ }
321
+
322
+ return metrics
pia_bench/pipe_line/piepline.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pia_bench.checker.bench_checker import BenchChecker
2
+ from pia_bench.checker.sheet_checker import SheetChecker
3
+ from pia_bench.event_alarm import EventDetector
4
+ from pia_bench.metric import MetricsEvaluator
5
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
6
+ from pia_bench.bench import PiaBenchMark
7
+ from dotenv import load_dotenv
8
+ from typing import Optional, List , Dict
9
+ import os
10
+ load_dotenv()
11
+ import numpy as np
12
+ from typing import Dict, Tuple
13
+ from typing import Dict, Optional, Tuple
14
+ import logging
15
+ from dataclasses import dataclass
16
+ from sheet_manager.sheet_checker.sheet_check import SheetChecker
17
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
18
+ from pia_bench.checker.bench_checker import BenchChecker
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+ @dataclass
23
+ class PipelineConfig:
24
+ """파이프라인 설정을 위한 데이터 클래스"""
25
+ model_name: str
26
+ benchmark_name: str
27
+ cfg_target_path: str
28
+ base_path: str = "/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
29
+
30
+ class BenchmarkPipelineStatus:
31
+ """파이프라인 상태 및 결과 관리"""
32
+ def __init__(self):
33
+ self.sheet_status: Tuple[bool, bool] = (False, False) # (model_added, benchmark_exists)
34
+ self.bench_status: Dict[str, bool] = {}
35
+ self.bench_result: str = ""
36
+ self.current_stage: str = "not_started"
37
+
38
+ def is_success(self) -> bool:
39
+ """전체 파이프라인 성공 여부"""
40
+ return (not self.sheet_status[0] # 모델이 이미 존재하고
41
+ and self.sheet_status[1] # 벤치마크가 존재하고
42
+ and self.bench_result == "all_passed") # 벤치마크 체크도 통과
43
+
44
+ def __str__(self) -> str:
45
+ return (f"Current Stage: {self.current_stage}\n"
46
+ f"Sheet Status: {self.sheet_status}\n"
47
+ f"Bench Status: {self.bench_status}\n"
48
+ f"Bench Result: {self.bench_result}")
49
+
50
+ class BenchmarkPipeline:
51
+ """벤치마크 실행을 위한 파이프라인"""
52
+
53
+ def __init__(self, config: PipelineConfig):
54
+ self.config = config
55
+ self.logger = logging.getLogger(self.__class__.__name__)
56
+ self.status = BenchmarkPipelineStatus()
57
+ self.access_token = os.getenv("ACCESS_TOKEN")
58
+ self.cfg_prompt = os.path.splitext(os.path.basename(self.config.cfg_target_path))[0]
59
+
60
+ # Initialize checkers
61
+ self.sheet_manager = SheetManager()
62
+ self.sheet_checker = SheetChecker(self.sheet_manager)
63
+ self.bench_checker = BenchChecker(self.config.base_path)
64
+
65
+ self.bench_result_dict = None
66
+
67
+ def run(self) -> BenchmarkPipelineStatus:
68
+ """전체 파이프라인 실행"""
69
+ try:
70
+ self.status.current_stage = "sheet_check"
71
+ proceed = self._check_sheet()
72
+
73
+ if not proceed:
74
+ self.status.current_stage = "completed_no_action_needed"
75
+ self.logger.info("벤치마크가 이미 존재하여 추가 작업이 필요하지 않습니다.")
76
+ return self.status
77
+
78
+ self.status.current_stage = "bench_check"
79
+ if not self._check_bench():
80
+ return self.status
81
+
82
+ self.status.current_stage = "execution"
83
+ self._execute_based_on_status()
84
+
85
+ self.status.current_stage = "completed"
86
+ return self.status
87
+
88
+ except Exception as e:
89
+ self.logger.error(f"파이프라인 실행 중 에러 발생: {str(e)}")
90
+ self.status.current_stage = "error"
91
+ return self.status
92
+
93
+ def _check_sheet(self) -> bool:
94
+ """구글 시트 상태 체크"""
95
+ self.logger.info("시트 상태 체크 시작")
96
+ model_added, benchmark_exists = self.sheet_checker.check_model_and_benchmark(
97
+ self.config.model_name,
98
+ self.config.benchmark_name
99
+ )
100
+ self.status.sheet_status = (model_added, benchmark_exists)
101
+
102
+ if model_added:
103
+ self.logger.info("새로운 모델이 추가되었습니다")
104
+ if not benchmark_exists:
105
+ self.logger.info("벤치마크 측정이 필요합니다")
106
+ return True # 벤치마크 측정이 필요한 경우만 다음 단계로 진행
107
+
108
+ self.logger.info("이미 벤치마크가 존재합니다. 파이프라인을 종료합니다.")
109
+ return False # 벤치마크가 이미 있으면 여기서 중단
110
+
111
+ def _check_bench(self) -> bool:
112
+ """로컬 벤치마크 환경 체크"""
113
+ self.logger.info("벤치마크 환경 체크 시작")
114
+ self.status.bench_status = self.bench_checker.check_benchmark(
115
+ self.config.benchmark_name,
116
+ self.config.model_name,
117
+ self.cfg_prompt
118
+ )
119
+ self.status.bench_result = self.bench_checker.get_benchmark_status(
120
+ self.status.bench_status
121
+ )
122
+
123
+ # no bench 상태 벤치를 돌린적이 없음 폴더구조도 없음
124
+ if self.status.bench_result == "no bench":
125
+ self.logger.error("벤치마크 실행에 필요한 기본 폴더구조가 없습니다.")
126
+ return True
127
+
128
+ return True # 그 외의 경우만 다음 단계로 진행
129
+
130
+ def _execute_based_on_status(self):
131
+ """상태에 따른 실행 로직"""
132
+ if self.status.bench_result == "all_passed":
133
+ self._execute_full_pipeline()
134
+ elif self.status.bench_result == "no_vectors":
135
+ self._execute_vector_generation()
136
+ elif self.status.bench_result == "no_metrics":
137
+ self._execute_metrics_generation()
138
+ else:
139
+ self._execute_vector_generation()
140
+ self.logger.warning("폴더구조가 없습니다")
141
+
142
+ def _execute_full_pipeline(self):
143
+ """모든 조건이 충족된 경우의 실행 로직"""
144
+ self.logger.info("전체 파이프라인 실행 중...")
145
+ pia_benchmark = PiaBenchMark(
146
+ benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
147
+ model_name=self.config.model_name,
148
+ cfg_target_path= self.config.cfg_target_path ,
149
+ token=self.access_token )
150
+ pia_benchmark.preprocess_structure()
151
+ print("Categories identified:", pia_benchmark.categories)
152
+ metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
153
+ label_dir=pia_benchmark.dataset_path,
154
+ save_dir=pia_benchmark.metric_path)
155
+
156
+ self.bench_result_dict = metric.evaluate()
157
+
158
+ def _execute_vector_generation(self):
159
+ """벡터 생성이 필요한 경우의 실행 로직"""
160
+ self.logger.info("벡터 생성 중...")
161
+ # 구현 필요
162
+
163
+ pia_benchmark = PiaBenchMark(
164
+ benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
165
+ model_name=self.config.model_name,
166
+ cfg_target_path= self.config.cfg_target_path ,
167
+ token=self.access_token )
168
+ pia_benchmark.preprocess_structure()
169
+ pia_benchmark.preprocess_label_to_csv()
170
+ print("Categories identified:", pia_benchmark.categories)
171
+
172
+ pia_benchmark.extract_visual_vector()
173
+
174
+ detector = EventDetector(config_path=self.config.cfg_target_path,
175
+ model_name=self.config.model_name ,
176
+ token=pia_benchmark.token)
177
+ detector.process_and_save_predictions(pia_benchmark.vector_video_path,
178
+ pia_benchmark.dataset_path,
179
+ pia_benchmark.alram_path)
180
+ metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
181
+ label_dir=pia_benchmark.dataset_path,
182
+ save_dir=pia_benchmark.metric_path)
183
+
184
+ self.bench_result_dict = metric.evaluate()
185
+
186
+
187
+ def _execute_metrics_generation(self):
188
+ """메트릭 생성이 필요한 경우의 실행 로직"""
189
+ self.logger.info("메트릭 생성 중...")
190
+ # 구현 필요
191
+ pia_benchmark = PiaBenchMark(
192
+ benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
193
+ model_name=self.config.model_name,
194
+ cfg_target_path= self.config.cfg_target_path ,
195
+ token=self.access_token )
196
+ pia_benchmark.preprocess_structure()
197
+ pia_benchmark.preprocess_label_to_csv()
198
+ print("Categories identified:", pia_benchmark.categories)
199
+
200
+ detector = EventDetector(config_path=self.config.cfg_target_path,
201
+ model_name=self.config.model_name ,
202
+ token=pia_benchmark.token)
203
+ detector.process_and_save_predictions(pia_benchmark.vector_video_path,
204
+ pia_benchmark.dataset_path,
205
+ pia_benchmark.alram_path)
206
+ metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
207
+ label_dir=pia_benchmark.dataset_path,
208
+ save_dir=pia_benchmark.metric_path)
209
+
210
+ self.bench_result_dict = metric.evaluate()
211
+
212
+
213
+ if __name__ == "__main__":
214
+ # 파이프라인 설정
215
+ config = PipelineConfig(
216
+ model_name="T2V_CLIP4CLIP_MSRVTT",
217
+ benchmark_name="PIA",
218
+ cfg_target_path="topk.json",
219
+ base_path="/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
220
+ )
221
+
222
+ # 파이프라인 실행
223
+ pipeline = BenchmarkPipeline(config)
224
+ result = pipeline.run()
225
+
226
+ print(f"\n파이프라인 실행 결과:")
227
+ print(str(result))
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ oauth2client
2
+ gspread
3
+ gradio
4
+ python-dotenv
5
+ APScheduler
6
+ black
7
+ gradio[oauth]
8
+ gradio_leaderboard==0.0.13
9
+ gradio_client
10
+ huggingface-hub>=0.18.0
11
+ matplotlib
12
+ numpy
13
+ pandas
14
+ python-dateutil
15
+ tqdm
sample.csv ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ video_name,resolution,video_duration,category,benchmark,duration_seconds,total_frames,file_format,file_size_mb,aspect_ratio,fps
2
+ 417-2_cam02_assault01_place03_night_spring.mp4,3840x2160,5:15,violence,PIA,315.14816666666667,9445,.mp4,351.66,1.78,29.97002997002997
3
+ 439-5_cam01_assault01_place03_day_summer.mp4,3840x2160,5:19,violence,PIA,318.6516666666667,9550,.mp4,342.09,1.78,29.97002997002997
4
+ 24-2_cam01_assault01_place09_night_winter.mp4,3840x2160,4:33,violence,PIA,272.7725,8175,.mp4,299.35,1.78,29.97002997002997
5
+ 6-1_cam01_assault01_place03_night_summer.mp4,3840x2160,5:19,violence,PIA,318.6516666666667,9550,.mp4,342.81,1.78,29.97002997002997
6
+ fight_0026.mp4,640x360,3:20,violence,PIA,199.93306666666666,5992,.mp4,16.19,1.78,29.97002997002997
7
+ 10-1_cam01_assault03_place07_night_winter.mp4,3840x2160,5:10,violence,PIA,310.07643333333334,9293,.mp4,333.36,1.78,29.97002997002997
8
+ 22-2_cam01_assault01_place07_night_winter.mp4,3840x2160,5:5,violence,PIA,305.305,9150,.mp4,368.69,1.78,29.97002997002997
9
+ 407-6_cam01_assault01_place04_day_winter.mp4,3840x2160,5:13,violence,PIA,312.77913333333333,9374,.mp4,747.29,1.78,29.97002997002997
10
+ 407-6_cam01_assault01_place04_day_spring.mp4,3840x2160,5:9,violence,PIA,309.10880000000003,9264,.mp4,738.48,1.78,29.97002997002997
11
+ 411-3_cam01_assault01_place08_night_winter.mp4,3840x2160,5:15,violence,PIA,314.7477666666667,9433,.mp4,337.82,1.78,29.97002997002997
12
+ 412-1_cam01_assault01_place09_night_winter.mp4,3840x2160,5:30,violence,PIA,330.26326666666665,9898,.mp4,380.09,1.78,29.97002997002997
13
+ 416-5_cam03_assault01_place02_night_spring.mp4,3840x2160,5:9,violence,PIA,308.97533333333337,9260,.mp4,342.81,1.78,29.97002997002997
14
+ 12-1_cam01_assault01_place09_day_summer.mp4,3840x2160,4:54,violence,PIA,294.02706666666666,8812,.mp4,712.79,1.78,29.97002997002997
15
+ 13-4_cam02_assault02_place08_day_spring.mp4,3840x2160,4:56,violence,PIA,295.96233333333333,8870,.mp4,711.64,1.78,29.97002997002997
16
+ 16-3_cam01_assault01_place02_night_summer.mp4,3840x2160,5:0,violence,PIA,300.3333666666667,9001,.mp4,358.8,1.78,29.97002997002997
17
+ 17-2_cam01_assault03_place03_night_spring.mp4,3840x2160,5:13,violence,PIA,312.8125,9375,.mp4,758.24,1.78,29.97002997002997
18
+ 2-3_cam01_assault01_place04_night_spring.mp4,3840x2160,5:12,violence,PIA,312.4454666666667,9364,.mp4,750.96,1.78,29.97002997002997
19
+ 20-3_cam02_assault01_place02_night_summer.mp4,3840x2160,5:14,violence,PIA,314.2139,9417,.mp4,450.17,1.78,29.97002997002997
20
+ 23-3_cam01_assault01_place02_night_summer.mp4,3840x2160,5:0,violence,PIA,300.3,9000,.mp4,358.84,1.78,29.97002997002997
21
+ 406-1_cam01_assault01_place03_day_summer.mp4,3840x2160,5:0,violence,PIA,300.3333666666667,9001,.mp4,428.29,1.78,29.97002997002997
22
+ 8-1_cam01_assault03_place05_day_spring.mp4,3840x2160,5:0,violence,PIA,300.3,9000,.mp4,717.83,1.78,29.97002997002997
23
+ fight_0035.mp4,406x720,1:47,violence,PIA,107.27383333333333,3215,.mp4,11.04,0.56,29.97002997002997
24
+ fight_0062.mp4,640x360,1:12,violence,PIA,71.70496666666666,2149,.mp4,3.05,1.78,29.97002997002997
25
+ fight_0051.mp4,1280x720,2:0,violence,PIA,120.43333333333334,3599,.mp4,30.26,1.78,29.88375311375588
26
+ fight_0097.mp4,1280x720,1:8,violence,PIA,68.00126666666667,2038,.mp4,6.47,1.78,29.97002997002997
27
+ fight_0125.mp4,1280x720,1:26,violence,PIA,86.01926666666667,2578,.mp4,9.09,1.78,29.97002997002997
28
+ fight_0141.mp4,1280x720,3:11,violence,PIA,191.42456666666666,5737,.mp4,19.21,1.78,29.97002997002997
29
+ fight_0147.mp4,640x360,2:35,violence,PIA,154.73333333333335,4624,.mp4,9.45,1.78,29.88367083153813
30
+ fight_0156.mp4,1280x720,1:16,violence,PIA,76.0,2280,.mp4,10.58,1.78,30.0
31
+ fight_0162.mp4,1280x720,1:8,violence,PIA,67.86666666666666,2036,.mp4,8.28,1.78,30.0
32
+ 20190102_013314A.mp4,3840x2160,15:0,fire,PIA,900.39,27013,.mp4,2477.09,1.78,30.001443818789635
33
+ 화재 - 불피우기.mp4,1920x1080,2:40,fire,PIA,159.535675,4786,.mp4,76.23,1.78,29.999559659618452
34
+ 화재 - 토치.mp4,1920x1080,8:44,fire,PIA,523.871349,15716,.mp4,249.98,1.78,29.999731861648346
35
+ Video34.mp4,292x240,15:1,fire,PIA,901.0485436893204,7734,.mp4,1.71,1.22,8.583333333333334
36
+ Video5.mp4,320x240,3:7,fire,PIA,187.33333333333334,4496,.mp4,4.82,1.33,24.0
37
+ Video49.mp4,320x240,1:9,fire,PIA,69.08333333333333,1658,.mp4,2.7,1.33,24.0
38
+ Video149.mp4,854x480,0:30,fire,PIA,29.996663329996665,899,.mp4,5.37,1.78,29.97
39
+ Video261.mp4,292x240,15:0,fire,PIA,900.1199999999999,7501,.mp4,4.09,1.22,8.333333333333334
40
+ fire_general-fire_rgb_0002_cctv1.mp4,1920x1080,0:7,fire,PIA,7.0,189,.mp4,5.58,1.78,27.0
41
+ fire_general-fire_rgb_0065_cctv4.mp4,1920x1080,0:9,fire,PIA,8.525191858525192,511,.mp4,47.03,1.78,59.94
42
+ fire_general-fire_rgb_0070_cctv3.mp4,1920x1080,0:12,fire,PIA,12.479145812479146,748,.mp4,73.64,1.78,59.94
43
+ fire_general-fire_rgb_0083_cctv2.mp4,1920x1080,0:10,fire,PIA,9.743076409743077,584,.mp4,55.54,1.78,59.94
44
+ fire_general-fire_rgb_0559_cctv1.mp4,1920x1080,0:10,fire,PIA,10.477143810477145,628,.mp4,81.13,1.78,59.94
45
+ fire_general-fire_rgb_0556_cctv2.mp4,1920x1080,0:9,fire,PIA,9.376042709376042,562,.mp4,65.8,1.78,59.94
46
+ fire_general-fire_rgb_0530_cctv1.mp4,1920x1080,0:12,fire,PIA,12.028695362028696,721,.mp4,87.26,1.78,59.94
47
+ fire_general-fire_rgb_0514_cctv2.mp4,1920x1080,0:8,fire,PIA,8.241574908241574,494,.mp4,59.32,1.78,59.94
48
+ fire_general-fire_rgb_0475_cctv1.mp4,1920x1080,0:9,fire,PIA,9.426092759426092,565,.mp4,73.66,1.78,59.94
49
+ fire_general-fire_rgb_0460_cctv2.mp4,1920x1080,0:11,fire,PIA,11.16116116116116,669,.mp4,80.93,1.78,59.94
50
+ fire_general-fire_rgb_0356_cctv1.mp4,1920x1080,0:9,fire,PIA,9.376042709376042,562,.mp4,79.11,1.78,59.94
51
+ fire_general-fire_rgb_0331_cctv2.mp4,1920x1080,0:7,fire,PIA,6.773440106773441,406,.mp4,52.11,1.78,59.94
52
+ fire_general-fire_rgb_0305_cctv3.mp4,1920x1080,0:10,fire,PIA,10.226893560226895,613,.mp4,84.28,1.78,59.94
53
+ fire_general-fire_rgb_0291_cctv1.mp4,1920x1080,0:4,fire,PIA,4.170837504170838,250,.mp4,22.08,1.78,59.94
54
+ fire_general-fire_rgb_0280_cctv4.mp4,1920x1080,0:4,fire,PIA,4.087420754087421,245,.mp4,21.92,1.78,59.94
55
+ fire_general-fire_rgb_0562_cctv7.mp4,1920x1080,0:7,fire,PIA,7.0,203,.mp4,7.62,1.78,29.0
56
+ fire_general-fire_rgb_0337_cctv2.mp4,1920x1080,0:7,fire,PIA,7.0,203,.mp4,5.37,1.78,29.0
57
+ fire_general-fire_rgb_0289_cctv2.mp4,1920x1080,0:7,fire,PIA,7.0,203,.mp4,7.74,1.78,29.0
58
+ fire_oil-fire_rgb_0002_cctv2.mp4,1920x1080,0:6,fire,PIA,6.0,180,.mp4,5.28,1.78,30.0
59
+ fire_oil-fire_rgb_0222_cctv4.mp4,1920x1080,0:6,fire,PIA,6.0,180,.mp4,6.37,1.78,30.0
60
+ fire_oil-fire_rgb_0445_cctv7.mp4,1920x1080,0:6,fire,PIA,6.0,174,.mp4,4.19,1.78,29.0
61
+ Explosion004_x264.mp4,320x240,1:3,fire,PIA,63.4,1902,.mp4,8.04,1.33,30.0
62
+ Explosion005_x264.mp4,320x240,0:23,fire,PIA,23.1,693,.mp4,5.46,1.33,30.0
63
+ Explosion009_x264.mp4,320x240,0:37,fire,PIA,36.7,1101,.mp4,9.13,1.33,30.0
64
+ Explosion010_x264.mp4,320x240,1:23,fire,PIA,83.26666666666667,2498,.mp4,11.58,1.33,30.0
65
+ Explosion013_x264.mp4,320x240,1:51,fire,PIA,110.56666666666666,3317,.mp4,15.66,1.33,30.0
66
+ Explosion014_x264.mp4,320x240,0:43,fire,PIA,43.06666666666667,1292,.mp4,6.24,1.33,30.0
67
+ Explosion017_x264.mp4,320x240,0:55,fire,PIA,54.766666666666666,1643,.mp4,12.31,1.33,30.0
68
+ Explosion002_x264.mp4,320x240,2:14,fire,PIA,133.76666666666668,4013,.mp4,18.32,1.33,30.0
69
+ Explosion051_x264.mp4,320x240,1:34,fire,PIA,94.0,2820,.mp4,13.52,1.33,30.0
70
+ 119-1_cam01_swoon01_place03_day_summer.mp4,3840x2160,4:60,falldown,PIA,299.9997,8991,.mp4,810.22,1.78,29.97002997002997
71
+ 100-5_cam02_swoon01_place02_day_summer.mp4,3840x2160,5:21,falldown,PIA,321.38773333333336,9632,.mp4,451.85,1.78,29.97002997002997
72
+ FILE210101-012606F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.23,1.78,29.97002997002997
73
+ 118-2_cam01_swoon02_place10_day_spring.mp4,3840x2160,5:0,falldown,PIA,300.3,9000,.mp4,442.05,1.78,29.97002997002997
74
+ 108-5_cam02_swoon01_place06_night_spring.mp4,3840x2160,4:56,falldown,PIA,296.22926666666666,8878,.mp4,718.68,1.78,29.97002997002997
75
+ FILE210101-003713F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.27,1.78,29.97002997002997
76
+ FILE210101-010727F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.04,1.78,29.97002997002997
77
+ 245-5_cam02_swoon01_place04_night_spring.mp4,3840x2160,5:7,falldown,PIA,306.97333333333336,9200,.mp4,369.43,1.78,29.97002997002997
78
+ 115-1_cam01_swoon01_place02_night_spring.mp4,3840x2160,5:8,falldown,PIA,308.2746333333333,9239,.mp4,362.95,1.78,29.97002997002997
79
+ 110-2_cam01_swoon01_place01_day_spring.mp4,3840x2160,4:55,falldown,PIA,295.06143333333335,8843,.mp4,709.46,1.78,29.97002997002997
80
+ FILE210101-005228F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.06,1.78,29.97002997002997
81
+ 114-2_cam01_swoon03_place03_day_spring.mp4,3840x2160,5:6,falldown,PIA,306.306,9180,.mp4,438.87,1.78,29.97002997002997
82
+ 117-1_cam02_swoon01_place04_night_spring.mp4,3840x2160,5:8,falldown,PIA,308.4081,9243,.mp4,747.54,1.78,29.97002997002997
83
+ 104-2_cam01_swoon01_place04_day_spring.mp4,3840x2160,5:6,falldown,PIA,305.7054,9162,.mp4,742.15,1.78,29.97002997002997
84
+ 103-1_cam01_swoon01_place04_day_spring.mp4,3840x2160,5:8,falldown,PIA,307.5072,9216,.mp4,746.19,1.78,29.97002997002997
85
+ 99-4_cam03_swoon01_place03_day_winter.mp4,3840x2160,5:7,falldown,PIA,306.6396666666667,9190,.mp4,329.27,1.78,29.97002997002997
86
+ 240-3_cam01_swoon01_place02_night_spring.mp4,3840x2160,5:8,falldown,PIA,308.04106666666667,9232,.mp4,330.92,1.78,29.97002997002997
87
+ 115-5_cam02_swoon01_place02_night_spring.mp4,3840x2160,5:10,falldown,PIA,309.57593333333335,9278,.mp4,750.26,1.78,29.97002997002997
88
+ 640-2_cam01_swoon01_place01_day_summer.mp4,3840x2160,5:0,falldown,PIA,300.3333666666667,9001,.mp4,425.3,1.78,29.97002997002997
89
+ 517-3_cam03_swoon03_place04_night_winter.mp4,3840x2160,5:7,falldown,PIA,306.9066,9198,.mp4,330.13,1.78,29.97002997002997
90
+ 107-1_cam01_swoon01_place06_night_spring.mp4,3840x2160,4:0,falldown,PIA,240.24,7200,.mp4,580.49,1.78,29.97002997002997
91
+ 106-1_cam02_swoon01_place05_day_spring.mp4,3840x2160,5:1,falldown,PIA,300.56693333333334,9008,.mp4,430.69,1.78,29.97002997002997
92
+ 110-1_cam01_swoon01_place01_day_spring.mp4,3840x2160,4:55,falldown,PIA,294.6276666666667,8830,.mp4,708.43,1.78,29.97002997002997
93
+ 116-4_cam02_swoon01_place03_day_summer.mp4,3840x2160,5:7,falldown,PIA,306.53956666666664,9187,.mp4,340.57,1.78,29.97002997002997
94
+ 109-5_cam02_swoon02_place01_night_summer.mp4,3840x2160,5:12,falldown,PIA,312.312,9360,.mp4,757.09,1.78,29.97002997002997
95
+ 105-5_cam01_swoon01_place05_night_summer.mp4,3840x2160,5:11,falldown,PIA,311.34436666666664,9331,.mp4,333.82,1.78,29.97002997002997
96
+ 112-6_cam01_swoon02_place01_night_summer.mp4,3840x2160,4:0,falldown,PIA,240.24,7200,.mp4,582.44,1.78,29.97002997002997
97
+ 108-2_cam01_swoon01_place06_night_summer.mp4,3840x2160,4:60,falldown,PIA,299.7327666666667,8983,.mp4,726.91,1.78,29.97002997002997
98
+ 120-1_cam02_swoon02_place06_day_summer.mp4,3840x2160,5:22,falldown,PIA,322.2219,9657,.mp4,781.36,1.78,29.97002997002997
99
+ 113-2_cam02_swoon01_place08_day_summer.mp4,3840x2160,5:14,falldown,PIA,313.7134,9402,.mp4,450.0,1.78,29.97002997002997
sheet_manager/sheet_checker/sheet_check.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+ import logging
3
+ import gspread
4
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
5
+
6
+ class SheetChecker:
7
+ def __init__(self, sheet_manager: SheetManager):
8
+ """SheetChecker 초기화"""
9
+ self.sheet_manager = sheet_manager
10
+ self.bench_sheet_manager = None
11
+ self.logger = logging.getLogger(__name__)
12
+ self._init_bench_sheet()
13
+
14
+ def _init_bench_sheet(self):
15
+ """model 시트용 시트 매니저 초기화"""
16
+ self.bench_sheet_manager = type(self.sheet_manager)(
17
+ spreadsheet_url=self.sheet_manager.spreadsheet_url,
18
+ worksheet_name="model",
19
+ column_name="Model name"
20
+ )
21
+
22
+ def add_benchmark_column(self, column_name: str):
23
+ """새로운 벤치마크 컬럼 추가"""
24
+ try:
25
+ headers = self.bench_sheet_manager.get_available_columns()
26
+
27
+ if column_name in headers:
28
+ return
29
+
30
+ new_col_index = len(headers) + 1
31
+ cell = gspread.utils.rowcol_to_a1(1, new_col_index)
32
+ self.bench_sheet_manager.sheet.update(cell, [[column_name]])
33
+
34
+
35
+ # 관련 컬럼 추가 (벤치마크이름*100)
36
+ next_col_index = new_col_index + 1
37
+ next_cell = gspread.utils.rowcol_to_a1(1, next_col_index)
38
+ self.bench_sheet_manager.sheet.update(next_cell, [[f"{column_name}*100"]])
39
+
40
+ self.logger.info(f"새로운 벤치마크 컬럼들 추가됨: {column_name}, {column_name}*100")
41
+ # 컬럼 추가 후 시트 매니저 재연결
42
+ self.bench_sheet_manager._connect_to_sheet(validate_column=False)
43
+
44
+ except Exception as e:
45
+ self.logger.error(f"벤치마크 컬럼 {column_name} 추가 중 오류 발생: {str(e)}")
46
+ raise
47
+
48
+ def check_model_and_benchmark(self, model_name: str, benchmark_name: str) -> Tuple[bool, bool]:
49
+ """
50
+ 모델 존재 여부와 벤치마크 상태를 확인하고, 필요한 경우 모델 정보를 추가
51
+
52
+ Args:
53
+ model_name: 확인할 모델 이름
54
+ benchmark_name: 확인할 벤치마크 이름
55
+
56
+ Returns:
57
+ Tuple[bool, bool]: (모델이 새로 추가되었는지 여부, 벤치마크가 이미 존재하는지 여부)
58
+ """
59
+ try:
60
+ # 모델 존재 여부 확인
61
+ model_exists = self._check_model_exists(model_name)
62
+ model_added = False
63
+
64
+ # 모델이 없으면 추가
65
+ if not model_exists:
66
+ self._add_new_model(model_name)
67
+ model_added = True
68
+ self.logger.info(f"새로운 모델 추가됨: {model_name}")
69
+
70
+ # 벤치마크 컬럼이 없으면 추가
71
+ available_columns = self.bench_sheet_manager.get_available_columns()
72
+ if benchmark_name not in available_columns:
73
+ self.add_benchmark_column(benchmark_name)
74
+ self.logger.info(f"새로운 벤치마크 컬럼 추가됨: {benchmark_name}")
75
+
76
+ # 벤치마크 상태 확인
77
+ benchmark_exists = self._check_benchmark_exists(model_name, benchmark_name)
78
+
79
+ return model_added, benchmark_exists
80
+
81
+ except Exception as e:
82
+ self.logger.error(f"모델/벤치마크 확인 중 오류 발생: {str(e)}")
83
+ raise
84
+
85
+ def _check_model_exists(self, model_name: str) -> bool:
86
+ """모델 존재 여부 확인"""
87
+ try:
88
+ self.bench_sheet_manager.change_column("Model name")
89
+ values = self.bench_sheet_manager.get_all_values()
90
+ return model_name in values
91
+ except Exception as e:
92
+ self.logger.error(f"모델 존재 여부 확인 중 오류 발생: {str(e)}")
93
+ raise
94
+
95
+ def _add_new_model(self, model_name: str):
96
+ """새로운 모델 정보 추가"""
97
+ try:
98
+ model_info = {
99
+ "Model name": model_name,
100
+ "Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
101
+ "Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
102
+ }
103
+
104
+ for column_name, value in model_info.items():
105
+ self.bench_sheet_manager.change_column(column_name)
106
+ self.bench_sheet_manager.push(value)
107
+
108
+ except Exception as e:
109
+ self.logger.error(f"모델 정보 추가 중 오류 발생: {str(e)}")
110
+ raise
111
+
112
+ def _check_benchmark_exists(self, model_name: str, benchmark_name: str) -> bool:
113
+ """벤치마크 값 존재 여부 확인"""
114
+ try:
115
+ # 해당 모델의 벤치마크 값 확인
116
+ self.bench_sheet_manager.change_column("Model name")
117
+ all_values = self.bench_sheet_manager.get_all_values()
118
+ row_index = all_values.index(model_name) + 2
119
+
120
+ self.bench_sheet_manager.change_column(benchmark_name)
121
+ value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
122
+
123
+ return bool(value and value.strip())
124
+
125
+ except Exception as e:
126
+ self.logger.error(f"벤치마크 존재 여부 확인 중 오류 발생: {str(e)}")
127
+ raise
128
+
129
+ # 사용 예시
130
+ if __name__ == "__main__":
131
+ sheet_manager = SheetManager()
132
+ checker = SheetChecker(sheet_manager)
133
+
134
+ model_added, benchmark_exists = checker.check_model_and_benchmark(
135
+ model_name="test-model",
136
+ benchmark_name="COCO"
137
+ )
138
+
139
+ print(f"Model added: {model_added}")
140
+ print(f"Benchmark exists: {benchmark_exists}")
sheet_manager/sheet_convert/json2sheet.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
3
+ import json
4
+ from typing import Optional, Dict
5
+
6
+ def update_benchmark_json(
7
+ model_name: str,
8
+ benchmark_data: dict,
9
+ worksheet_name: str = "metric",
10
+ target_column: str = "benchmark" # 타겟 칼럼 파라미터 추가
11
+ ):
12
+ """
13
+ 특정 모델의 벤치마크 데이터를 JSON 형태로 업데이트합니다.
14
+
15
+ Args:
16
+ model_name (str): 업데이트할 모델 이름
17
+ benchmark_data (dict): 업데이트할 벤치마크 데이터 딕셔너리
18
+ worksheet_name (str): 작업할 워크시트 이름 (기본값: "metric")
19
+ target_column (str): 업데이트할 타겟 칼럼 이름 (기본값: "benchmark")
20
+ """
21
+ sheet_manager = SheetManager(worksheet_name=worksheet_name)
22
+
23
+ # 딕셔너리를 JSON 문자열로 변환
24
+ json_str = json.dumps(benchmark_data, ensure_ascii=False)
25
+
26
+ # 모델명을 기준으로 지정된 칼럼 업데이트
27
+ row = sheet_manager.update_cell_by_condition(
28
+ condition_column="Model name", # 모델명이 있는 칼럼
29
+ condition_value=model_name, # 찾을 모델명
30
+ target_column=target_column, # 업데이트할 타겟 칼럼
31
+ target_value=json_str # 업데이트할 JSON 값
32
+ )
33
+
34
+ if row:
35
+ print(f"Successfully updated {target_column} data for model: {model_name}")
36
+ else:
37
+ print(f"Model {model_name} not found in the sheet")
38
+
39
+
40
+
41
+ def get_benchmark_dict(
42
+ model_name: str,
43
+ worksheet_name: str = "metric",
44
+ target_column: str = "benchmark",
45
+ save_path: Optional[str] = None
46
+ ) -> Dict:
47
+ """
48
+ 시트에서 특정 모델의 벤치마크 JSON 데이터를 가져와 딕셔너리로 변환합니다.
49
+
50
+ Args:
51
+ model_name (str): 가져올 모델 이름
52
+ worksheet_name (str): 작업할 워크시트 이름 (기본값: "metric")
53
+ target_column (str): 데이터를 가져올 칼럼 이름 (기본값: "benchmark")
54
+ save_path (str, optional): 딕셔너리를 저장할 JSON 파일 경로
55
+
56
+ Returns:
57
+ Dict: 벤치마크 데이터 딕셔너리. 데이터가 없거나 JSON 파싱 실패시 빈 딕셔너리 반환
58
+ """
59
+ sheet_manager = SheetManager(worksheet_name=worksheet_name)
60
+
61
+ try:
62
+ # 모든 데이터 가져오기
63
+ data = sheet_manager.sheet.get_all_records()
64
+
65
+ # 해당 모델 찾기
66
+ target_row = next(
67
+ (row for row in data if row.get("Model name") == model_name),
68
+ None
69
+ )
70
+
71
+ if not target_row:
72
+ print(f"Model {model_name} not found in the sheet")
73
+ return {}
74
+
75
+ # 타겟 칼럼의 JSON 문자열 가져오기
76
+ json_str = target_row.get(target_column)
77
+
78
+ if not json_str:
79
+ print(f"No data found in {target_column} for model: {model_name}")
80
+ return {}
81
+
82
+ # JSON 문자열을 딕셔너리로 변환
83
+ result_dict = json.loads(json_str)
84
+
85
+ # 결과 저장 (save_path가 제공된 경우)
86
+ if save_path:
87
+ with open(save_path, 'w', encoding='utf-8') as f:
88
+ json.dump(result_dict, f, ensure_ascii=False, indent=2)
89
+ print(f"Successfully saved dictionary to: {save_path}")
90
+
91
+ return result_dict
92
+
93
+ except json.JSONDecodeError:
94
+ print(f"Failed to parse JSON data for model: {model_name}")
95
+ return {}
96
+ except Exception as e:
97
+ print(f"Error occurred: {str(e)}")
98
+ return {}
99
+
100
+ def str2json(json_str):
101
+ """
102
+ 문자열을 JSON 객체로 변환합니다.
103
+
104
+ Args:
105
+ json_str (str): JSON 형식의 문자열
106
+
107
+ Returns:
108
+ dict: 파싱된 JSON 객체, 실패시 None
109
+ """
110
+ try:
111
+ return json.loads(json_str)
112
+ except json.JSONDecodeError as e:
113
+ print(f"JSON Parsing Error: {e}")
114
+ return None
115
+ except Exception as e:
116
+ print(f"Unexpected Error: {e}")
117
+ return None
sheet_manager/sheet_crud/create_col.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ import gspread
4
+ from huggingface_hub import HfApi
5
+ import os
6
+ from dotenv import load_dotenv
7
+ from enviroments.convert import get_json_from_env_var
8
+
9
+ load_dotenv()
10
+
11
+ def push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization):
12
+ """
13
+ Fetches model names from Hugging Face and updates a Google Sheet with the names, links, and HTML links.
14
+
15
+ Args:
16
+ json_key_path (str): Path to the Google service account JSON key file.
17
+ spreadsheet_url (str): URL of the Google Spreadsheet.
18
+ sheet_name (str): Name of the sheet to update.
19
+ access_token (str): Hugging Face access token.
20
+ organization (str): Organization name on Hugging Face.
21
+ """
22
+ # Authorize Google Sheets API
23
+ scope = ['https://spreadsheets.google.com/feeds',
24
+ 'https://www.googleapis.com/auth/drive']
25
+ json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
26
+ credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
27
+ gc = gspread.authorize(credential)
28
+
29
+ # Open the Google Spreadsheet
30
+ doc = gc.open_by_url(spreadsheet_url)
31
+ sheet = doc.worksheet(sheet_name)
32
+
33
+ # Fetch existing data from the sheet
34
+ existing_data = pd.DataFrame(sheet.get_all_records())
35
+
36
+ # Fetch models from Hugging Face
37
+ api = HfApi()
38
+ models = api.list_models(author=organization, use_auth_token=access_token)
39
+
40
+ # Extract model names, links, and HTML links
41
+ model_details = [{
42
+ "Model name": model.modelId.split("/")[1],
43
+ "Model link": f"https://huggingface.co/{model.modelId}",
44
+ "Model": f"<a target=\"_blank\" href=\"https://huggingface.co/{model.modelId}\" style=\"color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;\">{model.modelId}</a>"
45
+ } for model in models]
46
+
47
+ new_data_df = pd.DataFrame(model_details)
48
+
49
+ # Check for duplicates and update only new model names
50
+ if "Model name" in existing_data.columns:
51
+ existing_model_names = existing_data["Model name"].tolist()
52
+ else:
53
+ existing_model_names = []
54
+
55
+ new_data_df = new_data_df[~new_data_df["Model name"].isin(existing_model_names)]
56
+
57
+ if not new_data_df.empty:
58
+ # Append new model names, links, and HTML links to the existing data
59
+ updated_data = pd.concat([existing_data, new_data_df], ignore_index=True)
60
+
61
+ # Push updated data back to the sheet
62
+ updated_data = updated_data.replace([float('inf'), float('-inf')], None) # Infinity 값을 None으로 변환
63
+ updated_data = updated_data.fillna('') # NaN 값을 빈 문자열로 변환
64
+ sheet.update([updated_data.columns.values.tolist()] + updated_data.values.tolist())
65
+ print("New model names, links, and HTML links successfully added to Google Sheet.")
66
+ else:
67
+ print("No new model names to add.")
68
+
69
+ # Example usage
70
+ if __name__ == "__main__":
71
+ spreadsheet_url = os.getenv("SPREADSHEET_URL")
72
+ access_token = os.getenv("ACCESS_TOKEN")
73
+ sheet_name = "시트1"
74
+ organization = "PIA-SPACE-LAB"
75
+
76
+ push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization)
sheet_manager/sheet_crud/sheet_crud.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ import gspread
4
+ from dotenv import load_dotenv
5
+ from enviroments.convert import get_json_from_env_var
6
+ from typing import Optional, List
7
+
8
+ load_dotenv(override=True)
9
+
10
+ class SheetManager:
11
+ def __init__(self, spreadsheet_url: Optional[str] = None,
12
+ worksheet_name: str = "flag",
13
+ column_name: str = "huggingface_id"):
14
+ """
15
+ Initialize SheetManager with Google Sheets credentials and connection.
16
+
17
+ Args:
18
+ spreadsheet_url (str, optional): URL of the Google Spreadsheet.
19
+ If None, takes from environment variable.
20
+ worksheet_name (str): Name of the worksheet to operate on.
21
+ Defaults to "flag".
22
+ column_name (str): Name of the column to operate on.
23
+ Defaults to "huggingface_id".
24
+ """
25
+ self.spreadsheet_url = spreadsheet_url or os.getenv("SPREADSHEET_URL")
26
+ if not self.spreadsheet_url:
27
+ raise ValueError("Spreadsheet URL not provided and not found in environment variables")
28
+
29
+ self.worksheet_name = worksheet_name
30
+ self.column_name = column_name
31
+
32
+ # Initialize credentials and client
33
+ self._init_google_client()
34
+
35
+ # Initialize sheet connection
36
+ self.doc = None
37
+ self.sheet = None
38
+ self.col_index = None
39
+ self._connect_to_sheet(validate_column=True)
40
+
41
+ def _init_google_client(self):
42
+ """Initialize Google Sheets client with credentials."""
43
+ scope = ['https://spreadsheets.google.com/feeds',
44
+ 'https://www.googleapis.com/auth/drive']
45
+ json_key_dict = get_json_from_env_var("GOOGLE_CREDENTIALS")
46
+ credentials = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
47
+ self.client = gspread.authorize(credentials)
48
+
49
+ def _connect_to_sheet(self, validate_column: bool = True):
50
+ """
51
+ Connect to the specified Google Sheet and initialize necessary attributes.
52
+
53
+ Args:
54
+ validate_column (bool): Whether to validate the column name exists
55
+ """
56
+ try:
57
+ self.doc = self.client.open_by_url(self.spreadsheet_url)
58
+
59
+ # Try to get the worksheet
60
+ try:
61
+ self.sheet = self.doc.worksheet(self.worksheet_name)
62
+ except gspread.exceptions.WorksheetNotFound:
63
+ raise ValueError(f"Worksheet '{self.worksheet_name}' not found in spreadsheet")
64
+
65
+ # Get headers
66
+ self.headers = self.sheet.row_values(1)
67
+
68
+ # Validate column only if requested
69
+ if validate_column:
70
+ try:
71
+ self.col_index = self.headers.index(self.column_name) + 1
72
+ except ValueError:
73
+ # If column not found, use first available column
74
+ if self.headers:
75
+ self.column_name = self.headers[0]
76
+ self.col_index = 1
77
+ print(f"Column '{self.column_name}' not found. Using first available column: '{self.headers[0]}'")
78
+ else:
79
+ raise ValueError("No columns found in worksheet")
80
+
81
+ except Exception as e:
82
+ if isinstance(e, ValueError):
83
+ raise e
84
+ raise ConnectionError(f"Failed to connect to sheet: {str(e)}")
85
+
86
+ def change_worksheet(self, worksheet_name: str, column_name: Optional[str] = None):
87
+ """
88
+ Change the current worksheet and optionally the column.
89
+
90
+ Args:
91
+ worksheet_name (str): Name of the worksheet to switch to
92
+ column_name (str, optional): Name of the column to switch to
93
+ """
94
+ old_worksheet = self.worksheet_name
95
+ old_column = self.column_name
96
+
97
+ try:
98
+ self.worksheet_name = worksheet_name
99
+ if column_name:
100
+ self.column_name = column_name
101
+
102
+ # First connect without column validation
103
+ self._connect_to_sheet(validate_column=False)
104
+
105
+ # Then validate the column if specified
106
+ if column_name:
107
+ self.change_column(column_name)
108
+ else:
109
+ # Validate existing column in new worksheet
110
+ try:
111
+ self.col_index = self.headers.index(self.column_name) + 1
112
+ except ValueError:
113
+ # If column not found, use first available column
114
+ if self.headers:
115
+ self.column_name = self.headers[0]
116
+ self.col_index = 1
117
+ print(f"Column '{old_column}' not found in new worksheet. Using first available column: '{self.headers[0]}'")
118
+ else:
119
+ raise ValueError("No columns found in worksheet")
120
+
121
+ print(f"Successfully switched to worksheet: {worksheet_name}, using column: {self.column_name}")
122
+
123
+ except Exception as e:
124
+ # Restore previous state on error
125
+ self.worksheet_name = old_worksheet
126
+ self.column_name = old_column
127
+ self._connect_to_sheet()
128
+ raise e
129
+
130
+ def change_column(self, column_name: str):
131
+ """
132
+ Change the target column.
133
+
134
+ Args:
135
+ column_name (str): Name of the column to switch to
136
+ """
137
+ if not self.headers:
138
+ self.headers = self.sheet.row_values(1)
139
+
140
+ try:
141
+ self.col_index = self.headers.index(column_name) + 1
142
+ self.column_name = column_name
143
+ print(f"Successfully switched to column: {column_name}")
144
+ except ValueError:
145
+ raise ValueError(f"Column '{column_name}' not found in worksheet. Available columns: {', '.join(self.headers)}")
146
+
147
+ def get_available_worksheets(self) -> List[str]:
148
+ """Get list of all available worksheets in the spreadsheet."""
149
+ return [worksheet.title for worksheet in self.doc.worksheets()]
150
+
151
+ def get_available_columns(self) -> List[str]:
152
+ """Get list of all available columns in the current worksheet."""
153
+ return self.headers if self.headers else self.sheet.row_values(1)
154
+
155
+ def _reconnect_if_needed(self):
156
+ """Reconnect to the sheet if the connection is lost."""
157
+ try:
158
+ self.sheet.row_values(1)
159
+ except (gspread.exceptions.APIError, AttributeError):
160
+ self._init_google_client()
161
+ self._connect_to_sheet()
162
+
163
+ def _fetch_column_data(self) -> List[str]:
164
+ """Fetch all data from the huggingface_id column."""
165
+ values = self.sheet.col_values(self.col_index)
166
+ return values[1:] # Exclude header
167
+
168
+ def _update_sheet(self, data: List[str]):
169
+ """Update the entire column with new data."""
170
+ try:
171
+ # Prepare the range for update (excluding header)
172
+ start_cell = gspread.utils.rowcol_to_a1(2, self.col_index) # Start from row 2
173
+ end_cell = gspread.utils.rowcol_to_a1(len(data) + 2, self.col_index)
174
+ range_name = f"{start_cell}:{end_cell}"
175
+
176
+ # Convert data to 2D array format required by gspread
177
+ cells = [[value] for value in data]
178
+
179
+ # Update the range
180
+ self.sheet.update(range_name, cells)
181
+ except Exception as e:
182
+ print(f"Error updating sheet: {str(e)}")
183
+ raise
184
+
185
+ def push(self, text: str) -> int:
186
+ """
187
+ Push a text value to the next empty cell in the huggingface_id column.
188
+
189
+ Args:
190
+ text (str): Text to push to the sheet
191
+
192
+ Returns:
193
+ int: The row number where the text was pushed
194
+ """
195
+ try:
196
+ self._reconnect_if_needed()
197
+
198
+ # Get all values in the huggingface_id column
199
+ column_values = self.sheet.col_values(self.col_index)
200
+
201
+ # Find the next empty row
202
+ next_row = None
203
+ for i in range(1, len(column_values)):
204
+ if not column_values[i].strip():
205
+ next_row = i + 1
206
+ break
207
+
208
+ # If no empty row found, append to the end
209
+ if next_row is None:
210
+ next_row = len(column_values) + 1
211
+
212
+ # Update the cell
213
+ self.sheet.update_cell(next_row, self.col_index, text)
214
+ print(f"Successfully pushed value: {text} to row {next_row}")
215
+ return next_row
216
+
217
+ except Exception as e:
218
+ print(f"Error pushing to sheet: {str(e)}")
219
+ raise
220
+
221
+ def pop(self) -> Optional[str]:
222
+ """Remove and return the most recent value."""
223
+ try:
224
+ self._reconnect_if_needed()
225
+ data = self._fetch_column_data()
226
+
227
+ if not data or not data[0].strip():
228
+ return None
229
+
230
+ value = data.pop(0) # Remove first value
231
+ data.append("") # Add empty string at the end to maintain sheet size
232
+
233
+ self._update_sheet(data)
234
+ print(f"Successfully popped value: {value}")
235
+ return value
236
+
237
+ except Exception as e:
238
+ print(f"Error popping from sheet: {str(e)}")
239
+ raise
240
+
241
+ def delete(self, value: str) -> List[int]:
242
+ """Delete all occurrences of a value."""
243
+ try:
244
+ self._reconnect_if_needed()
245
+ data = self._fetch_column_data()
246
+
247
+ # Find all indices before deletion
248
+ indices = [i + 1 for i, v in enumerate(data) if v.strip() == value.strip()]
249
+ if not indices:
250
+ print(f"Value '{value}' not found in sheet")
251
+ return []
252
+
253
+ # Remove matching values and add empty strings at the end
254
+ data = [v for v in data if v.strip() != value.strip()]
255
+ data.extend([""] * len(indices)) # Add empty strings to maintain sheet size
256
+
257
+ self._update_sheet(data)
258
+ print(f"Successfully deleted value '{value}' from rows: {indices}")
259
+ return indices
260
+
261
+ except Exception as e:
262
+ print(f"Error deleting from sheet: {str(e)}")
263
+ raise
264
+
265
+ def update_cell_by_condition(self, condition_column: str, condition_value: str, target_column: str, target_value: str) -> Optional[int]:
266
+ """
267
+ Update the value of a cell based on a condition in another column.
268
+
269
+ Args:
270
+ condition_column (str): The column to check the condition on.
271
+ condition_value (str): The value to match in the condition column.
272
+ target_column (str): The column where the value should be updated.
273
+ target_value (str): The new value to set in the target column.
274
+
275
+ Returns:
276
+ Optional[int]: The row number where the value was updated, or None if no matching row was found.
277
+ """
278
+ try:
279
+ self._reconnect_if_needed()
280
+
281
+ # Get all column headers
282
+ headers = self.sheet.row_values(1)
283
+
284
+ # Find the indices for the condition and target columns
285
+ try:
286
+ condition_col_index = headers.index(condition_column) + 1
287
+ except ValueError:
288
+ raise ValueError(f"조건 칼럼 '{condition_column}'이(가) 없습니다.")
289
+
290
+ try:
291
+ target_col_index = headers.index(target_column) + 1
292
+ except ValueError:
293
+ raise ValueError(f"목표 칼럼 '{target_column}'이(가) 없습니다.")
294
+
295
+ # Get all rows of data
296
+ data = self.sheet.get_all_records()
297
+
298
+ # Find the row that matches the condition
299
+ for i, row in enumerate(data):
300
+ if row.get(condition_column) == condition_value:
301
+ # Update the target column in the matching row
302
+ row_number = i + 2 # Row index starts at 2 (1 is header)
303
+ self.sheet.update_cell(row_number, target_col_index, target_value)
304
+ print(f"Updated row {row_number}: Set {target_column} to '{target_value}' where {condition_column} is '{condition_value}'")
305
+ return row_number
306
+
307
+ print(f"조건에 맞는 행을 찾을 수 없습니다: {condition_column} = '{condition_value}'")
308
+ return None
309
+
310
+ except Exception as e:
311
+ print(f"Error updating cell by condition: {str(e)}")
312
+ raise
313
+
314
+ def get_all_values(self) -> List[str]:
315
+ """Get all values from the huggingface_id column."""
316
+ self._reconnect_if_needed()
317
+ return [v for v in self._fetch_column_data() if v.strip()]
318
+
319
+ # Example usage
320
+ if __name__ == "__main__":
321
+ # Initialize sheet manager
322
+ sheet_manager = SheetManager()
323
+
324
+ # # Push some test values
325
+ # sheet_manager.push("test-model-1")
326
+ # sheet_manager.push("test-model-2")
327
+ # sheet_manager.push("test-model-3")
328
+
329
+ # print("Initial values:", sheet_manager.get_all_values())
330
+
331
+ # # Pop the most recent value
332
+ # popped = sheet_manager.pop()
333
+ # print(f"Popped value: {popped}")
334
+ # print("After pop:", sheet_manager.get_all_values())
335
+
336
+ # # Delete a specific value
337
+ # deleted_rows = sheet_manager.delete("test-model-2")
338
+ # print(f"Deleted from rows: {deleted_rows}")
339
+ # print("After delete:", sheet_manager.get_all_values())
340
+
341
+ row_updated = sheet_manager.update_cell_by_condition(
342
+ condition_column="model",
343
+ condition_value="msr",
344
+ target_column="pia",
345
+ target_value="new_value"
346
+ )
347
+
sheet_manager/sheet_loader/sheet2df.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ import gspread
4
+ from dotenv import load_dotenv
5
+ import os
6
+ import json
7
+ from enviroments.convert import get_json_from_env_var
8
+ load_dotenv()
9
+
10
+ def sheet2df(sheet_name:str = "model"):
11
+ """
12
+ Reads data from a specified Google Spreadsheet and converts it into a Pandas DataFrame.
13
+
14
+ Steps:
15
+ 1. Authenticate using a service account JSON key.
16
+ 2. Open the spreadsheet by its URL.
17
+ 3. Select the worksheet to read.
18
+ 4. Convert the worksheet data to a Pandas DataFrame.
19
+ 5. Clean up the DataFrame:
20
+ - Rename columns using the first row of data.
21
+ - Drop the first row after renaming columns.
22
+
23
+ Returns:
24
+ pd.DataFrame: A Pandas DataFrame containing the cleaned data from the spreadsheet.
25
+
26
+ Note:
27
+ - The following variables must be configured before using this function:
28
+ - `json_key_path`: Path to the service account JSON key file.
29
+ - `spreadsheet_url`: URL of the Google Spreadsheet.
30
+ - `sheet_name`: Name of the worksheet to load.
31
+
32
+ Dependencies:
33
+ - pandas
34
+ - gspread
35
+ - oauth2client
36
+ """
37
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
38
+ json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
39
+ credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
40
+ gc = gspread.authorize(credential)
41
+
42
+ spreadsheet_url = os.getenv("SPREADSHEET_URL")
43
+ doc = gc.open_by_url(spreadsheet_url)
44
+ sheet = doc.worksheet(sheet_name)
45
+
46
+ # Convert to DataFrame
47
+ df = pd.DataFrame(sheet.get_all_values())
48
+ # Clean DataFrame
49
+ df.rename(columns=df.iloc[0], inplace=True)
50
+ df.drop(df.index[0], inplace=True)
51
+
52
+ return df
sheet_manager/sheet_monitor/sheet_sync.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from typing import Optional, Callable
4
+ import logging
5
+
6
+ class SheetMonitor:
7
+ def __init__(self, sheet_manager, check_interval: float = 1.0):
8
+ """
9
+ Initialize SheetMonitor with a sheet manager instance.
10
+ """
11
+ self.sheet_manager = sheet_manager
12
+ self.check_interval = check_interval
13
+
14
+ # Threading control
15
+ self.monitor_thread = None
16
+ self.is_running = threading.Event()
17
+ self.pause_monitoring = threading.Event()
18
+ self.monitor_paused = threading.Event()
19
+
20
+ # Queue status
21
+ self.has_data = threading.Event()
22
+
23
+ # Logging setup
24
+ logging.basicConfig(level=logging.INFO)
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def start_monitoring(self):
28
+ """Start the monitoring thread."""
29
+ if self.monitor_thread is not None and self.monitor_thread.is_alive():
30
+ self.logger.warning("Monitoring thread is already running")
31
+ return
32
+
33
+ self.is_running.set()
34
+ self.pause_monitoring.clear()
35
+ self.monitor_thread = threading.Thread(target=self._monitor_loop)
36
+ self.monitor_thread.daemon = True
37
+ self.monitor_thread.start()
38
+ self.logger.info("Started monitoring thread")
39
+
40
+ def stop_monitoring(self):
41
+ """Stop the monitoring thread."""
42
+ self.is_running.clear()
43
+ if self.monitor_thread:
44
+ self.monitor_thread.join()
45
+ self.logger.info("Stopped monitoring thread")
46
+
47
+ def pause(self):
48
+ """Pause the monitoring."""
49
+ self.pause_monitoring.set()
50
+ self.monitor_paused.wait()
51
+ self.logger.info("Monitoring paused")
52
+
53
+ def resume(self):
54
+ """Resume the monitoring."""
55
+ self.pause_monitoring.clear()
56
+ self.monitor_paused.clear()
57
+ # 즉시 체크 수행
58
+ self.logger.info("Monitoring resumed, checking for new data...")
59
+ values = self.sheet_manager.get_all_values()
60
+ if values:
61
+ self.has_data.set()
62
+ self.logger.info(f"Found data after resume: {values}")
63
+
64
+
65
+ def _monitor_loop(self):
66
+ """Main monitoring loop that checks for data in sheet."""
67
+ while self.is_running.is_set():
68
+ if self.pause_monitoring.is_set():
69
+ self.monitor_paused.set()
70
+ self.pause_monitoring.wait()
71
+ self.monitor_paused.clear()
72
+ # continue
73
+
74
+ try:
75
+ # Check if there's any data in the sheet
76
+ values = self.sheet_manager.get_all_values()
77
+ self.logger.info(f"Monitoring: Current column={self.sheet_manager.column_name}, "
78
+ f"Values found={len(values)}, "
79
+ f"Has data={self.has_data.is_set()}")
80
+
81
+ if values: # If there's any non-empty value
82
+ self.has_data.set()
83
+ self.logger.info(f"Data detected: {values}")
84
+ else:
85
+ self.has_data.clear()
86
+ self.logger.info("No data in sheet, waiting...")
87
+
88
+ time.sleep(self.check_interval)
89
+
90
+ except Exception as e:
91
+ self.logger.error(f"Error in monitoring loop: {str(e)}")
92
+ time.sleep(self.check_interval)
93
+
94
+ class MainLoop:
95
+ def __init__(self, sheet_manager, sheet_monitor, callback_function: Callable = None):
96
+ """
97
+ Initialize MainLoop with sheet manager and monitor instances.
98
+ """
99
+ self.sheet_manager = sheet_manager
100
+ self.monitor = sheet_monitor
101
+ self.callback = callback_function
102
+ self.is_running = threading.Event()
103
+ self.logger = logging.getLogger(__name__)
104
+
105
+ def start(self):
106
+ """Start the main processing loop."""
107
+ self.is_running.set()
108
+ self.monitor.start_monitoring()
109
+ self._main_loop()
110
+
111
+ def stop(self):
112
+ """Stop the main processing loop."""
113
+ self.is_running.clear()
114
+ self.monitor.stop_monitoring()
115
+
116
+ def process_new_value(self):
117
+ """Process values by calling pop function for multiple columns and custom callback."""
118
+ try:
119
+ # Store original column
120
+ original_column = self.sheet_manager.column_name
121
+
122
+ # Pop from huggingface_id column
123
+ model_id = self.sheet_manager.pop()
124
+
125
+ if model_id:
126
+ # Pop from benchmark_name column
127
+ self.sheet_manager.change_column("benchmark_name")
128
+ benchmark_name = self.sheet_manager.pop()
129
+
130
+ # Pop from prompt_cfg_name column
131
+ self.sheet_manager.change_column("prompt_cfg_name")
132
+ prompt_cfg_name = self.sheet_manager.pop()
133
+
134
+ # Return to original column
135
+ self.sheet_manager.change_column(original_column)
136
+
137
+ self.logger.info(f"Processed values - model_id: {model_id}, "
138
+ f"benchmark_name: {benchmark_name}, "
139
+ f"prompt_cfg_name: {prompt_cfg_name}")
140
+
141
+ if self.callback:
142
+ # Pass all three values to callback
143
+ self.callback(model_id, benchmark_name, prompt_cfg_name)
144
+
145
+ return model_id, benchmark_name, prompt_cfg_name
146
+
147
+ except Exception as e:
148
+ self.logger.error(f"Error processing values: {str(e)}")
149
+ # Return to original column in case of error
150
+ try:
151
+ self.sheet_manager.change_column(original_column)
152
+ except:
153
+ pass
154
+ return None
155
+
156
+ def _main_loop(self):
157
+ """Main processing loop."""
158
+ while self.is_running.is_set():
159
+ # Wait for data to be available
160
+ if self.monitor.has_data.wait(timeout=1.0):
161
+ # Pause monitoring
162
+ self.monitor.pause()
163
+
164
+ # Process the value
165
+ self.process_new_value()
166
+
167
+ # Check if there's still data in the sheet
168
+ values = self.sheet_manager.get_all_values()
169
+ self.logger.info(f"After processing: Current column={self.sheet_manager.column_name}, "
170
+ f"Values remaining={len(values)}")
171
+
172
+ if not values:
173
+ self.monitor.has_data.clear()
174
+ self.logger.info("All data processed, clearing has_data flag")
175
+ else:
176
+ self.logger.info(f"Remaining data: {values}")
177
+
178
+ # Resume monitoring
179
+ self.monitor.resume()
180
+ ## TODO
181
+ # API 분당 호출 문제로 만약에 참조하다가 실패할 경우 대기했다가 다시 시도하게끔 설계
182
+
183
+
184
+ # Example usage
185
+ if __name__ == "__main__":
186
+ import sys
187
+ from pathlib import Path
188
+ sys.path.append(str(Path(__file__).parent.parent.parent))
189
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
190
+ from pia_bench.pipe_line.piepline import PiaBenchMark
191
+ def my_custom_function(huggingface_id, benchmark_name, prompt_cfg_name):
192
+ piabenchmark = PiaBenchMark(huggingface_id, benchmark_name, prompt_cfg_name)
193
+ piabenchmark.bench_start()
194
+
195
+ # Initialize components
196
+ sheet_manager = SheetManager()
197
+ monitor = SheetMonitor(sheet_manager, check_interval=10.0)
198
+ main_loop = MainLoop(sheet_manager, monitor, callback_function=my_custom_function)
199
+
200
+ try:
201
+ main_loop.start()
202
+ while True:
203
+ time.sleep(5)
204
+ except KeyboardInterrupt:
205
+ main_loop.stop()
topk.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "VIDEO_CFG": {
3
+ "window_size": 6,
4
+ "time_sampling": 15,
5
+ "tile_size": null
6
+ },
7
+ "MODEL_CFG": {
8
+ "_comment": "",
9
+ "_link": "",
10
+ "name": "assets/c7.pt",
11
+ "type": "clip4clip"
12
+ },
13
+ "PROMPT_CFG": [
14
+ {
15
+ "event": "falldown",
16
+ "top_candidates": 1,
17
+ "alert_threshold": 1,
18
+ "prompts": {
19
+ "normal": [
20
+ {
21
+ "sentence": "typical"
22
+ }
23
+ ],
24
+ "abnormal": [
25
+ {
26
+ "sentence": "falldown"
27
+ }
28
+ ]
29
+ }
30
+ },
31
+ {
32
+ "event": "violence",
33
+ "top_candidates": 1,
34
+ "alert_threshold": 1,
35
+ "prompts": {
36
+ "normal": [
37
+ {
38
+ "sentence": "normal"
39
+ },
40
+ {
41
+ "sentence": "average"
42
+ },
43
+ {
44
+ "sentence": "typical"
45
+ }
46
+ ],
47
+ "abnormal": [
48
+ {
49
+ "sentence": "violence with kicking and punching"
50
+ },
51
+ {
52
+ "sentence": "physical confrontation between people"
53
+ },
54
+ {
55
+ "sentence": "violence"
56
+ }
57
+ ]
58
+ }
59
+ },
60
+ {
61
+ "event": "fire",
62
+ "top_candidates": 1,
63
+ "alert_threshold": 1,
64
+ "prompts": {
65
+ "normal": [
66
+ {
67
+ "sentence": "tomato"
68
+ }
69
+ ],
70
+ "abnormal": [
71
+ {
72
+ "sentence": "fire"
73
+ },
74
+ {
75
+ "sentence": "video of be on fire with a stove"
76
+ },
77
+ {
78
+ "sentence": "a fire is burning"
79
+ },
80
+ {
81
+ "sentence": "embers are burning"
82
+ }
83
+ ]
84
+ }
85
+ }
86
+
87
+ ]
88
+ }
utils/bench_meta.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ from utils.except_dir import cust_listdir
7
+ def get_video_metadata(video_path, category, benchmark):
8
+ """Extract metadata from a video file."""
9
+ cap = cv2.VideoCapture(video_path)
10
+
11
+ if not cap.isOpened():
12
+ return None
13
+ # Extract metadata
14
+ video_name = os.path.basename(video_path)
15
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
16
+ fps = cap.get(cv2.CAP_PROP_FPS)
17
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
18
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
19
+ resolution = f"{frame_width}x{frame_height}"
20
+ duration_seconds = frame_count / fps if fps > 0 else 0
21
+ aspect_ratio = round(frame_width / frame_height, 2) if frame_height > 0 else 0
22
+ file_size = os.path.getsize(video_path) / (1024 * 1024) # MB
23
+ file_format = os.path.splitext(video_name)[1].lower()
24
+ cap.release()
25
+
26
+ return {
27
+ "video_name": video_name,
28
+ "resolution": resolution,
29
+ "video_duration": f"{duration_seconds // 60:.0f}:{duration_seconds % 60:.0f}",
30
+ "category": category,
31
+ "benchmark": benchmark,
32
+ "duration_seconds": duration_seconds,
33
+ "total_frames": frame_count,
34
+ "file_format": file_format,
35
+ "file_size_mb": round(file_size, 2),
36
+ "aspect_ratio": aspect_ratio,
37
+ "fps": fps
38
+ }
39
+
40
+ def process_videos_in_directory(root_dir):
41
+ """Process all videos in the given directory structure."""
42
+ video_metadata_list = []
43
+
44
+ # 벤치마크 폴더들을 순회
45
+ for benchmark in cust_listdir(root_dir):
46
+ benchmark_path = os.path.join(root_dir, benchmark)
47
+ if not os.path.isdir(benchmark_path):
48
+ continue
49
+
50
+ # dataset 폴더 경로
51
+ dataset_path = os.path.join(benchmark_path, "dataset")
52
+ if not os.path.isdir(dataset_path):
53
+ continue
54
+
55
+ # dataset 폴더 안의 카테고리 폴더들을 순회
56
+ for category in cust_listdir(dataset_path):
57
+ category_path = os.path.join(dataset_path, category)
58
+ if not os.path.isdir(category_path):
59
+ continue
60
+
61
+ # 각 카테고리 폴더 안의 비디오 파일들을 처리
62
+ for file in cust_listdir(category_path):
63
+ file_path = os.path.join(category_path, file)
64
+
65
+ if file_path.lower().endswith(('.mp4', '.avi', '.mkv', '.mov', 'MOV')):
66
+ metadata = get_video_metadata(file_path, category, benchmark)
67
+ if metadata:
68
+ video_metadata_list.append(metadata)
69
+ # df = pd.DataFrame(video_metadata_list)
70
+ # df.to_csv('sample.csv', index=False)
71
+ return pd.DataFrame(video_metadata_list)
72
+
utils/except_dir.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ import enviroments.config as config
4
+
5
+ def cust_listdir(directory: str) -> List[str]:
6
+ """
7
+ os.listdir와 유사하게 작동하지만 config에 정의된 폴더/파일들을 제외하고 목록을 반환합니다.
8
+
9
+ Args:
10
+ directory (str): 탐색할 디렉토리 경로
11
+
12
+ Returns:
13
+ List[str]: config의 EXCLUDE_DIRS에 정의된 폴더/파일들을 제외한 목록
14
+ """
15
+ return [item for item in os.listdir(directory) if item not in config.EXCLUDE_DIRS]
utils/hf_api.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ from typing import Optional, List, Dict, Any
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass
6
+ class ModelInfo:
7
+ """모델 정보를 저장하는 데이터 클래스"""
8
+ model_id: str
9
+ last_modified: Any
10
+ downloads: int
11
+ private: bool
12
+ attributes: Dict[str, Any]
13
+
14
+ class HuggingFaceInfoManager:
15
+ def __init__(self, access_token: Optional[str] = None, organization: str = "PIA-SPACE-LAB"):
16
+ """
17
+ HuggingFace API 관리자 클래스 초기화
18
+
19
+ Args:
20
+ access_token (str, optional): HuggingFace 액세스 토큰
21
+ organization (str): 조직 이름 (기본값: "PIA-SPACE-LAB")
22
+
23
+ Raises:
24
+ ValueError: access_token이 None일 경우 발생
25
+ """
26
+ if access_token is None:
27
+ raise ValueError("액세스 토큰은 필수 입력값입니다. HuggingFace에서 발급받은 토큰을 입력해주세요.")
28
+
29
+ self.api = HfApi()
30
+ self.access_token = access_token
31
+ self.organization = organization
32
+
33
+ # API 호출 결과를 바로 처리하여 저장
34
+ api_models = self.api.list_models(author=self.organization, use_auth_token=self.access_token)
35
+ self._stored_models = []
36
+ self._model_infos = []
37
+
38
+ # 모든 모델 정보를 미리 처리하여 저장
39
+ for model in api_models:
40
+ # 기본 정보 저장
41
+ model_attrs = {}
42
+ for attr in dir(model):
43
+ if not attr.startswith("_"):
44
+ model_attrs[attr] = getattr(model, attr)
45
+
46
+ # ModelInfo 객체 생성 및 저장
47
+ model_info = ModelInfo(
48
+ model_id=model.modelId,
49
+ last_modified=model.lastModified,
50
+ downloads=model.downloads,
51
+ private=model.private,
52
+ attributes=model_attrs
53
+ )
54
+ self._model_infos.append(model_info)
55
+ self._stored_models.append(model)
56
+
57
+ def get_model_info(self) -> List[Dict[str, Any]]:
58
+ """모든 모델의 정보를 반환"""
59
+ return [
60
+ {
61
+ 'model_id': info.model_id,
62
+ 'last_modified': info.last_modified,
63
+ 'downloads': info.downloads,
64
+ 'private': info.private,
65
+ **info.attributes
66
+ }
67
+ for info in self._model_infos
68
+ ]
69
+
70
+ def get_model_ids(self) -> List[str]:
71
+ """모든 모델의 ID 리스트 반환"""
72
+ return [info.model_id for info in self._model_infos]
73
+
74
+ def get_private_models(self) -> List[Dict[str, Any]]:
75
+ """비공개 모델 정보 반환"""
76
+ return [
77
+ {
78
+ 'model_id': info.model_id,
79
+ 'last_modified': info.last_modified,
80
+ 'downloads': info.downloads,
81
+ 'private': info.private,
82
+ **info.attributes
83
+ }
84
+ for info in self._model_infos if info.private
85
+ ]
86
+
87
+ def get_public_models(self) -> List[Dict[str, Any]]:
88
+ """공개 모델 정보 반환"""
89
+ return [
90
+ {
91
+ 'model_id': info.model_id,
92
+ 'last_modified': info.last_modified,
93
+ 'downloads': info.downloads,
94
+ 'private': info.private,
95
+ **info.attributes
96
+ }
97
+ for info in self._model_infos if not info.private
98
+ ]
99
+
100
+ def refresh_models(self) -> None:
101
+ """모델 정보 새로고침 (새로운 API 호출 수행)"""
102
+ # 클래스 재초기화
103
+ self.__init__(self.access_token, self.organization)
utils/parser.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List, Tuple
3
+
4
+ def load_config(config_path: str) -> Dict:
5
+ """JSON 설정 파일을 읽어서 딕셔너리로 반환"""
6
+ with open(config_path, 'r', encoding='utf-8') as f:
7
+ return json.load(f)
8
+
9
+ class PromptManager:
10
+ def __init__(self, config_path: str):
11
+ self.config = load_config(config_path)
12
+ self.sentences, self.index_mapping = self._extract_all_sentences_with_index()
13
+ self.reverse_mapping = self._create_reverse_mapping()
14
+
15
+ def _extract_all_sentences_with_index(self) -> Tuple[List[str], Dict]:
16
+ """모든 sentence와 인덱스 매핑 추출"""
17
+ sentences = []
18
+ index_mapping = {}
19
+
20
+ for event_idx, event_config in enumerate(self.config.get('PROMPT_CFG', [])):
21
+ prompts = event_config.get('prompts', {})
22
+ for status in ['normal', 'abnormal']:
23
+ for prompt_idx, prompt in enumerate(prompts.get(status, [])):
24
+ sentence = prompt.get('sentence', '')
25
+ sentences.append(sentence)
26
+ index_mapping[(event_idx, status, prompt_idx)] = sentence
27
+
28
+ return sentences, index_mapping
29
+
30
+ def _create_reverse_mapping(self) -> Dict:
31
+ """sentence -> indices 역방향 매핑 생성"""
32
+ reverse_map = {}
33
+ for indices, sent in self.index_mapping.items():
34
+ if sent not in reverse_map:
35
+ reverse_map[sent] = []
36
+ reverse_map[sent].append(indices)
37
+ return reverse_map
38
+
39
+ def get_sentence_indices(self, sentence: str) -> List[Tuple[int, str, int]]:
40
+ """특정 sentence의 모든 인덱스 위치 반환"""
41
+ return self.reverse_mapping.get(sentence, [])
42
+
43
+ def get_details_by_sentence(self, sentence: str) -> List[Dict]:
44
+ """sentence로 모든 관련 상세 정보 찾아 반환"""
45
+ indices = self.get_sentence_indices(sentence)
46
+ return [self.get_details_by_index(*idx) for idx in indices]
47
+
48
+ def get_details_by_index(self, event_idx: int, status: str, prompt_idx: int) -> Dict:
49
+ """인덱스로 상세 정보 찾아 반환"""
50
+ event_config = self.config['PROMPT_CFG'][event_idx]
51
+ prompt = event_config['prompts'][status][prompt_idx]
52
+
53
+ return {
54
+ 'event': event_config['event'],
55
+ 'status': status,
56
+ 'sentence': prompt['sentence'],
57
+ 'top_candidates': event_config['top_candidates'],
58
+ 'alert_threshold': event_config['alert_threshold'],
59
+ 'event_idx': event_idx,
60
+ 'prompt_idx': prompt_idx
61
+ }
62
+
63
+ def get_all_sentences(self) -> List[str]:
64
+ """모든 sentence 리스트 반환"""
65
+ return self.sentences