jsi6452 commited on
Commit
ec2d80c
ยท
1 Parent(s): ce58bca
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # *.json
174
+
175
+ assets
176
+ DevMACS-AI-solution-devmacs
177
+ Research-AI-research-t2v_f1score_evaluator
178
+ .env
179
+ enviroments/abnormal-situation-leaderboard-3ca42d06719e.json
180
+ leaderboard_test
181
+ enviroments/deep-byte-352904-a072fdf439e7.json
182
+ PIA-SPACE_LeaderBoard
183
+ *.txt
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README copy.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # MAP
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from leaderboard_ui.tab.leaderboard_tab import maat_gag_tab
4
+ from leaderboard_ui.tab.food_map_tab import map_food_tab
5
+
6
+ abs_path = Path(__file__).parent
7
+
8
+ with gr.Blocks() as demo:
9
+ gr.Markdown("""
10
+ # ๐Ÿ—บ๏ธ ๋ง›๊ฐ ๐Ÿ—บ๏ธ
11
+ """)
12
+ with gr.Tabs():
13
+ maat_gag_tab()
14
+ map_food_tab()
15
+
16
+ if __name__ == "__main__":
17
+ demo.launch()
enviroments/.gitkeep ADDED
File without changes
enviroments/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ BASE_BENCH_PATH = "/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
3
+ EXCLUDE_DIRS = {"@eaDir", 'temp'}
4
+ TYPES = [
5
+ "str",
6
+ "str",
7
+ "str",
8
+ "str",
9
+ "str",
10
+ "str",
11
+ "str",
12
+ "number",
13
+ "number",
14
+ "number",
15
+ "number",
16
+ "number",
17
+ "markdown",
18
+ "markdown",
19
+ "number",
20
+ "number",
21
+ "number",
22
+ "number",
23
+ "number",
24
+ "number",
25
+ "number",
26
+ "str",
27
+ "str",
28
+ "str",
29
+ "str",
30
+ "bool",
31
+ "str",
32
+ "number",
33
+ "number",
34
+ "bool",
35
+ "str",
36
+ "bool",
37
+ "bool",
38
+ "str",
39
+ ]
40
+
41
+ ON_LOAD_COLUMNS = [
42
+ "๋ถ„๋ฅ˜",
43
+ ]
44
+
45
+ OFF_LOAD_COLUMNS = []
46
+ HIDE_COLUMNS = ["๋„ค์ด๋ฒ„๋ณ„์ *100" , "์นด์นด์˜ค๋ณ„์ *100"]
47
+ FILTER_COLUMNS = ["์‹๋‹น๋ช…"]
48
+
49
+ NUMERIC_COLUMNS = ["PIA"]
enviroments/convert.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from dotenv import load_dotenv
4
+ load_dotenv()
5
+
6
+ def get_json_from_env_var(env_var_name):
7
+ """
8
+ ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ JSON ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์™€ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜.
9
+ :param env_var_name: ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์ด๋ฆ„
10
+ :return: ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ์˜ JSON ๋ฐ์ดํ„ฐ
11
+ """
12
+ json_string = os.getenv(env_var_name)
13
+ if not json_string:
14
+ raise EnvironmentError(f"ํ™˜๊ฒฝ ๋ณ€์ˆ˜ '{env_var_name}'๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
15
+
16
+ try:
17
+ # ์ค„๋ฐ”๊ฟˆ(\n)์„ ์ด์Šค์ผ€์ดํ”„ ๋ฌธ์ž(\\n)๋กœ ๋ณ€ํ™˜
18
+ json_string = json_string.replace("\n", "\\n")
19
+
20
+ # JSON ๋ฌธ์ž์—ด์„ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜
21
+ json_data = json.loads(json_string)
22
+ except json.JSONDecodeError as e:
23
+ raise ValueError(f"JSON ๋ณ€ํ™˜ ์‹คํŒจ: {e}")
24
+
25
+ return json_data
26
+
27
+
28
+
29
+ def json_to_env_var(json_file_path, env_var_name="JSON_ENV_VAR"):
30
+ """
31
+ ์ฃผ์–ด์ง„ JSON ํŒŒ์ผ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜.
32
+
33
+ :param json_file_path: JSON ํŒŒ์ผ ๊ฒฝ๋กœ
34
+ :param env_var_name: ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: JSON_ENV_VAR)
35
+ :return: None
36
+ """
37
+ try:
38
+ # JSON ํŒŒ์ผ ์ฝ๊ธฐ
39
+ with open(json_file_path, 'r') as json_file:
40
+ json_data = json.load(json_file)
41
+
42
+ # JSON ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
43
+ json_string = json.dumps(json_data)
44
+
45
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ˜•ํƒœ๋กœ ์ถœ๋ ฅ
46
+ env_variable = f'{env_var_name}={json_string}'
47
+ print("\nํ™˜๊ฒฝ ๋ณ€์ˆ˜๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์ถœ๋ ฅ๊ฐ’:\n")
48
+ print(env_variable)
49
+ print("\n์œ„ ๊ฐ’์„ .env ํŒŒ์ผ์— ๋ณต์‚ฌํ•˜์—ฌ ๋ถ™์—ฌ๋„ฃ์œผ์„ธ์š”.")
50
+ except FileNotFoundError:
51
+ print(f"ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {json_file_path}")
52
+ except json.JSONDecodeError:
53
+ print(f"์œ ํšจํ•œ JSON ํŒŒ์ผ์ด ์•„๋‹™๋‹ˆ๋‹ค: {json_file_path}")
54
+
leaderboard_ui/tab/dataset_visual_tab.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from leaderboard_ui.tab.submit_tab import submit_tab
4
+ from leaderboard_ui.tab.leaderboard_tab import leaderboard_tab
5
+ abs_path = Path(__file__).parent
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import pandas as pd
9
+ import numpy as np
10
+ from utils.bench_meta import process_videos_in_directory
11
+ # Mock ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
12
+ def create_mock_data():
13
+ benchmarks = ['VQA-2023', 'ImageQuality-2024', 'VideoEnhance-2024']
14
+ categories = ['Animation', 'Game', 'Movie', 'Sports', 'Vlog']
15
+
16
+ data_list = []
17
+
18
+ for benchmark in benchmarks:
19
+ n_videos = np.random.randint(50, 100)
20
+ for _ in range(n_videos):
21
+ category = np.random.choice(categories)
22
+
23
+ data_list.append({
24
+ "video_name": f"video_{np.random.randint(1000, 9999)}.mp4",
25
+ "resolution": np.random.choice(["1920x1080", "3840x2160", "1280x720"]),
26
+ "video_duration": f"{np.random.randint(0, 10)}:{np.random.randint(0, 60)}",
27
+ "category": category,
28
+ "benchmark": benchmark,
29
+ "duration_seconds": np.random.randint(30, 600),
30
+ "total_frames": np.random.randint(1000, 10000),
31
+ "file_format": ".mp4",
32
+ "file_size_mb": round(np.random.uniform(10, 1000), 2),
33
+ "aspect_ratio": 16/9,
34
+ "fps": np.random.choice([24, 30, 60])
35
+ })
36
+
37
+ return pd.DataFrame(data_list)
38
+
39
+ # Mock ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
40
+ # df = process_videos_in_directory("/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench")
41
+ df = pd.read_csv("sample.csv")
42
+ print("DataFrame shape:", df.shape)
43
+ print("DataFrame columns:", df.columns)
44
+ print("DataFrame head:\n", df.head())
45
+ def create_category_pie_chart(df, selected_benchmark, selected_categories=None):
46
+ filtered_df = df[df['benchmark'] == selected_benchmark]
47
+
48
+ if selected_categories:
49
+ filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
50
+
51
+ category_counts = filtered_df['category'].value_counts()
52
+
53
+ fig = px.pie(
54
+ values=category_counts.values,
55
+ names=category_counts.index,
56
+ title=f'{selected_benchmark} - Video Distribution by Category',
57
+ hole=0.3
58
+ )
59
+ fig.update_traces(textposition='inside', textinfo='percent+label')
60
+
61
+ return fig
62
+
63
+ ###TODO ์ŠคํŠธ๋ง์ผ๊ฒฝ์šฐ ์–ด์ผ€ ์ฒ˜๋ฆฌ
64
+
65
+ def create_bar_chart(df, selected_benchmark, selected_categories, selected_column):
66
+ # Filter by benchmark and categories
67
+ filtered_df = df[df['benchmark'] == selected_benchmark]
68
+ if selected_categories:
69
+ filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
70
+
71
+ # Create bar chart for selected column
72
+ fig = px.bar(
73
+ filtered_df,
74
+ x=selected_column,
75
+ y='video_name',
76
+ color='category', # Color by category
77
+ title=f'{selected_benchmark} - Video {selected_column}',
78
+ orientation='h', # Horizontal bar chart
79
+ color_discrete_sequence=px.colors.qualitative.Set3 # Color palette
80
+ )
81
+
82
+ # Adjust layout
83
+ fig.update_layout(
84
+ height=max(400, len(filtered_df) * 30), # Adjust height based on data
85
+ yaxis={'categoryorder': 'total ascending'}, # Sort by value
86
+ margin=dict(l=200), # Margin for long video names
87
+ showlegend=True, # Show legend
88
+ legend=dict(
89
+ orientation="h", # Horizontal legend
90
+ yanchor="bottom",
91
+ y=1.02, # Place legend above graph
92
+ xanchor="right",
93
+ x=1
94
+ )
95
+ )
96
+
97
+ return fig
98
+
99
+ def submit_tab():
100
+ with gr.Tab("๐Ÿš€ Submit here! "):
101
+ with gr.Row():
102
+ gr.Markdown("# โœ‰๏ธโœจ Submit your Result here!")
103
+
104
+ def visual_tab():
105
+ with gr.Tab("๐Ÿ“Š Bench Info"):
106
+ with gr.Row():
107
+ benchmark_dropdown = gr.Dropdown(
108
+ choices=sorted(df['benchmark'].unique().tolist()),
109
+ value=sorted(df['benchmark'].unique().tolist())[0],
110
+ label="Select Benchmark",
111
+ interactive=True
112
+ )
113
+
114
+ category_multiselect = gr.CheckboxGroup(
115
+ choices=sorted(df['category'].unique().tolist()),
116
+ label="Select Categories (empty for all)",
117
+ interactive=True
118
+ )
119
+
120
+ # Pie chart
121
+ pie_plot_output = gr.Plot(label="pie")
122
+
123
+ # Column selection dropdown
124
+ column_options = [
125
+ "video_duration", "duration_seconds", "total_frames",
126
+ "file_size_mb", "aspect_ratio", "fps", "file_format"
127
+ ]
128
+
129
+ column_dropdown = gr.Dropdown(
130
+ choices=column_options,
131
+ value=column_options[0],
132
+ label="Select Data to Compare",
133
+ interactive=True
134
+ )
135
+
136
+ # Bar chart
137
+ bar_plot_output = gr.Plot(label="video")
138
+
139
+ def update_plots(benchmark, categories, selected_column):
140
+ pie_chart = create_category_pie_chart(df, benchmark, categories)
141
+ bar_chart = create_bar_chart(df, benchmark, categories, selected_column)
142
+ return pie_chart, bar_chart
143
+
144
+ # Connect event handlers
145
+ benchmark_dropdown.change(
146
+ fn=update_plots,
147
+ inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
148
+ outputs=[pie_plot_output, bar_plot_output]
149
+ )
150
+ category_multiselect.change(
151
+ fn=update_plots,
152
+ inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
153
+ outputs=[pie_plot_output, bar_plot_output]
154
+ )
155
+ column_dropdown.change(
156
+ fn=update_plots,
157
+ inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
158
+ outputs=[pie_plot_output, bar_plot_output]
159
+ )
160
+
leaderboard_ui/tab/food_map_tab.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import pandas as pd
4
+ from dotenv import load_dotenv
5
+ import plotly.graph_objects as go
6
+ import gradio as gr
7
+ from gradio_leaderboard import Leaderboard, SelectColumns, ColumnFilter,SearchColumns
8
+ import enviroments.config as config
9
+ from sheet_manager.sheet_loader.sheet2df import sheet2df, add_scaled_columns
10
+
11
+ load_dotenv()
12
+ NAVER_CLIENT_ID = os.getenv("NAVER_KEY")
13
+ NAVER_CLIENT_SECRET = os.getenv("NAVER_SECRET_KEY")
14
+
15
+ def get_lat_lon_naver(address: str):
16
+ """
17
+ ๋„ค์ด๋ฒ„ Geocoding API๋ฅผ ์‚ฌ์šฉํ•ด ์ฃผ์†Œ๋ฅผ ์œ„๋„/๊ฒฝ๋„๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
18
+
19
+ :param address: ๋ณ€ํ™˜ํ•  ์ฃผ์†Œ (์˜ˆ: "์„œ์šธ ์„œ์ดˆ๊ตฌ ๋‚จ๋ถ€์ˆœํ™˜๋กœ358๊ธธ 8")
20
+ :return: (์œ„๋„, ๊ฒฝ๋„) ํŠœํ”Œ
21
+ """
22
+ url = "https://naveropenapi.apigw.ntruss.com/map-geocode/v2/geocode"
23
+ headers = {
24
+ "X-NCP-APIGW-API-KEY-ID": NAVER_CLIENT_ID,
25
+ "X-NCP-APIGW-API-KEY": NAVER_CLIENT_SECRET
26
+ }
27
+ params = {"query": address}
28
+
29
+ response = requests.get(url, headers=headers, params=params)
30
+
31
+ if response.status_code == 200:
32
+ data = response.json()
33
+ if "addresses" in data and len(data["addresses"]) > 0:
34
+ lat = float(data["addresses"][0]["y"]) # ์œ„๋„
35
+ lon = float(data["addresses"][0]["x"]) # ๊ฒฝ๋„
36
+ return lat, lon
37
+ return None, None # ๋ณ€ํ™˜ ์‹คํŒจ ์‹œ None ๋ฐ˜ํ™˜
38
+
39
+ df = sheet2df(sheet_name="์„œ์šธ")
40
+ for i in ["๋„ค์ด๋ฒ„๋ณ„์ ", "์นด์นด์˜ค๋ณ„์ "]:
41
+ df = add_scaled_columns(df, i)
42
+
43
+ df[['์œ„๋„', '๊ฒฝ๋„']] = df['์ฃผ์†Œ'].apply(lambda x: pd.Series(get_lat_lon_naver(x)))
44
+ df = df.dropna(subset=["์œ„๋„", "๊ฒฝ๋„"])
45
+ df["์œ„๋„"] = df["์œ„๋„"].astype(float)
46
+ df["๊ฒฝ๋„"] = df["๊ฒฝ๋„"].astype(float)
47
+ print(df[["์‹๋‹น๋ช…", "์ฃผ์†Œ", "์œ„๋„", "๊ฒฝ๋„"]]) # โœ… ๋ฐ์ดํ„ฐ ํ™•์ธ
48
+ print(df[["์œ„๋„", "๊ฒฝ๋„"]].dtypes) # โœ… ๋ฐ์ดํ„ฐ ํƒ€์ž… ํ™•์ธ
49
+
50
+ def plot_map(df):
51
+ """
52
+ Plotly๋ฅผ ์ด์šฉํ•ด ๋ง›์ง‘ ์œ„์น˜๋ฅผ ์ง€๋„์— ์‹œ๊ฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜
53
+ """
54
+
55
+ # ๐Ÿ”น ์ง€๋„ ์ค‘์‹ฌ์„ ๋ฐ์ดํ„ฐ ํ‰๊ท  ๊ฐ’์œผ๋กœ ์„ค์ •
56
+ center_lat = df["์œ„๋„"].mean()
57
+ center_lon = df["๊ฒฝ๋„"].mean()
58
+
59
+ fig = go.Figure(go.Scattermapbox(
60
+ lat=df["์œ„๋„"].tolist(),
61
+ lon=df["๊ฒฝ๋„"].tolist(),
62
+ mode='markers',
63
+ marker=go.scattermapbox.Marker(
64
+ size=10, # ๋งˆ์ปค ํฌ๊ธฐ ์„ค์ •
65
+ color="red", # ๋งˆ์ปค ์ƒ‰์ƒ ์„ค์ •
66
+ opacity=0.7
67
+ ),
68
+ customdata=df["์‹๋‹น๋ช…"].tolist(), # ์‹๋‹น๋ช… ํ‘œ์‹œ
69
+ hoverinfo="text",
70
+ hovertemplate="<b>์‹๋‹น๋ช…</b>: %{customdata}<br>" # ๋งˆ์šฐ์Šค ์˜ค๋ฒ„ ์‹œ ํ‘œ์‹œ๋  ํ…์ŠคํŠธ
71
+ ))
72
+
73
+ fig.update_layout(
74
+ mapbox_style="open-street-map", # โœ… ๋ฌด๋ฃŒ OpenStreetMap ์‚ฌ์šฉ
75
+ hovermode='closest',
76
+ mapbox=dict(
77
+ bearing=0,
78
+ center=go.layout.mapbox.Center(
79
+ lat=center_lat,
80
+ lon=center_lon
81
+ ),
82
+ pitch=0,
83
+ zoom=12 # ์ดˆ๊ธฐ ์คŒ ๋ ˆ๋ฒจ ์„ค์ •
84
+ ),
85
+ margin={"r":0, "t":0, "l":0, "b":0}
86
+ )
87
+ return fig
88
+
89
+ # โœ… Gradio UI
90
+ def map_interface():
91
+ return plot_map(df)
92
+
93
+ def map_food_tab():
94
+ with gr.Tab("๐Ÿ˜‹์ง€๋„๐Ÿ˜‹"):
95
+ gr.Markdown("# ๐Ÿ“ ๋ง›์ง‘ ์ง€๋„ ์‹œ๊ฐํ™”")
96
+ map_plot = gr.Plot(map_interface)
leaderboard_ui/tab/leaderboard_tab.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_leaderboard import Leaderboard, SelectColumns, ColumnFilter,SearchColumns
3
+ import enviroments.config as config
4
+ from sheet_manager.sheet_loader.sheet2df import sheet2df, add_scaled_columns
5
+
6
+ df = sheet2df(sheet_name="์„œ์šธ")
7
+ for i in ["๋„ค์ด๋ฒ„๋ณ„์ ", "์นด์นด์˜ค๋ณ„์ "]:
8
+ df = add_scaled_columns(df, i)
9
+ columns = df.columns.tolist()
10
+ print(columns)
11
+ df.columns = df.columns.str.strip()
12
+ print(df.columns.tolist())
13
+ # print(df["๋„ค์ด๋ฒ„๋ณ„์ *100"])
14
+
15
+ def maat_gag_tab():
16
+ with gr.Tab("๐Ÿ˜‹๋ง›๊ฐ๋ชจ์ž„๐Ÿ˜‹"):
17
+ leaderboard = Leaderboard(
18
+ value=df,
19
+ select_columns=SelectColumns(
20
+ # default_selection=config.ON_LOAD_COLUMNS,
21
+ default_selection=columns,
22
+ cant_deselect=config.OFF_LOAD_COLUMNS,
23
+ label="Select Columns to Display:",
24
+ info="Check"
25
+ ),
26
+
27
+ search_columns=SearchColumns(
28
+ primary_column="์‹๋‹น๋ช…",
29
+ secondary_columns=["๋Œ€ํ‘œ๋ฉ”๋‰ด"],
30
+ placeholder="Search",
31
+ label="Search"
32
+ ),
33
+ hide_columns=config.HIDE_COLUMNS,
34
+ filter_columns=[
35
+ ColumnFilter(
36
+ column="์นด์นด์˜ค๋ณ„์ *100",
37
+ type="slider",
38
+ min=0, # 77
39
+ max=500, # 92
40
+ # default=[min_val, max_val],
41
+ default = [400 ,500],
42
+ label="์นด์นด์˜ค๋ณ„์ " # ์‹ค์ œ ๊ฐ’์˜ 100๋ฐฐ๋กœ ํ‘œ์‹œ๋จ,
43
+ ),
44
+ ColumnFilter(
45
+ column="๋„ค์ด๋ฒ„๋ณ„์ *100",
46
+ type="slider",
47
+ min=0, # 77
48
+ max=500, # 92
49
+ # default=[min_val, max_val],
50
+ default = [400 ,500],
51
+ label="๋„ค์ด๋ฒ„๋ณ„์ " # ์‹ค์ œ ๊ฐ’์˜ 100๋ฐฐ๋กœ ํ‘œ์‹œ๋จ,
52
+ ),
53
+ ColumnFilter(
54
+ column= "์ง€์—ญ",
55
+ label="์ง€์—ญ"
56
+ )
57
+ ],
58
+
59
+ datatype=config.TYPES,
60
+ # column_widths=["33%", "10%"],
61
+ )
62
+ refresh_button = gr.Button("๐Ÿ”„ Refresh Leaderboard")
63
+
64
+ def refresh_leaderboard():
65
+ return sheet2df()
66
+
67
+ refresh_button.click(
68
+ refresh_leaderboard,
69
+ inputs=[],
70
+ outputs=leaderboard,
71
+ )
72
+
leaderboard_ui/tab/map_tab.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import pandas as pd
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+ # โœ… ๋„ค์ด๋ฒ„ API ํ‚ค ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
7
+ NAVER_CLIENT_ID = os.getenv("NAVER_KEY")
8
+ NAVER_CLIENT_SECRET = os.getenv("NAVER_SECRET_KEY")
9
+
10
+ def get_lat_lon_naver(address: str):
11
+ """
12
+ ๋„ค์ด๋ฒ„ Geocoding API๋ฅผ ์‚ฌ์šฉํ•ด ์ฃผ์†Œ๋ฅผ ์œ„๋„/๊ฒฝ๋„๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
13
+
14
+ :param address: ๋ณ€ํ™˜ํ•  ์ฃผ์†Œ (์˜ˆ: "์„œ์šธ ์„œ์ดˆ๊ตฌ ๋‚จ๋ถ€์ˆœํ™˜๋กœ358๊ธธ 8")
15
+ :return: (์œ„๋„, ๊ฒฝ๋„) ํŠœํ”Œ
16
+ """
17
+ url = "https://naveropenapi.apigw.ntruss.com/map-geocode/v2/geocode"
18
+ headers = {
19
+ "X-NCP-APIGW-API-KEY-ID": NAVER_CLIENT_ID,
20
+ "X-NCP-APIGW-API-KEY": NAVER_CLIENT_SECRET
21
+ }
22
+ params = {"query": address}
23
+
24
+ response = requests.get(url, headers=headers, params=params)
25
+
26
+ if response.status_code == 200:
27
+ data = response.json()
28
+ if "addresses" in data and len(data["addresses"]) > 0:
29
+ lat = float(data["addresses"][0]["y"]) # ์œ„๋„
30
+ lon = float(data["addresses"][0]["x"]) # ๊ฒฝ๋„
31
+ return lat, lon
32
+ return None, None # ๋ณ€ํ™˜ ์‹คํŒจ ์‹œ None ๋ฐ˜ํ™˜
33
+
34
+ # โœ… ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„ (์‹๋‹น ๋ชฉ๋ก)
35
+ df = pd.DataFrame({
36
+ "์‹๋‹น๋ช…": ["์˜๋™์กฑ๋ฐœ", "๊ณ ๋„์‹", "๋‚œํฌ", "์„์ง€๋กœ๋ณด์„", "์–‘๋ฐ์‚ฐ"],
37
+ "์ฃผ์†Œ": [
38
+ "์„œ์šธ ์„œ์ดˆ๊ตฌ ๋‚จ๋ถ€์ˆœํ™˜๋กœ358๊ธธ 8",
39
+ "์„œ์šธ ์†กํŒŒ๊ตฌ ๋ฐฑ์ œ๊ณ ๋ถ„๋กœ45๊ธธ 28",
40
+ "์„œ์šธ ์„ฑ๋™๊ตฌ ์„œ์šธ์ˆฒ4๊ธธ 18-8",
41
+ "์„œ์šธ ์ค‘๊ตฌ ๋งˆ๋ฅธ๋‚ด๋กœ 11-10",
42
+ "์„œ์šธ ์„œ๋Œ€๋ฌธ๊ตฌ ๊ฐ€์ขŒ๋กœ 36"
43
+ ]
44
+ })
45
+
46
+ # โœ… ์ฃผ์†Œ๋ฅผ ์œ„๋„/๊ฒฝ๋„๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์— ์ถ”๊ฐ€
47
+ df[['์œ„๋„', '๊ฒฝ๋„']] = df['์ฃผ์†Œ'].apply(lambda x: pd.Series(get_lat_lon_naver(x)))
48
+
49
+ # โœ… ๋ณ€ํ™˜๋œ ๋ฐ์ดํ„ฐ ์ถœ๋ ฅ
50
+ print(df)
leaderboard_ui/tab/metric_visaul_tab.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ abs_path = Path(__file__).parent
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ import pandas as pd
7
+ import numpy as np
8
+ from sheet_manager.sheet_loader.sheet2df import sheet2df
9
+ from sheet_manager.sheet_convert.json2sheet import str2json
10
+ # Mock ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
11
+ def calculate_avg_metrics(df):
12
+ """
13
+ ๊ฐ ๋ชจ๋ธ์˜ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ํ‰๊ท  ์„ฑ๋Šฅ ์ง€ํ‘œ๋ฅผ ๊ณ„์‚ฐ
14
+ """
15
+ metrics_data = []
16
+
17
+ for _, row in df.iterrows():
18
+ model_name = row['Model name']
19
+
20
+ # PIA๊ฐ€ ๋น„์–ด์žˆ๊ฑฐ๋‚˜ ๋‹ค๋ฅธ ๊ฐ’์ธ ๊ฒฝ์šฐ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
21
+ if pd.isna(row['PIA']) or not isinstance(row['PIA'], str):
22
+ print(f"Skipping model {model_name}: Invalid PIA data")
23
+ continue
24
+
25
+ try:
26
+ metrics = str2json(row['PIA'])
27
+
28
+ # metrics๊ฐ€ None์ด๊ฑฐ๋‚˜ dict๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
29
+ if not metrics or not isinstance(metrics, dict):
30
+ print(f"Skipping model {model_name}: Invalid JSON format")
31
+ continue
32
+
33
+ # ํ•„์š”ํ•œ ์นดํ…Œ๊ณ ๋ฆฌ๊ฐ€ ๋ชจ๋‘ ์žˆ๋Š”์ง€ ํ™•์ธ
34
+ required_categories = ['falldown', 'violence', 'fire']
35
+ if not all(cat in metrics for cat in required_categories):
36
+ print(f"Skipping model {model_name}: Missing required categories")
37
+ continue
38
+
39
+ # ํ•„์š”ํ•œ ๋ฉ”ํŠธ๋ฆญ์ด ๋ชจ๋‘ ์žˆ๋Š”์ง€ ํ™•์ธ
40
+ required_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
41
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
42
+
43
+ avg_metrics = {}
44
+ for metric in required_metrics:
45
+ try:
46
+ values = [metrics[cat][metric] for cat in required_categories
47
+ if metric in metrics[cat]]
48
+ if values: # ๊ฐ’์ด ์žˆ๋Š” ๊ฒฝ์šฐ๋งŒ ํ‰๊ท  ๊ณ„์‚ฐ
49
+ avg_metrics[metric] = sum(values) / len(values)
50
+ else:
51
+ avg_metrics[metric] = 0 # ๋˜๋Š” ๋‹ค๋ฅธ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
52
+ except (KeyError, TypeError) as e:
53
+ print(f"Error calculating {metric} for {model_name}: {str(e)}")
54
+ avg_metrics[metric] = 0 # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
55
+
56
+ metrics_data.append({
57
+ 'model_name': model_name,
58
+ **avg_metrics
59
+ })
60
+
61
+ except Exception as e:
62
+ print(f"Error processing model {model_name}: {str(e)}")
63
+ continue
64
+
65
+ return pd.DataFrame(metrics_data)
66
+
67
+ def create_performance_chart(df, selected_metrics):
68
+ """
69
+ ๋ชจ๋ธ๋ณ„ ์„ ํƒ๋œ ์„ฑ๋Šฅ ์ง€ํ‘œ์˜ ์ˆ˜ํ‰ ๋ง‰๋Œ€ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
70
+ """
71
+ fig = go.Figure()
72
+
73
+ # ๋ชจ๋ธ ์ด๋ฆ„ ๊ธธ์ด์— ๋”ฐ๋ฅธ ๋งˆ์ง„ ๊ณ„์‚ฐ
74
+ max_name_length = max([len(name) for name in df['model_name']])
75
+ left_margin = min(max_name_length * 7, 500) # ๊ธ€์ž ์ˆ˜์— ๋”ฐ๋ผ ๋งˆ์ง„ ์กฐ์ •, ์ตœ๋Œ€ 500
76
+
77
+ for metric in selected_metrics:
78
+ fig.add_trace(go.Bar(
79
+ name=metric,
80
+ y=df['model_name'], # y์ถ•์— ๋ชจ๋ธ ์ด๋ฆ„
81
+ x=df[metric], # x์ถ•์— ์„ฑ๋Šฅ ์ง€ํ‘œ ๊ฐ’
82
+ text=[f'{val:.3f}' for val in df[metric]],
83
+ textposition='auto',
84
+ orientation='h' # ์ˆ˜ํ‰ ๋ฐฉํ–ฅ ๋ง‰๋Œ€
85
+ ))
86
+
87
+ fig.update_layout(
88
+ title='Model Performance Comparison',
89
+ yaxis_title='Model Name',
90
+ xaxis_title='Performance',
91
+ barmode='group',
92
+ height=max(400, len(df) * 40), # ๋ชจ๋ธ ์ˆ˜์— ๋”ฐ๋ผ ๋†’์ด ์กฐ์ •
93
+ margin=dict(l=left_margin, r=50, t=50, b=50), # ์™ผ์ชฝ ๋งˆ์ง„ ๋™์  ์กฐ์ •
94
+ showlegend=True,
95
+ legend=dict(
96
+ orientation="h",
97
+ yanchor="bottom",
98
+ y=1.02,
99
+ xanchor="right",
100
+ x=1
101
+ ),
102
+ yaxis={'categoryorder': 'total ascending'} # ์„ฑ๋Šฅ ์ˆœ์œผ๋กœ ์ •๋ ฌ
103
+ )
104
+
105
+ # y์ถ• ๋ ˆ์ด๋ธ” ์Šคํƒ€์ผ ์กฐ์ •
106
+ fig.update_yaxes(tickfont=dict(size=10)) # ๊ธ€์ž ํฌ๊ธฐ ์กฐ์ •
107
+
108
+ return fig
109
+ def create_confusion_matrix(metrics_data, selected_category):
110
+ """ํ˜ผ๋™ ํ–‰๋ ฌ ์‹œ๊ฐํ™” ์ƒ์„ฑ"""
111
+ # ์„ ํƒ๋œ ์นดํ…Œ๊ณ ๋ฆฌ์˜ ํ˜ผ๋™ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ
112
+ tp = metrics_data[selected_category]['tp']
113
+ tn = metrics_data[selected_category]['tn']
114
+ fp = metrics_data[selected_category]['fp']
115
+ fn = metrics_data[selected_category]['fn']
116
+
117
+ # ํ˜ผ๋™ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ
118
+ z = [[tn, fp], [fn, tp]]
119
+ x = ['Negative', 'Positive']
120
+ y = ['Negative', 'Positive']
121
+
122
+ # ํžˆํŠธ๋งต ์ƒ์„ฑ
123
+ fig = go.Figure(data=go.Heatmap(
124
+ z=z,
125
+ x=x,
126
+ y=y,
127
+ colorscale=[[0, '#f7fbff'], [1, '#08306b']],
128
+ showscale=False,
129
+ text=[[str(val) for val in row] for row in z],
130
+ texttemplate="%{text}",
131
+ textfont={"color": "black", "size": 16}, # ๊ธ€์ž ์ƒ‰๏ฟฝ๏ฟฝ๏ฟฝ์„ ๊ฒ€์ •์ƒ‰์œผ๋กœ ๊ณ ์ •
132
+ ))
133
+
134
+ # ๋ ˆ์ด์•„์›ƒ ์—…๋ฐ์ดํŠธ
135
+ fig.update_layout(
136
+ title={
137
+ 'text': f'Confusion Matrix - {selected_category}',
138
+ 'y':0.9,
139
+ 'x':0.5,
140
+ 'xanchor': 'center',
141
+ 'yanchor': 'top'
142
+ },
143
+ xaxis_title='Predicted',
144
+ yaxis_title='Actual',
145
+ width=600, # ๋„ˆ๋น„ ์ฆ๊ฐ€
146
+ height=600, # ๋†’์ด ์ฆ๊ฐ€
147
+ margin=dict(l=80, r=80, t=100, b=80), # ์—ฌ๋ฐฑ ์กฐ์ •
148
+ paper_bgcolor='white',
149
+ plot_bgcolor='white',
150
+ font=dict(size=14) # ์ „์ฒด ํฐํŠธ ํฌ๊ธฐ ์กฐ์ •
151
+ )
152
+
153
+ # ์ถ• ์„ค์ •
154
+ fig.update_xaxes(side="bottom", tickfont=dict(size=14))
155
+ fig.update_yaxes(side="left", tickfont=dict(size=14))
156
+
157
+ return fig
158
+
159
+ def get_metrics_for_model(df, model_name, benchmark_name):
160
+ """ํŠน์ • ๋ชจ๋ธ๊ณผ ๋ฒค์น˜๋งˆํฌ์— ๋Œ€ํ•œ ๋ฉ”ํŠธ๋ฆญ์Šค ๋ฐ์ดํ„ฐ ์ถ”์ถœ"""
161
+ row = df[(df['Model name'] == model_name) & (df['Benchmark'] == benchmark_name)]
162
+ if not row.empty:
163
+ metrics = str2json(row['PIA'].iloc[0])
164
+ return metrics
165
+ return None
166
+
167
+ def metric_visual_tab():
168
+ # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
169
+ df = sheet2df(sheet_name="metric")
170
+ avg_metrics_df = calculate_avg_metrics(df)
171
+
172
+ # ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๋ฉ”ํŠธ๋ฆญ ๋ฆฌ์ŠคํŠธ
173
+ all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
174
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
175
+
176
+ with gr.Tab("๐Ÿ“Š Performance Visualization"):
177
+ with gr.Row():
178
+ metrics_multiselect = gr.CheckboxGroup(
179
+ choices=all_metrics,
180
+ value=[], # ์ดˆ๊ธฐ ์„ ํƒ ์—†์Œ
181
+ label="Select Performance Metrics",
182
+ interactive=True
183
+ )
184
+
185
+ # Performance comparison chart (์ดˆ๊ธฐ๊ฐ’ ์—†์Œ)
186
+ performance_plot = gr.Plot()
187
+
188
+ def update_plot(selected_metrics):
189
+ if not selected_metrics: # ์„ ํƒ๋œ ๋ฉ”ํŠธ๋ฆญ์ด ์—†๋Š” ๊ฒฝ์šฐ
190
+ return None
191
+
192
+ try:
193
+ # accuracy ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ
194
+ sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
195
+ return create_performance_chart(sorted_df, selected_metrics)
196
+ except Exception as e:
197
+ print(f"Error in update_plot: {str(e)}")
198
+ return None
199
+
200
+ # Connect event handler
201
+ metrics_multiselect.change(
202
+ fn=update_plot,
203
+ inputs=[metrics_multiselect],
204
+ outputs=[performance_plot]
205
+ )
206
+
207
+ def create_category_metrics_chart(metrics_data, selected_metrics):
208
+ """
209
+ ์„ ํƒ๋œ ๋ชจ๋ธ์˜ ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ์„ฑ๋Šฅ ์ง€ํ‘œ ์‹œ๊ฐํ™”
210
+ """
211
+ fig = go.Figure()
212
+ categories = ['falldown', 'violence', 'fire']
213
+
214
+ for metric in selected_metrics:
215
+ values = []
216
+ for category in categories:
217
+ values.append(metrics_data[category][metric])
218
+
219
+ fig.add_trace(go.Bar(
220
+ name=metric,
221
+ x=categories,
222
+ y=values,
223
+ text=[f'{val:.3f}' for val in values],
224
+ textposition='auto',
225
+ ))
226
+
227
+ fig.update_layout(
228
+ title='Performance Metrics by Category',
229
+ xaxis_title='Category',
230
+ yaxis_title='Score',
231
+ barmode='group',
232
+ height=500,
233
+ showlegend=True,
234
+ legend=dict(
235
+ orientation="h",
236
+ yanchor="bottom",
237
+ y=1.02,
238
+ xanchor="right",
239
+ x=1
240
+ )
241
+ )
242
+
243
+ return fig
244
+
245
+ def metric_visual_tab():
246
+ # ๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ์ฒซ ๋ฒˆ์งธ ์‹œ๊ฐํ™” ๋ถ€๋ถ„
247
+ df = sheet2df(sheet_name="metric")
248
+ avg_metrics_df = calculate_avg_metrics(df)
249
+
250
+ # ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๋ฉ”ํŠธ๋ฆญ ๋ฆฌ์ŠคํŠธ
251
+ all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
252
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
253
+
254
+ with gr.Tab("๐Ÿ“Š Performance Visualization"):
255
+ with gr.Row():
256
+ metrics_multiselect = gr.CheckboxGroup(
257
+ choices=all_metrics,
258
+ value=[], # ์ดˆ๊ธฐ ์„ ํƒ ์—†์Œ
259
+ label="Select Performance Metrics",
260
+ interactive=True
261
+ )
262
+
263
+ performance_plot = gr.Plot()
264
+
265
+ def update_plot(selected_metrics):
266
+ if not selected_metrics:
267
+ return None
268
+ try:
269
+ sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
270
+ return create_performance_chart(sorted_df, selected_metrics)
271
+ except Exception as e:
272
+ print(f"Error in update_plot: {str(e)}")
273
+ return None
274
+
275
+ metrics_multiselect.change(
276
+ fn=update_plot,
277
+ inputs=[metrics_multiselect],
278
+ outputs=[performance_plot]
279
+ )
280
+
281
+ # ๋‘ ๋ฒˆ์งธ ์‹œ๊ฐํ™” ์„น์…˜
282
+ gr.Markdown("## Detailed Model Analysis")
283
+
284
+ with gr.Row():
285
+ # ๋ชจ๋ธ ์„ ํƒ
286
+ model_dropdown = gr.Dropdown(
287
+ choices=sorted(df['Model name'].unique().tolist()),
288
+ label="Select Model",
289
+ interactive=True
290
+ )
291
+
292
+ # ์ปฌ๋Ÿผ ์„ ํƒ (Model name ์ œ์™ธ)
293
+ column_dropdown = gr.Dropdown(
294
+ choices=[col for col in df.columns if col != 'Model name'],
295
+ label="Select Metric Column",
296
+ interactive=True
297
+ )
298
+
299
+ # ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ
300
+ category_dropdown = gr.Dropdown(
301
+ choices=['falldown', 'violence', 'fire'],
302
+ label="Select Category",
303
+ interactive=True
304
+ )
305
+
306
+ # ํ˜ผ๋™ ํ–‰๋ ฌ ์‹œ๊ฐํ™”
307
+ with gr.Row():
308
+ with gr.Column(scale=1):
309
+ gr.Markdown("") # ๋นˆ ๊ณต๊ฐ„
310
+ with gr.Column(scale=2):
311
+ confusion_matrix_plot = gr.Plot(container=True) # container=True ์ถ”๊ฐ€
312
+ with gr.Column(scale=1):
313
+ gr.Markdown("") # ๋นˆ ๊ณต๊ฐ„
314
+
315
+ with gr.Column(scale=2):
316
+ # ์„ฑ๋Šฅ ์ง€ํ‘œ ์„ ํƒ
317
+ metrics_select = gr.CheckboxGroup(
318
+ choices=['accuracy', 'precision', 'recall', 'specificity', 'f1',
319
+ 'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'],
320
+ value=['accuracy'], # ๊ธฐ๋ณธ๊ฐ’
321
+ label="Select Metrics to Display",
322
+ interactive=True
323
+ )
324
+ category_metrics_plot = gr.Plot()
325
+
326
+ def update_visualizations(model, column, category, selected_metrics):
327
+ if not all([model, column]): # category๋Š” ํ˜ผ๋™ํ–‰๋ ฌ์—๋งŒ ํ•„์š”
328
+ return None, None
329
+
330
+ try:
331
+ # ์„ ํƒ๋œ ๋ชจ๋ธ์˜ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
332
+ selected_data = df[df['Model name'] == model][column].iloc[0]
333
+ metrics = str2json(selected_data)
334
+
335
+ if not metrics:
336
+ return None, None
337
+
338
+ # ํ˜ผ๋™ ํ–‰๋ ฌ (์™ผ์ชฝ)
339
+ confusion_fig = create_confusion_matrix(metrics, category) if category else None
340
+
341
+ # ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ์„ฑ๋Šฅ ์ง€ํ‘œ (์˜ค๋ฅธ์ชฝ)
342
+ if not selected_metrics:
343
+ selected_metrics = ['accuracy']
344
+ category_fig = create_category_metrics_chart(metrics, selected_metrics)
345
+
346
+ return confusion_fig, category_fig
347
+
348
+ except Exception as e:
349
+ print(f"Error updating visualizations: {str(e)}")
350
+ return None, None
351
+
352
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
353
+ for input_component in [model_dropdown, column_dropdown, category_dropdown, metrics_select]:
354
+ input_component.change(
355
+ fn=update_visualizations,
356
+ inputs=[model_dropdown, column_dropdown, category_dropdown, metrics_select],
357
+ outputs=[confusion_matrix_plot, category_metrics_plot]
358
+ )
359
+ # def update_confusion_matrix(model, column, category):
360
+ # if not all([model, column, category]):
361
+ # return None
362
+
363
+ # try:
364
+ # # ์„ ํƒ๋œ ๋ชจ๋ธ์˜ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
365
+ # selected_data = df[df['Model name'] == model][column].iloc[0]
366
+ # metrics = str2json(selected_data)
367
+
368
+ # if metrics and category in metrics:
369
+ # category_data = metrics[category]
370
+
371
+ # # ํ˜ผ๋™ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ
372
+ # confusion_data = {
373
+ # 'tp': category_data['tp'],
374
+ # 'tn': category_data['tn'],
375
+ # 'fp': category_data['fp'],
376
+ # 'fn': category_data['fn']
377
+ # }
378
+
379
+ # # ํžˆํŠธ๋งต ์ƒ์„ฑ
380
+ # z = [[confusion_data['tn'], confusion_data['fp']],
381
+ # [confusion_data['fn'], confusion_data['tp']]]
382
+
383
+ # fig = go.Figure(data=go.Heatmap(
384
+ # z=z,
385
+ # x=['Negative', 'Positive'],
386
+ # y=['Negative', 'Positive'],
387
+ # text=[[str(val) for val in row] for row in z],
388
+ # texttemplate="%{text}",
389
+ # textfont={"size": 16},
390
+ # colorscale='Blues',
391
+ # showscale=False
392
+ # ))
393
+
394
+ # fig.update_layout(
395
+ # title=f'Confusion Matrix - {category}',
396
+ # xaxis_title='Predicted',
397
+ # yaxis_title='Actual',
398
+ # width=500,
399
+ # height=500
400
+ # )
401
+
402
+ # return fig
403
+
404
+ # except Exception as e:
405
+ # print(f"Error updating confusion matrix: {str(e)}")
406
+ # return None
407
+
408
+ # # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
409
+ # for dropdown in [model_dropdown, column_dropdown, category_dropdown]:
410
+ # dropdown.change(
411
+ # fn=update_confusion_matrix,
412
+ # inputs=[model_dropdown, column_dropdown, category_dropdown],
413
+ # outputs=confusion_matrix_plot
414
+ # )
415
+
416
+
417
+
418
+
leaderboard_ui/tab/submit_tab.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
3
+ import pandas as pd
4
+
5
+ def list_to_dataframe(data):
6
+ """
7
+ ๋ฆฌ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜.
8
+ ๊ฐ ๊ฐ’์ด ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์˜ ํ•œ ํ–‰(row)์— ๋“ค์–ด๊ฐ€๋„๋ก ์„ค์ •.
9
+
10
+ :param data: ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ์˜ ๋ฐ์ดํ„ฐ
11
+ :return: pandas.DataFrame
12
+ """
13
+ if not isinstance(data, list):
14
+ raise ValueError("์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋Š” ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
15
+
16
+ # ์—ด ์ด๋ฆ„์„ ๋ฌธ์ž์—ด๋กœ ์„ค์ •
17
+ headers = [f"Queue {i}" for i in range(len(data))]
18
+ df = pd.DataFrame([data], columns=headers)
19
+ return df
20
+
21
+ def model_submit(model_id , benchmark_name, prompt_cfg_name):
22
+ model_id = model_id.split("/")[-1]
23
+ sheet_manager = SheetManager()
24
+ sheet_manager.push(model_id)
25
+ model_q = list_to_dataframe(sheet_manager.get_all_values())
26
+ sheet_manager.change_column("benchmark_name")
27
+ sheet_manager.push(benchmark_name)
28
+ sheet_manager.change_column("prompt_cfg_name")
29
+ sheet_manager.push(prompt_cfg_name)
30
+
31
+ return model_q
32
+
33
+ def read_queue():
34
+ sheet_manager = SheetManager()
35
+ return list_to_dataframe(sheet_manager.get_all_values())
36
+
37
+ def submit_tab():
38
+ with gr.Tab("๐Ÿš€ Submit here! "):
39
+ with gr.Row():
40
+ gr.Markdown("# โœ‰๏ธโœจ Submit your Result here!")
41
+
42
+ with gr.Row():
43
+ with gr.Tab("Model"):
44
+ with gr.Row():
45
+ with gr.Column():
46
+ model_id_textbox = gr.Textbox(
47
+ label="huggingface_id",
48
+ placeholder="PIA-SPACE-LAB/T2V_CLIP4Clip",
49
+ interactive = True
50
+ )
51
+ benchmark_name_textbox = gr.Textbox(
52
+ label="benchmark_name",
53
+ placeholder="PiaFSV",
54
+ interactive = True,
55
+ value="PIA"
56
+ )
57
+ prompt_cfg_name_textbox = gr.Textbox(
58
+ label="prompt_cfg_name",
59
+ placeholder="topk",
60
+ interactive = True,
61
+ value="topk"
62
+ )
63
+ with gr.Column():
64
+ gr.Markdown("## ํ‰๊ฐ€๋ฅผ ๋ฐ›์•„๋ณด์„ธ์š” ๋ฐ˜๋“œ์‹œ ํ—ˆ๊น…ํŽ˜์ด์Šค์— ์—…๋กœ๋“œ๋œ ๋ชจ๋ธ์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
65
+ gr.Markdown("#### ํ˜„์žฌ ํ‰๊ฐ€ ๋Œ€๊ธฐ์ค‘ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.")
66
+ model_queue = gr.Dataframe()
67
+ refresh_button = gr.Button("refresh")
68
+ refresh_button.click(
69
+ fn=read_queue,
70
+ outputs=model_queue
71
+ )
72
+ with gr.Row():
73
+ model_submit_button = gr.Button("Submit Eval")
74
+ model_submit_button.click(
75
+ fn=model_submit,
76
+ inputs=[model_id_textbox,
77
+ benchmark_name_textbox ,
78
+ prompt_cfg_name_textbox],
79
+ outputs=model_queue
80
+ )
81
+ with gr.Tab("Prompt"):
82
+ with gr.Row():
83
+ with gr.Column():
84
+ prompt_cfg_selector = gr.Dropdown(
85
+ choices=["์ „๋ถ€"],
86
+ label="Prompt_CFG",
87
+ multiselect=False,
88
+ value=None,
89
+ interactive=True,
90
+ )
91
+ weight_type = gr.Dropdown(
92
+ choices=["์ „๋ถ€"],
93
+ label="Weights type",
94
+ multiselect=False,
95
+ value=None,
96
+ interactive=True,
97
+ )
98
+ with gr.Column():
99
+ gr.Markdown("## ํ‰๊ฐ€๋ฅผ ๋ฐ›์•„๋ณด์„ธ์š” ๋ฐ˜๋“œ์‹œ ํ—ˆ๊น…ํŽ˜์ด์Šค์— ์—…๋กœ๋“œ๋œ ๋ชจ๋ธ์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
100
+
101
+ with gr.Row():
102
+ prompt_submit_button = gr.Button("Submit Eval")
103
+
pia_bench/bench.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from devmacs_core.devmacs_core import DevMACSCore
4
+ import json
5
+ from typing import Dict, List, Tuple
6
+ from pathlib import Path
7
+ import pandas as pd
8
+ from utils.except_dir import cust_listdir
9
+ from utils.parser import load_config
10
+ from utils.logger import custom_logger
11
+
12
+ logger = custom_logger(__name__)
13
+
14
+ DATA_SET = "dataset"
15
+ CFG = "CFG"
16
+ VECTOR = "vector"
17
+ TEXT = "text"
18
+ VIDEO = "video"
19
+ EXECPT = ["@eaDir", "README.md"]
20
+ ALRAM = "alarm"
21
+ METRIC = "metric"
22
+ MSRVTT = "MSRVTT"
23
+ MODEL = "models"
24
+
25
+ class PiaBenchMark:
26
+ def __init__(self, benchmark_path :str, cfg_target_path : str = None , model_name : str = MSRVTT , token:str =None):
27
+ """
28
+ PIA ๋ฒค์น˜๋งˆํฌ ์‹œ์Šคํ…œ์„ ๊ตฌ์ถ• ์œ„ํ•œ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค.
29
+ ๋ฐ์ดํ„ฐ์…‹ ํด๋”๊ตฌ์กฐ, ๋ฒกํ„ฐ ์ถ”์ถœ, ๊ตฌ์กฐ ์ƒ์„ฑ ๋“ฑ์˜ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
30
+
31
+ Attributes:
32
+ benchmark_path (str): ๋ฒค์น˜๋งˆํฌ ๊ธฐ๋ณธ ๊ฒฝ๋กœ
33
+ cfg_target_path (str): ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ
34
+ model_name (str): ์‚ฌ์šฉํ•  ๋ชจ๋ธ ์ด๋ฆ„
35
+ token (str): ์ธ์ฆ ํ† ํฐ
36
+ categories (List[str]): ์ฒ˜๋ฆฌํ•  ์นดํ…Œ๊ณ ๋ฆฌ ๋ชฉ๋ก
37
+ """
38
+ self.benchmark_path = benchmark_path
39
+ self.token = token
40
+ self.model_name = model_name
41
+ self.devmacs_core = None
42
+ self.cfg_target_path = cfg_target_path
43
+ self.cfg_name = Path(cfg_target_path).stem
44
+ self.cfg_dict = load_config(self.cfg_target_path)
45
+
46
+ self.dataset_path = os.path.join(benchmark_path, DATA_SET)
47
+ self.cfg_path = os.path.join(benchmark_path , CFG)
48
+
49
+ self.model_path = os.path.join(self.benchmark_path , MODEL)
50
+ self.model_name_path = os.path.join(self.model_path ,self.model_name)
51
+ self.model_name_cfg_path = os.path.join(self.model_name_path , CFG)
52
+ self.model_name_cfg_name_path = os.path.join(self.model_name_cfg_path , self.cfg_name)
53
+ self.alram_path = os.path.join(self.model_name_cfg_name_path , ALRAM)
54
+ self.metric_path = os.path.join(self.model_name_cfg_name_path , METRIC)
55
+
56
+ self.vector_path = os.path.join(self.model_name_path , VECTOR)
57
+ self.vector_text_path = os.path.join(self.vector_path , TEXT)
58
+ self.vector_video_path = os.path.join(self.vector_path , VIDEO)
59
+
60
+ self.categories = []
61
+
62
+ def _create_frame_labels(self, label_data: Dict, total_frames: int) -> pd.DataFrame:
63
+ """
64
+ ํ”„๋ ˆ์ž„ ๊ธฐ๋ฐ˜์˜ ๋ ˆ์ด๋ธ” ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
65
+
66
+ Args:
67
+ label_data (Dict): ๋ ˆ์ด๋ธ” ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ๋”•์…”๋„ˆ๋ฆฌ
68
+ total_frames (int): ์ด ํ”„๋ ˆ์ž„ ์ˆ˜
69
+
70
+ Returns:
71
+ pd.DataFrame: ํ”„๋ ˆ์ž„๋ณ„ ๋ ˆ์ด๋ธ”์ด ์ €์žฅ๋œ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„
72
+
73
+ Note:
74
+ ๋ฐ˜ํ™˜๋˜๋Š” ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์€ ๊ฐ ํ”„๋ ˆ์ž„๋ณ„๋กœ ์นดํ…Œ๊ณ ๋ฆฌ์˜ ์กด์žฌ ์—ฌ๋ถ€๋ฅผ 0๊ณผ 1๋กœ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
75
+ """
76
+ colmuns = ['frame'] + sorted(self.categories)
77
+ df = pd.DataFrame(0, index=range(total_frames), columns=colmuns)
78
+ df['frame'] = range(total_frames)
79
+
80
+ for clip_info in label_data['clips'].values():
81
+ category = clip_info['category']
82
+ if category in self.categories: # ํ•ด๋‹น ์นดํ…Œ๊ณ ๋ฆฌ๊ฐ€ ๋ชฉ๋ก์— ์žˆ๋Š” ๊ฒฝ์šฐ๋งŒ ์ฒ˜๋ฆฌ
83
+ start_frame, end_frame = clip_info['timestamp']
84
+ df.loc[start_frame:end_frame, category] = 1
85
+
86
+ return df
87
+
88
+ def preprocess_label_to_csv(self):
89
+ """
90
+ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ชจ๋“  JSON ๋ ˆ์ด๋ธ” ํŒŒ์ผ์„ ํ”„๋ ˆ์ž„ ๊ธฐ๋ฐ˜ CSV ํŒŒ์ผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
91
+
92
+ Raises:
93
+ ValueError: JSON ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ ๋ฐœ์ƒ
94
+
95
+ Note:
96
+ - ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ ํด๋” ๋‚ด์˜ JSON ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
97
+ - ์ด๋ฏธ CSV๋กœ ๋ณ€ํ™˜๋œ ํŒŒ์ผ์€ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.
98
+ """
99
+ json_files = []
100
+ csv_files = []
101
+
102
+ # categories๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ์—๋งŒ ์ฑ„์šฐ๋„๋ก ์ˆ˜์ •
103
+ if not self.categories:
104
+ for cate in cust_listdir(self.dataset_path):
105
+ if os.path.isdir(os.path.join(self.dataset_path, cate)):
106
+ self.categories.append(cate)
107
+
108
+ for category in self.categories:
109
+ category_path = os.path.join(self.dataset_path, category)
110
+ category_jsons = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.json')]
111
+ json_files.extend(category_jsons)
112
+ category_csvs = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.csv')]
113
+ csv_files.extend(category_csvs)
114
+
115
+ if not json_files:
116
+ logger.error("No JSON files found in any category directory")
117
+ raise ValueError("No JSON files found in any category directory")
118
+
119
+ if len(json_files) == len(csv_files):
120
+ logger.info("All JSON files have already been processed to CSV. No further processing needed.")
121
+ return
122
+
123
+ for json_file in json_files:
124
+ json_path = os.path.join(self.dataset_path, json_file)
125
+ video_name = os.path.splitext(json_file)[0]
126
+
127
+ label_info = load_config(json_path)
128
+ video_info = label_info['video_info']
129
+ total_frames = video_info['total_frame']
130
+
131
+ df = self._create_frame_labels( label_info, total_frames)
132
+
133
+ output_path = os.path.join(self.dataset_path, f"{video_name}.csv")
134
+ df.to_csv(output_path , index=False)
135
+ logger.info("Complete !")
136
+
137
+ def preprocess_structure(self):
138
+ """
139
+ ๋ฒค์น˜๋งˆํฌ ์‹œ์Šคํ…œ์— ํ•„์š”ํ•œ ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
140
+
141
+ ์ƒ์„ฑ๋˜๋Š” ๊ตฌ์กฐ:
142
+ - dataset/: ๋ฐ์ดํ„ฐ์…‹ ์ €์žฅ
143
+ - cfg/: ์„ค์ • ํŒŒ์ผ ์ €์žฅ
144
+ - vector/: ์ถ”์ถœ๋œ ๋ฒกํ„ฐ ์ €์žฅ
145
+ - alarm/: ์•Œ๋žŒ ๊ด€๋ จ ํŒŒ์ผ ์ €์žฅ
146
+ - metric/: ํ‰๊ฐ€ ์ง€ํ‘œ ์ €์žฅ
147
+
148
+ Note:
149
+ ๊ธฐ์กด ์นดํ…Œ๊ณ ๋ฆฌ ๊ตฌ์กฐ๊ฐ€ ์žˆ๋‹ค๋ฉด ์œ ์ง€ํ•˜๊ณ , ์—†๋‹ค๋ฉด ์ƒˆ๋กœ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
150
+ """
151
+ logger.info("Starting directory structure preprocessing...")
152
+ os.makedirs(self.dataset_path, exist_ok=True)
153
+ os.makedirs(self.cfg_path, exist_ok=True)
154
+ os.makedirs(self.vector_text_path, exist_ok=True)
155
+ os.makedirs(self.vector_video_path, exist_ok=True)
156
+ os.makedirs(self.alram_path, exist_ok=True)
157
+ os.makedirs(self.metric_path, exist_ok=True)
158
+ os.makedirs(self.model_name_cfg_name_path , exist_ok=True)
159
+
160
+
161
+ # dataset ํด๋”๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•˜๊ณ  ๊ทธ ์•ˆ์— ์นดํ…Œ๊ณ ๋ฆฌ ํด๋”๋“ค์ด ์žˆ๋Š”์ง€ ํ™•์ธ
162
+ if os.path.exists(self.dataset_path) and any(os.path.isdir(os.path.join(self.dataset_path, d)) for d in cust_listdir(self.dataset_path)):
163
+ # ์ด๋ฏธ ๊ตฌ์„ฑ๋œ ๊ตฌ์กฐ๋ผ๋ฉด, dataset ํด๋”์—์„œ ์นดํ…Œ๊ณ ๋ฆฌ๋“ค์„ ๊ฐ€์ ธ์˜ด
164
+ self.categories = [d for d in cust_listdir(self.dataset_path) if os.path.isdir(os.path.join(self.dataset_path, d))]
165
+ else:
166
+ # ์ฒ˜์Œ ์‹คํ–‰๋˜๋Š” ๊ฒฝ์šฐ, ๊ธฐ์กด ๋กœ์ง๋Œ€๋กœ ์ง„ํ–‰
167
+ for item in cust_listdir(self.benchmark_path):
168
+ item_path = os.path.join(self.benchmark_path, item)
169
+
170
+ if item.startswith("@") or item in [METRIC ,"README.md",MODEL, CFG, DATA_SET, VECTOR, ALRAM] or not os.path.isdir(item_path):
171
+ continue
172
+ target_path = os.path.join(self.dataset_path, item)
173
+ if not os.path.exists(target_path):
174
+ shutil.move(item_path, target_path)
175
+ self.categories.append(item)
176
+
177
+ for category in self.categories:
178
+ category_path = os.path.join(self.vector_video_path, category)
179
+ os.makedirs(category_path, exist_ok=True)
180
+
181
+ logger.info("Folder preprocessing completed.")
182
+
183
+ def extract_visual_vector(self):
184
+ """
185
+ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์‹œ๊ฐ์  ํŠน์ง• ๋ฒกํ„ฐ๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
186
+
187
+ Note:
188
+ - Hugging Face ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํŠน์ง•์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
189
+ - ์ถ”์ถœ๋œ ๋ฒกํ„ฐ๋Š” vector/video/ ๊ฒฝ๋กœ์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
190
+
191
+ Requires:
192
+ DevMACSCore๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์–ด ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
193
+ """
194
+ logger.info(f"Starting visual vector extraction using model: {self.model_name}")
195
+ try:
196
+ self.devmacs_core = DevMACSCore.from_huggingface(token=self.token, repo_id=f"PIA-SPACE-LAB/{self.model_name}")
197
+ self.devmacs_core.save_visual_results(
198
+ vid_dir = self.dataset_path,
199
+ result_dir = self.vector_video_path
200
+ )
201
+ except Exception as e:
202
+ logger.error(f"Error during vector extraction: {str(e)}")
203
+ raise
204
+
205
+ if __name__ == "__main__":
206
+ from dotenv import load_dotenv
207
+ import os
208
+ load_dotenv()
209
+
210
+ access_token = os.getenv("ACCESS_TOKEN")
211
+ model_name = "T2V_CLIP4CLIP_MSRVTT"
212
+
213
+ benchmark_path = "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA"
214
+ cfg_target_path= "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA/CFG/topk.json"
215
+
216
+ pia_benchmark = PiaBenchMark(benchmark_path ,model_name=model_name, cfg_target_path= cfg_target_path , token=access_token )
217
+ pia_benchmark.preprocess_structure()
218
+ pia_benchmark.preprocess_label_to_csv()
219
+ print("Categories identified:", pia_benchmark.categories)
pia_bench/checker/bench_checker.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Dict, Optional, Tuple
4
+ from pathlib import Path
5
+ import json
6
+ import numpy as np
7
+ logging.basicConfig(level=logging.INFO)
8
+
9
+ class BenchChecker:
10
+ def __init__(self, base_path: str):
11
+ """Initialize BenchChecker with base assets path.
12
+
13
+ Args:
14
+ base_path (str): Base path to assets directory containing benchmark folders
15
+ """
16
+ self.base_path = Path(base_path)
17
+ self.logger = logging.getLogger(__name__)
18
+
19
+ def check_benchmark_exists(self, benchmark_name: str) -> bool:
20
+ """Check if benchmark folder exists."""
21
+ benchmark_path = self.base_path / benchmark_name
22
+ exists = benchmark_path.exists() and benchmark_path.is_dir()
23
+ if exists:
24
+ self.logger.info(f"Found benchmark directory: {benchmark_name}")
25
+ else:
26
+ self.logger.error(f"Benchmark directory not found: {benchmark_name}")
27
+ return exists
28
+
29
+ def get_video_list(self, benchmark_name: str) -> List[str]:
30
+ """Get list of videos from benchmark's dataset directory. Return empty list if no videos found."""
31
+ dataset_path = self.base_path / benchmark_name / "dataset"
32
+ videos = []
33
+
34
+ if not dataset_path.exists():
35
+ self.logger.info(f"Dataset directory exists but no videos found for {benchmark_name}")
36
+ return videos # ๋นˆ ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜
37
+
38
+ # Recursively find all .mp4 files
39
+ for category in dataset_path.glob("*"):
40
+ if category.is_dir():
41
+ for video_file in category.glob("*.mp4"):
42
+ videos.append(video_file.stem)
43
+
44
+ self.logger.info(f"Found {len(videos)} videos in {benchmark_name} dataset")
45
+ return videos
46
+
47
+ def check_model_exists(self, benchmark_name: str, model_name: str) -> bool:
48
+ """Check if model directory exists in benchmark's models directory."""
49
+ model_path = self.base_path / benchmark_name / "models" / model_name
50
+ exists = model_path.exists() and model_path.is_dir()
51
+ if exists:
52
+ self.logger.info(f"Found model directory: {model_name}")
53
+ else:
54
+ self.logger.error(f"Model directory not found: {model_name}")
55
+ return exists
56
+
57
+ def check_cfg_files(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Tuple[bool, bool]:
58
+ """Check if CFG files/directories exist in both benchmark and model directories."""
59
+ # Check benchmark CFG json
60
+ benchmark_cfg = self.base_path / benchmark_name / "CFG" / f"{cfg_prompt}.json"
61
+ benchmark_cfg_exists = benchmark_cfg.exists() and benchmark_cfg.is_file()
62
+
63
+ # Check model CFG directory
64
+ model_cfg = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt
65
+ model_cfg_exists = model_cfg.exists() and model_cfg.is_dir()
66
+
67
+ if benchmark_cfg_exists:
68
+ self.logger.info(f"Found benchmark CFG file: {cfg_prompt}.json")
69
+ else:
70
+ self.logger.error(f"Benchmark CFG file not found: {cfg_prompt}.json")
71
+
72
+ if model_cfg_exists:
73
+ self.logger.info(f"Found model CFG directory: {cfg_prompt}")
74
+ else:
75
+ self.logger.error(f"Model CFG directory not found: {cfg_prompt}")
76
+
77
+ return benchmark_cfg_exists, model_cfg_exists
78
+ def check_vector_files(self, benchmark_name: str, model_name: str, video_list: List[str]) -> bool:
79
+ """Check if video vectors match with dataset."""
80
+ vector_path = self.base_path / benchmark_name / "models" / model_name / "vector" / "video"
81
+
82
+ # ๋น„๋””์˜ค๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ๋Š” ๋ฌด์กฐ๊ฑด False
83
+ if not video_list:
84
+ self.logger.error("No videos found in dataset - cannot proceed")
85
+ return False
86
+
87
+ # ๋ฒกํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ
88
+ if not vector_path.exists():
89
+ self.logger.error("Vector directory doesn't exist")
90
+ return False
91
+
92
+ # ๋ฒกํ„ฐ ํŒŒ์ผ ๋ฆฌ์ŠคํŠธ ๊ฐ€์ ธ์˜ค๊ธฐ
93
+ # vector_files = [f.stem for f in vector_path.glob("*.npy")]
94
+ vector_files = [f.stem for f in vector_path.rglob("*.npy")]
95
+
96
+ missing_vectors = set(video_list) - set(vector_files)
97
+ extra_vectors = set(vector_files) - set(video_list)
98
+
99
+ if missing_vectors:
100
+ self.logger.error(f"Missing vectors for videos: {missing_vectors}")
101
+ return False
102
+ if extra_vectors:
103
+ self.logger.error(f"Extra vectors found: {extra_vectors}")
104
+ return False
105
+
106
+ self.logger.info(f"Vector status: videos={len(video_list)}, vectors={len(vector_files)}")
107
+ return len(video_list) == len(vector_files)
108
+
109
+ def check_metrics_file(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> bool:
110
+ """Check if overall_metrics.json exists in the model's CFG/metrics directory."""
111
+ metrics_path = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt / "metric" / "overall_metrics.json"
112
+ exists = metrics_path.exists() and metrics_path.is_file()
113
+
114
+ if exists:
115
+ self.logger.info(f"Found overall metrics file for {model_name}")
116
+ else:
117
+ self.logger.error(f"Overall metrics file not found for {model_name}")
118
+ return exists
119
+
120
+ def check_benchmark(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Dict[str, bool]:
121
+ """
122
+ Perform all benchmark checks and return status.
123
+ """
124
+ status = {
125
+ 'benchmark_exists': False,
126
+ 'model_exists': False,
127
+ 'cfg_files_exist': False,
128
+ 'vectors_match': False,
129
+ 'metrics_exist': False
130
+ }
131
+
132
+ # Check benchmark directory
133
+ status['benchmark_exists'] = self.check_benchmark_exists(benchmark_name)
134
+ if not status['benchmark_exists']:
135
+ return status
136
+
137
+ # Get video list
138
+ video_list = self.get_video_list(benchmark_name)
139
+
140
+ # Check model directory
141
+ status['model_exists'] = self.check_model_exists(benchmark_name, model_name)
142
+ if not status['model_exists']:
143
+ return status
144
+
145
+ # Check CFG files
146
+ benchmark_cfg, model_cfg = self.check_cfg_files(benchmark_name, model_name, cfg_prompt)
147
+ status['cfg_files_exist'] = benchmark_cfg and model_cfg
148
+ if not status['cfg_files_exist']:
149
+ return status
150
+
151
+ # Check vectors
152
+ status['vectors_match'] = self.check_vector_files(benchmark_name, model_name, video_list)
153
+
154
+ # Check metrics file (only if vectors match)
155
+ if status['vectors_match']:
156
+ status['metrics_exist'] = self.check_metrics_file(benchmark_name, model_name, cfg_prompt)
157
+
158
+ return status
159
+
160
+ def get_benchmark_status(self, check_status: Dict[str, bool]) -> str:
161
+ """Determine which execution path to take based on check results."""
162
+ basic_checks = ['benchmark_exists', 'model_exists', 'cfg_files_exist']
163
+ if not all(check_status[check] for check in basic_checks):
164
+ return "cannot_execute"
165
+ if check_status['vectors_match'] and check_status['metrics_exist']:
166
+ return "all_passed"
167
+ elif not check_status['vectors_match']:
168
+ return "no_vectors"
169
+ else: # vectors exist but no metrics
170
+ return "no_metrics"
171
+
172
+ # Example usage
173
+ if __name__ == "__main__":
174
+
175
+ bench_checker = BenchChecker("assets")
176
+ status = bench_checker.check_benchmark(
177
+ benchmark_name="huggingface_benchmarks_dataset",
178
+ model_name="MSRVTT",
179
+ cfg_prompt="topk"
180
+ )
181
+
182
+ execution_path = bench_checker.get_benchmark_status(status)
183
+ print(f"Checks completed. Execution path: {execution_path}")
184
+ print(f"Status: {status}")
pia_bench/checker/sheet_checker.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Set, Tuple
2
+ import logging
3
+ import gspread
4
+ from dotenv import load_dotenv
5
+ from typing import Optional, List
6
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
7
+
8
+ load_dotenv()
9
+ class SheetChecker:
10
+ def __init__(self, sheet_manager):
11
+ """Initialize SheetChecker with a sheet manager instance."""
12
+ self.sheet_manager = sheet_manager
13
+ self.bench_sheet_manager = None
14
+ self.logger = logging.getLogger(__name__)
15
+ self._init_bench_sheet()
16
+
17
+ def _init_bench_sheet(self):
18
+ """Initialize sheet manager for the model sheet."""
19
+ self.bench_sheet_manager = type(self.sheet_manager)(
20
+ spreadsheet_url=self.sheet_manager.spreadsheet_url,
21
+ worksheet_name="model",
22
+ column_name="Model name"
23
+ )
24
+ def add_benchmark_column(self, column_name: str):
25
+ """Add a new benchmark column to the sheet."""
26
+ try:
27
+ # Get current headers
28
+ headers = self.bench_sheet_manager.get_available_columns()
29
+
30
+ # If column already exists, return
31
+ if column_name in headers:
32
+ return
33
+
34
+ # Add new column header
35
+ new_col_index = len(headers) + 1
36
+ cell = gspread.utils.rowcol_to_a1(1, new_col_index)
37
+ # Update with 2D array format
38
+ self.bench_sheet_manager.sheet.update(cell, [[column_name]]) # ๊ฐ’์„ 2D ๋ฐฐ์—ด๋กœ ๋ณ€๊ฒฝ
39
+ self.logger.info(f"Added new benchmark column: {column_name}")
40
+
41
+ # Update headers in bench_sheet_manager
42
+ self.bench_sheet_manager._connect_to_sheet(validate_column=False)
43
+
44
+ except Exception as e:
45
+ self.logger.error(f"Error adding benchmark column {column_name}: {str(e)}")
46
+ raise
47
+ def validate_benchmark_columns(self, benchmark_columns: List[str]) -> Tuple[List[str], List[str]]:
48
+ """
49
+ Validate benchmark columns and add missing ones.
50
+
51
+ Args:
52
+ benchmark_columns: List of benchmark column names to validate
53
+
54
+ Returns:
55
+ Tuple[List[str], List[str]]: (valid columns, invalid columns)
56
+ """
57
+ available_columns = self.bench_sheet_manager.get_available_columns()
58
+ valid_columns = []
59
+ invalid_columns = []
60
+
61
+ for col in benchmark_columns:
62
+ if col in available_columns:
63
+ valid_columns.append(col)
64
+ else:
65
+ try:
66
+ self.add_benchmark_column(col)
67
+ valid_columns.append(col)
68
+ self.logger.info(f"Added new benchmark column: {col}")
69
+ except Exception as e:
70
+ invalid_columns.append(col)
71
+ self.logger.error(f"Failed to add benchmark column '{col}': {str(e)}")
72
+
73
+ return valid_columns, invalid_columns
74
+
75
+ def check_model_and_benchmarks(self, model_name: str, benchmark_columns: List[str]) -> Dict[str, List[str]]:
76
+ """
77
+ Check model existence and which benchmarks need to be filled.
78
+
79
+ Args:
80
+ model_name: Name of the model to check
81
+ benchmark_columns: List of benchmark column names to check
82
+
83
+ Returns:
84
+ Dict with keys:
85
+ 'status': 'model_not_found' or 'model_exists'
86
+ 'empty_benchmarks': List of benchmark columns that need to be filled
87
+ 'filled_benchmarks': List of benchmark columns that are already filled
88
+ 'invalid_benchmarks': List of benchmark columns that don't exist
89
+ """
90
+ result = {
91
+ 'status': '',
92
+ 'empty_benchmarks': [],
93
+ 'filled_benchmarks': [],
94
+ 'invalid_benchmarks': []
95
+ }
96
+
97
+ # First check if model exists
98
+ exists = self.check_model_exists(model_name)
99
+ if not exists:
100
+ result['status'] = 'model_not_found'
101
+ return result
102
+
103
+ result['status'] = 'model_exists'
104
+
105
+ # Validate benchmark columns
106
+ valid_columns, invalid_columns = self.validate_benchmark_columns(benchmark_columns)
107
+ result['invalid_benchmarks'] = invalid_columns
108
+
109
+ if not valid_columns:
110
+ return result
111
+
112
+ # Check which valid benchmarks are empty
113
+ self.bench_sheet_manager.change_column("Model name")
114
+ all_values = self.bench_sheet_manager.get_all_values()
115
+ row_index = all_values.index(model_name) + 2
116
+
117
+ for column in valid_columns:
118
+ try:
119
+ self.bench_sheet_manager.change_column(column)
120
+ value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
121
+ if not value or not value.strip():
122
+ result['empty_benchmarks'].append(column)
123
+ else:
124
+ result['filled_benchmarks'].append(column)
125
+ except Exception as e:
126
+ self.logger.error(f"Error checking column {column}: {str(e)}")
127
+ result['empty_benchmarks'].append(column)
128
+
129
+ return result
130
+
131
+ def update_model_info(self, model_name: str, model_info: Dict[str, str]):
132
+ """Update basic model information columns."""
133
+ try:
134
+ for column_name, value in model_info.items():
135
+ self.bench_sheet_manager.change_column(column_name)
136
+ self.bench_sheet_manager.push(value)
137
+ self.logger.info(f"Successfully added new model: {model_name}")
138
+ except Exception as e:
139
+ self.logger.error(f"Error updating model info: {str(e)}")
140
+ raise
141
+
142
+ def update_benchmarks(self, model_name: str, benchmark_values: Dict[str, str]):
143
+ """
144
+ Update benchmark values.
145
+
146
+ Args:
147
+ model_name: Name of the model
148
+ benchmark_values: Dictionary of benchmark column names and their values
149
+ """
150
+ try:
151
+ self.bench_sheet_manager.change_column("Model name")
152
+ all_values = self.bench_sheet_manager.get_all_values()
153
+ row_index = all_values.index(model_name) + 2
154
+
155
+ for column, value in benchmark_values.items():
156
+ self.bench_sheet_manager.change_column(column)
157
+ self.bench_sheet_manager.sheet.update_cell(row_index, self.bench_sheet_manager.col_index, value)
158
+ self.logger.info(f"Updated benchmark {column} for model {model_name}")
159
+
160
+ except Exception as e:
161
+ self.logger.error(f"Error updating benchmarks: {str(e)}")
162
+ raise
163
+
164
+ def check_model_exists(self, model_name: str) -> bool:
165
+ """Check if model exists in the sheet."""
166
+ try:
167
+ self.bench_sheet_manager.change_column("Model name")
168
+ values = self.bench_sheet_manager.get_all_values()
169
+ return model_name in values
170
+ except Exception as e:
171
+ self.logger.error(f"Error checking model existence: {str(e)}")
172
+ return False
173
+
174
+
175
+ def process_model_benchmarks(
176
+ model_name: str,
177
+ bench_checker: SheetChecker,
178
+ model_info_func,
179
+ benchmark_processor_func: callable,
180
+ benchmark_columns: List[str],
181
+ cfg_prompt: str
182
+ ) -> None:
183
+ """
184
+ Process model benchmarks according to the specified workflow.
185
+
186
+ Args:
187
+ model_name: Name of the model to process
188
+ bench_checker: SheetChecker instance
189
+ model_info_func: Function that returns model info (name, link, etc.)
190
+ benchmark_processor_func: Function that processes empty benchmarks and returns values
191
+ benchmark_columns: List of benchmark columns to check
192
+ """
193
+ try:
194
+ # Check model and benchmarks
195
+ check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
196
+
197
+ # Handle invalid benchmark columns
198
+ if check_result['invalid_benchmarks']:
199
+ bench_checker.logger.warning(
200
+ f"Skipping invalid benchmark columns: {', '.join(check_result['invalid_benchmarks'])}"
201
+ )
202
+
203
+ # If model doesn't exist, add it
204
+ if check_result['status'] == 'model_not_found':
205
+ model_info = model_info_func(model_name)
206
+ bench_checker.update_model_info(model_name, model_info)
207
+ bench_checker.logger.info(f"Added new model: {model_name}")
208
+ # Recheck benchmarks after adding model
209
+ check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
210
+
211
+ # Log filled benchmarks
212
+ if check_result['filled_benchmarks']:
213
+ bench_checker.logger.info(
214
+ f"Skipping filled benchmark columns: {', '.join(check_result['filled_benchmarks'])}"
215
+ )
216
+
217
+ # Process empty benchmarks
218
+ if check_result['empty_benchmarks']:
219
+ bench_checker.logger.info(
220
+ f"Processing empty benchmark columns: {', '.join(check_result['empty_benchmarks'])}"
221
+ )
222
+ # Get benchmark values from processor function
223
+ benchmark_values = benchmark_processor_func(
224
+ model_name,
225
+ check_result['empty_benchmarks'],
226
+ cfg_prompt
227
+ )
228
+ # Update benchmarks
229
+ bench_checker.update_benchmarks(model_name, benchmark_values)
230
+ else:
231
+ bench_checker.logger.info("No empty benchmark columns to process")
232
+
233
+ except Exception as e:
234
+ bench_checker.logger.error(f"Error processing model {model_name}: {str(e)}")
235
+ raise
236
+
237
+ def get_model_info(model_name: str) -> Dict[str, str]:
238
+ return {
239
+ "Model name": model_name,
240
+ "Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
241
+ "Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
242
+
243
+ }
244
+
245
+ def process_benchmarks(
246
+ model_name: str,
247
+ empty_benchmarks: List[str],
248
+ cfg_prompt: str
249
+ ) -> Dict[str, str]:
250
+ """
251
+ Measure benchmark scores for given model with specific configuration.
252
+
253
+ Args:
254
+ model_name: Name of the model to evaluate
255
+ empty_benchmarks: List of benchmarks to measure
256
+ cfg_prompt: Prompt configuration for evaluation
257
+
258
+ Returns:
259
+ Dict[str, str]: Dictionary mapping benchmark names to their scores
260
+ """
261
+ result = {}
262
+ for benchmark in empty_benchmarks:
263
+ # ์‹ค์ œ ๋ฒค์น˜๋งˆํฌ ์ธก์ • ์ˆ˜ํ–‰
264
+ # score = measure_benchmark(model_name, benchmark, cfg_prompt)
265
+ if benchmark == "COCO":
266
+ score = 0.5
267
+ elif benchmark == "ImageNet":
268
+ score = 15.0
269
+ result[benchmark] = str(score)
270
+ return result
271
+ # Example usage
272
+ if __name__ == "__main__":
273
+
274
+ sheet_manager = SheetManager()
275
+ bench_checker = SheetChecker(sheet_manager)
276
+
277
+ process_model_benchmarks(
278
+ "test-model",
279
+ bench_checker,
280
+ get_model_info,
281
+ process_benchmarks,
282
+ ["COCO", "ImageNet"],
283
+ "cfg_prompt_value"
284
+ )
pia_bench/event_alarm.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from typing import Dict, List, Tuple
5
+ from devmacs_core.devmacs_core import DevMACSCore
6
+ # from devmacs_core.devmacs_core_copy import DevMACSCore
7
+ from devmacs_core.utils.common.cal import loose_similarity
8
+ from utils.parser import load_config, PromptManager
9
+ import json
10
+ import pandas as pd
11
+ from tqdm import tqdm
12
+ import logging
13
+ from datetime import datetime
14
+ from utils.except_dir import cust_listdir
15
+ from utils.logger import custom_logger
16
+
17
+ logger = custom_logger(__name__)
18
+
19
+ class EventDetector:
20
+ def __init__(self, config_path: str , model_name:str = None, token:str = None):
21
+ self.config = load_config(config_path)
22
+ self.macs = DevMACSCore.from_huggingface(token=token, repo_id=f"PIA-SPACE-LAB/{model_name}")
23
+ # self.macs = DevMACSCore(model_type="clip4clip_web")
24
+
25
+ self.prompt_manager = PromptManager(config_path)
26
+ self.sentences = self.prompt_manager.sentences
27
+ self.text_vectors = self.macs.get_text_vector(self.sentences)
28
+
29
+ def process_and_save_predictions(self, vector_base_dir: str, label_base_dir: str, save_base_dir: str):
30
+ """๋น„๋””์˜ค ๋ฒกํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ CSV๋กœ ์ €์žฅ"""
31
+
32
+ # ์ „์ฒด ๋น„๋””์˜ค ํŒŒ์ผ ์ˆ˜ ๊ณ„์‚ฐ
33
+ total_videos = sum(len([f for f in cust_listdir(os.path.join(vector_base_dir, d))
34
+ if f.endswith('.npy')])
35
+ for d in cust_listdir(vector_base_dir)
36
+ if os.path.isdir(os.path.join(vector_base_dir, d)))
37
+ pbar = tqdm(total=total_videos, desc="Processing videos")
38
+
39
+ for category in cust_listdir(vector_base_dir):
40
+ category_path = os.path.join(vector_base_dir, category)
41
+ if not os.path.isdir(category_path):
42
+ continue
43
+
44
+ # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
45
+ save_category_dir = os.path.join(save_base_dir, category)
46
+ os.makedirs(save_category_dir, exist_ok=True)
47
+
48
+ for file in cust_listdir(category_path):
49
+ if file.endswith('.npy'):
50
+ video_name = os.path.splitext(file)[0]
51
+ vector_path = os.path.join(category_path, file)
52
+
53
+ # ๋ผ๋ฒจ ํŒŒ์ผ ์ฝ๊ธฐ
54
+ label_path = os.path.join(label_base_dir, category, f"{video_name}.json")
55
+ with open(label_path, 'r') as f:
56
+ label_data = json.load(f)
57
+ total_frames = label_data['video_info']['total_frame']
58
+
59
+ # ์˜ˆ์ธก ๊ฒฐ๊ณผ ์ƒ์„ฑ ๋ฐ ์ €์žฅ
60
+ self._process_and_save_single_video(
61
+ vector_path=vector_path,
62
+ total_frames=total_frames,
63
+ save_path=os.path.join(save_category_dir, f"{video_name}.csv")
64
+ )
65
+ pbar.update(1)
66
+ pbar.close()
67
+
68
+ def _process_and_save_single_video(self, vector_path: str, total_frames: int, save_path: str):
69
+ """๋‹จ์ผ ๋น„๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ์ €์žฅ"""
70
+ # ๊ธฐ๋ณธ ์˜ˆ์ธก ์ˆ˜ํ–‰
71
+ sparse_predictions = self._process_single_vector(vector_path)
72
+
73
+ # ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜ ๋ฐ ํ™•์žฅ
74
+ df = self._expand_predictions(sparse_predictions, total_frames)
75
+
76
+ # CSV๋กœ ์ €์žฅ
77
+ df.to_csv(save_path, index=False)
78
+
79
+ def _process_single_vector(self, vector_path: str) -> Dict:
80
+ """๊ธฐ์กด ์˜ˆ์ธก ๋กœ์ง"""
81
+ video_vector = np.load(vector_path)
82
+ processed_vectors = []
83
+ frame_interval = 15
84
+
85
+ for vector in video_vector:
86
+ v = vector.squeeze(0) # numpy array
87
+ v = torch.from_numpy(v).unsqueeze(0).cuda() # torch tensor๋กœ ๋ณ€ํ™˜ ํ›„ GPU๋กœ
88
+ processed_vectors.append(v)
89
+
90
+ frame_results = {}
91
+ for vector_idx, v in enumerate(processed_vectors):
92
+ actual_frame = vector_idx * frame_interval
93
+ sim_scores = loose_similarity(
94
+ sequence_output=self.text_vectors.cuda(),
95
+ visual_output=v.unsqueeze(1)
96
+ )
97
+ frame_results[actual_frame] = self._calculate_alarms(sim_scores)
98
+
99
+ return frame_results
100
+
101
+ def _expand_predictions(self, sparse_predictions: Dict, total_frames: int) -> pd.DataFrame:
102
+ """์˜ˆ์ธก์„ ์ „์ฒด ํ”„๋ ˆ์ž„์œผ๋กœ ํ™•์žฅ"""
103
+ # ์นดํ…Œ๊ณ ๋ฆฌ ๋ชฉ๋ก ์ถ”์ถœ (์ฒซ ๋ฒˆ์งธ ํ”„๋ ˆ์ž„์˜ ์•Œ๋žŒ ๊ฒฐ๊ณผ์—์„œ)
104
+ first_frame = list(sparse_predictions.keys())[0]
105
+ categories = list(sparse_predictions[first_frame].keys())
106
+
107
+ # ์ „์ฒด ํ”„๋ ˆ์ž„ ์ƒ์„ฑ
108
+ df = pd.DataFrame({'frame': range(total_frames)})
109
+
110
+ # ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ์— ๋Œ€ํ•œ ์•Œ๋žŒ ๊ฐ’ ์ดˆ๊ธฐํ™”
111
+ for category in categories:
112
+ df[category] = 0
113
+
114
+ # ์˜ˆ์ธก๊ฐ’ ์ฑ„์šฐ๊ธฐ
115
+ frame_keys = sorted(sparse_predictions.keys())
116
+ for i in range(len(frame_keys)):
117
+ current_frame = frame_keys[i]
118
+ next_frame = frame_keys[i + 1] if i + 1 < len(frame_keys) else total_frames
119
+
120
+ # ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ์˜ ์•Œ๋žŒ ๊ฐ’ ์„ค์ •
121
+ for category in categories:
122
+ alarm_value = sparse_predictions[current_frame][category]['alarm']
123
+ df.loc[current_frame:next_frame-1, category] = alarm_value
124
+
125
+ return df
126
+
127
+
128
+ def _calculate_alarms(self, sim_scores: torch.Tensor) -> Dict:
129
+ """์œ ์‚ฌ๋„ ์ ์ˆ˜๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐ ์ด๋ฒคํŠธ์˜ ์•Œ๋žŒ ์ƒํƒœ ๊ณ„์‚ฐ"""
130
+ # ๋กœ๊ฑฐ ์„ค์ •
131
+ log_filename = f"alarm_calculation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
132
+ logging.basicConfig(
133
+ filename=log_filename,
134
+ level=logging.ERROR,
135
+ format='%(asctime)s - %(message)s',
136
+ datefmt='%Y-%m-%d %H:%M:%S'
137
+ )
138
+ logger = logging.getLogger(__name__)
139
+
140
+ event_alarms = {}
141
+
142
+ for event_config in self.config['PROMPT_CFG']:
143
+ event = event_config['event']
144
+ top_k = event_config['top_candidates']
145
+ threshold = event_config['alert_threshold']
146
+
147
+ # logger.info(f"\nProcessing event: {event}")
148
+ # logger.info(f"Top K: {top_k}, Threshold: {threshold}")
149
+
150
+ event_prompts = self._get_event_prompts(event)
151
+
152
+ # logger.debug(f"\nEvent Prompts Debug for {event}:")
153
+ # logger.debug(f"Indices: {event_prompts['indices']}")
154
+ # logger.debug(f"Types: {event_prompts['types']}")
155
+ # logger.debug(f"\nSim Scores Debug:")
156
+ # logger.debug(f"Shape: {sim_scores.shape}")
157
+ # logger.debug(f"Raw scores: {sim_scores}")
158
+
159
+ # event_scores = sim_scores[event_prompts['indices']]
160
+ event_scores = sim_scores[event_prompts['indices']].squeeze(-1) # shape ๋ณ€๊ฒฝ
161
+
162
+ # logger.debug(f"Event scores shape: {event_scores.shape}")
163
+ # logger.debug(f"Event scores: {event_scores}")
164
+ # ๊ฐ ํ”„๋กฌํ”„ํŠธ์™€ ์ ์ˆ˜ ์ถœ๋ ฅ
165
+ # logger.info("\nDEBUG VALUES:")
166
+ # logger.info(f"event_scores: {event_scores}")
167
+ # logger.info(f"indices: {event_prompts['indices']}")
168
+ # logger.info(f"types: {event_prompts['types']}")
169
+
170
+ # logger.info("\nAll prompts and scores:")
171
+ # for idx, (score, prompt_type) in enumerate(zip(event_scores, event_prompts['types'])):
172
+ # logger.info(f"Type: {prompt_type}, Score: {score.item():.4f}")
173
+
174
+ top_k_values, top_k_indices = torch.topk(event_scores, min(top_k, len(event_scores)))
175
+
176
+ # logger.info(f"top_k_values: {top_k_values}")
177
+ # logger.info(f"top_k_indices (raw): {top_k_indices}")
178
+ # Top K ๊ฒฐ๊ณผ ์ถœ๋ ฅ
179
+ # logger.info(f"\nTop {top_k} selections:")
180
+ for idx, (value, index) in enumerate(zip(top_k_values, top_k_indices)):
181
+ # indices[index]๊ฐ€ ์•„๋‹Œ index๋ฅผ ์ง์ ‘ ์‚ฌ์šฉ
182
+ prompt_type = event_prompts['types'][index] # ์ˆ˜์ •๋œ ๋ถ€๋ถ„
183
+ # logger.info(f"DEBUG: index={index}, types={event_prompts['types']}, selected_type={prompt_type}")
184
+ # logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
185
+
186
+ abnormal_count = sum(1 for idx in top_k_indices
187
+ if event_prompts['types'][idx] == 'abnormal') # ์ˆ˜์ •๋œ ๋ถ€๋ถ„
188
+ # for idx, (value, orig_idx) in enumerate(zip(top_k_values, top_k_indices)):
189
+ # prompt_type = event_prompts['types'][orig_idx.item()]
190
+ # logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
191
+
192
+ # abnormal_count = sum(1 for idx in top_k_indices
193
+ # if event_prompts['types'][idx.item()] == 'abnormal')
194
+
195
+ # ์•Œ๋žŒ ๊ฒฐ์ • ๊ณผ์ • ์ถœ๋ ฅ
196
+ # logger.info(f"\nAbnormal count: {abnormal_count}")
197
+ alarm_result = 1 if abnormal_count >= threshold else 0
198
+ # logger.info(f"Final alarm decision: {alarm_result}")
199
+ # logger.info("-" * 50)
200
+
201
+ event_alarms[event] = {
202
+ 'alarm': alarm_result,
203
+ 'scores': top_k_values.tolist(),
204
+ 'top_k_types': [event_prompts['types'][idx.item()] for idx in top_k_indices]
205
+ }
206
+
207
+ # ๋กœ๊ฑฐ ์ข…๋ฃŒ
208
+ logging.shutdown()
209
+
210
+ return event_alarms
211
+
212
+ def _get_event_prompts(self, event: str) -> Dict:
213
+ indices = []
214
+ types = []
215
+ current_idx = 0
216
+
217
+ for event_config in self.config['PROMPT_CFG']:
218
+ if event_config['event'] == event:
219
+ for status in ['normal', 'abnormal']:
220
+ for _ in range(len(event_config['prompts'][status])):
221
+ indices.append(current_idx)
222
+ types.append(status)
223
+ current_idx += 1
224
+
225
+ return {'indices': indices, 'types': types}
226
+
227
+
pia_bench/metric.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
5
+ from typing import Dict, List
6
+ import json
7
+ from utils.except_dir import cust_listdir
8
+
9
+ class MetricsEvaluator:
10
+ def __init__(self, pred_dir: str, label_dir: str, save_dir: str):
11
+ """
12
+ Args:
13
+ pred_dir: ์˜ˆ์ธก csv ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
14
+ label_dir: ์ •๋‹ต csv ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
15
+ save_dir: ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
16
+ """
17
+ self.pred_dir = pred_dir
18
+ self.label_dir = label_dir
19
+ self.save_dir = save_dir
20
+
21
+ def evaluate(self) -> Dict:
22
+ """์ „์ฒด ํ‰๊ฐ€ ์ˆ˜ํ–‰"""
23
+ category_metrics = {} # ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ํ‰๊ท  ์„ฑ๋Šฅ ์ €์žฅ
24
+ all_metrics = { # ๋ชจ๋“  ์นดํ…Œ๊ณ ๋ฆฌ ํ†ตํ•ฉ ๋ฉ”ํŠธ๋ฆญ
25
+ 'falldown': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
26
+ 'violence': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
27
+ 'fire': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []}
28
+ }
29
+
30
+ # ๋ชจ๋“  ์นดํ…Œ๊ณ ๋ฆฌ์˜ metrics๋ฅผ ์ €์žฅํ•  DataFrame ๋ฆฌ์ŠคํŠธ
31
+ all_categories_metrics = []
32
+
33
+ for category in cust_listdir(self.pred_dir):
34
+ if not os.path.isdir(os.path.join(self.pred_dir, category)):
35
+ continue
36
+
37
+ pred_category_path = os.path.join(self.pred_dir, category)
38
+ label_category_path = os.path.join(self.label_dir, category)
39
+ save_category_path = os.path.join(self.save_dir, category)
40
+ os.makedirs(save_category_path, exist_ok=True)
41
+
42
+ # ๊ฒฐ๊ณผ ์ €์žฅ์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„ ์ƒ์„ฑ
43
+ metrics_df = self._evaluate_category(category, pred_category_path, label_category_path)
44
+
45
+ metrics_df['category'] = category
46
+
47
+ metrics_df.to_csv(os.path.join(save_category_path, f"{category}_metrics.csv"), index=False)
48
+
49
+ all_categories_metrics.append(metrics_df)
50
+
51
+ # ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ํ‰๊ท  ์„ฑ๋Šฅ ์ €์žฅ
52
+ category_metrics[category] = metrics_df.iloc[-1].to_dict() # ๋งˆ์ง€๋ง‰ row(ํ‰๊ท )
53
+
54
+ # ์ „์ฒด ํ‰๊ท ์„ ์œ„ํ•œ ๋ฉ”ํŠธ๋ฆญ ์ˆ˜์ง‘
55
+ # for col in metrics_df.columns:
56
+ # if col != 'video_name':
57
+ # event_type, metric_type = col.split('_')
58
+ # all_metrics[event_type][metric_type].append(category_metrics[category][col])
59
+
60
+ for col in metrics_df.columns:
61
+ if col != 'video_name':
62
+ try:
63
+ # ์ฒซ ๋ฒˆ์งธ ์–ธ๋”์Šค์ฝ”์–ด๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์ด๋ฒคํŠธ ํƒ€์ž…๊ณผ ๋ฉ”ํŠธ๋ฆญ ํƒ€์ž… ๋ถ„๋ฆฌ
64
+ parts = col.split('_', 1) # maxsplit=1๋กœ ์ฒซ ๋ฒˆ์งธ ์–ธ๋”์Šค์ฝ”์–ด์—์„œ๋งŒ ๋ถ„๋ฆฌ
65
+ if len(parts) == 2:
66
+ event_type, metric_type = parts
67
+ if event_type in all_metrics and metric_type in all_metrics[event_type]:
68
+ all_metrics[event_type][metric_type].append(category_metrics[category][col])
69
+ except Exception as e:
70
+ print(f"Warning: Could not process column {col}: {str(e)}")
71
+ continue
72
+
73
+ # ๊ฐ DataFrame์—์„œ ๋งˆ์ง€๋ง‰ ํ–‰(average)์„ ์ œ๊ฑฐ
74
+ all_categories_metrics_without_avg = [df.iloc[:-1] for df in all_categories_metrics]
75
+ # ๋ชจ๋“  ์นดํ…Œ๊ณ ๋ฆฌ์˜ metrics๋ฅผ ํ•˜๋‚˜์˜ DataFrame์œผ๋กœ ํ•ฉ์น˜๊ธฐ
76
+ combined_metrics_df = pd.concat(all_categories_metrics_without_avg, ignore_index=True)
77
+ # ํ•ฉ์ณ์ง„ metrics๋ฅผ json ํŒŒ์ผ๊ณผ ๊ฐ™์€ ์œ„์น˜์— ์ €์žฅ
78
+ combined_metrics_df.to_csv(os.path.join(self.save_dir, "all_categories_metrics.csv"), index=False)
79
+ # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
80
+ # print("\nCategory-wise Average Metrics:")
81
+ # for category, metrics in category_metrics.items():
82
+ # print(f"\n{category}:")
83
+ # for metric_name, value in metrics.items():
84
+ # if metric_name != "video_name":
85
+ # print(f"{metric_name}: {value:.3f}")
86
+
87
+ print("\nCategory-wise Average Metrics:")
88
+ for category, metrics in category_metrics.items():
89
+ print(f"\n{category}:")
90
+ for metric_name, value in metrics.items():
91
+ if metric_name != "video_name":
92
+ try:
93
+ if isinstance(value, str):
94
+ print(f"{metric_name}: {value}")
95
+ elif metric_name in ['tp', 'tn', 'fp', 'fn']:
96
+ print(f"{metric_name}: {int(value)}")
97
+ else:
98
+ print(f"{metric_name}: {float(value):.3f}")
99
+ except (ValueError, TypeError):
100
+ print(f"{metric_name}: {value}")
101
+ # ์ „์ฒด ํ‰๊ท  ๊ณ„์‚ฐ ๋ฐ ์ถœ๋ ฅ
102
+ print("\n" + "="*50)
103
+ print("Overall Average Metrics Across All Categories:")
104
+ print("="*50)
105
+
106
+ # for event_type in all_metrics:
107
+ # print(f"\n{event_type}:")
108
+ # for metric_type, values in all_metrics[event_type].items():
109
+ # avg_value = np.mean(values)
110
+ # print(f"{metric_type}: {avg_value:.3f}")
111
+
112
+ for event_type in all_metrics:
113
+ print(f"\n{event_type}:")
114
+ for metric_type, values in all_metrics[event_type].items():
115
+ avg_value = np.mean(values)
116
+ if metric_type in ['tp', 'tn', 'fp', 'fn']: # ์ •์ˆ˜ ๊ฐ’
117
+ print(f"{metric_type}: {int(avg_value)}")
118
+ else: # ์†Œ์ˆ˜์  ๊ฐ’
119
+ print(f"{metric_type}: {avg_value:.3f}")
120
+ ##################################################################################################
121
+ # ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ๋”•์…”๋„ˆ๋ฆฌ
122
+ final_results = {
123
+ "category_metrics": {},
124
+ "overall_metrics": {}
125
+ }
126
+ # ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ๋ฉ”ํŠธ๋ฆญ ์ €์žฅ
127
+
128
+ for category, metrics in category_metrics.items():
129
+ final_results["category_metrics"][category] = {}
130
+ for metric_name, value in metrics.items():
131
+ if metric_name != "video_name":
132
+ if isinstance(value, (int, float)):
133
+ final_results["category_metrics"][category][metric_name] = float(value)
134
+
135
+ # ์ „์ฒด ํ‰๊ท  ๊ณ„์‚ฐ ๋ฐ ์ €์žฅ
136
+ for event_type in all_metrics:
137
+ # print(f"\n{event_type}:")
138
+ final_results["overall_metrics"][event_type] = {}
139
+ for metric_type, values in all_metrics[event_type].items():
140
+ avg_value = float(np.mean(values))
141
+ # print(f"{metric_type}: {avg_value:.3f}")
142
+ final_results["overall_metrics"][event_type][metric_type] = avg_value
143
+
144
+ # JSON ํŒŒ์ผ๋กœ ์ €์žฅ
145
+ json_path = os.path.join(self.save_dir, "overall_metrics.json")
146
+ with open(json_path, 'w', encoding='utf-8') as f:
147
+ json.dump(final_results, f, indent=4)
148
+
149
+ # return category_metrics
150
+
151
+ # ๋ˆ„์  ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
152
+ accumulated_metrics = self.calculate_accumulated_metrics(combined_metrics_df)
153
+
154
+ # JSON์— ๋ˆ„์  ๋ฉ”ํŠธ๋ฆญ ์ถ”๊ฐ€
155
+ final_results["accumulated_metrics"] = accumulated_metrics
156
+
157
+ # ๋ˆ„์  ๋ฉ”ํŠธ๋ฆญ๋งŒ ๋”ฐ๋กœ ์ €์žฅ
158
+ accumulated_json_path = os.path.join(self.save_dir, "accumulated_metrics.json")
159
+ with open(accumulated_json_path, 'w', encoding='utf-8') as f:
160
+ json.dump(accumulated_metrics, f, indent=4)
161
+
162
+ return accumulated_metrics
163
+
164
+ def _evaluate_category(self, category: str, pred_path: str, label_path: str) -> pd.DataFrame:
165
+ """์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ํ‰๊ฐ€ ์ˆ˜ํ–‰"""
166
+ results = []
167
+ metrics_columns = ['video_name']
168
+
169
+ for pred_file in cust_listdir(pred_path):
170
+ if not pred_file.endswith('.csv'):
171
+ continue
172
+
173
+ video_name = os.path.splitext(pred_file)[0]
174
+ pred_df = pd.read_csv(os.path.join(pred_path, pred_file))
175
+
176
+ # ํ•ด๋‹น ๋น„๋””์˜ค์˜ ์ •๋‹ต CSV ํŒŒ์ผ ๋กœ๋“œ
177
+ label_file = f"{video_name}.csv"
178
+ label_path_full = os.path.join(label_path, label_file)
179
+
180
+ if not os.path.exists(label_path_full):
181
+ print(f"Warning: Label file not found for {video_name}")
182
+ continue
183
+
184
+ label_df = pd.read_csv(label_path_full)
185
+
186
+ # ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
187
+ video_metrics = {'video_name': video_name}
188
+ categories = [col for col in pred_df.columns if col != 'frame']
189
+
190
+ for cat in categories:
191
+ # ์ •๋‹ต๊ฐ’๊ณผ ์˜ˆ์ธก๊ฐ’
192
+ y_true = label_df[cat].values
193
+ y_pred = pred_df[cat].values
194
+
195
+ # ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
196
+ metrics = self._calculate_metrics(y_true, y_pred)
197
+
198
+ # ๊ฒฐ๊ณผ ์ €์žฅ
199
+ for metric_name, value in metrics.items():
200
+ col_name = f"{cat}_{metric_name}"
201
+ video_metrics[col_name] = value
202
+ if col_name not in metrics_columns:
203
+ metrics_columns.append(col_name)
204
+
205
+ results.append(video_metrics)
206
+
207
+ # ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜
208
+ metrics_df = pd.DataFrame(results, columns=metrics_columns)
209
+
210
+ # ํ‰๊ท  ๊ณ„์‚ฐํ•˜์—ฌ ์ถ”๊ฐ€
211
+ avg_metrics = {'video_name': 'average'}
212
+ for col in metrics_columns[1:]: # video_name ์ œ์™ธ
213
+ avg_metrics[col] = metrics_df[col].mean()
214
+
215
+ metrics_df = pd.concat([metrics_df, pd.DataFrame([avg_metrics])], ignore_index=True)
216
+
217
+ return metrics_df
218
+
219
+ # def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
220
+ # """์„ฑ๋Šฅ ์ง€ํ‘œ ๊ณ„์‚ฐ"""
221
+ # tn = np.sum((y_true == 0) & (y_pred == 0))
222
+ # fp = np.sum((y_true == 0) & (y_pred == 1))
223
+
224
+ # metrics = {
225
+ # 'f1': f1_score(y_true, y_pred, zero_division=0),
226
+ # 'accuracy': accuracy_score(y_true, y_pred),
227
+ # 'precision': precision_score(y_true, y_pred, zero_division=0),
228
+ # 'recall': recall_score(y_true, y_pred, zero_division=0),
229
+ # 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0
230
+ # }
231
+
232
+ # return metrics
233
+
234
+
235
+ def calculate_accumulated_metrics(self, all_categories_metrics_df: pd.DataFrame) -> Dict:
236
+ """๋ˆ„์ ๋œ ํ˜ผ๋™ํ–‰๋ ฌ๋กœ ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ์„ฑ๋Šฅ ์ง€ํ‘œ ๊ณ„์‚ฐ"""
237
+ accumulated_results = {"micro_avg": {}}
238
+ categories = ['falldown', 'violence', 'fire']
239
+
240
+ for category in categories:
241
+ # ํ•ด๋‹น ์นดํ…Œ๊ณ ๋ฆฌ์˜ ํ˜ผ๋™ํ–‰๋ ฌ ๊ฐ’๋“ค ๋ˆ„์ 
242
+ tp = all_categories_metrics_df[f'{category}_tp'].sum()
243
+ tn = all_categories_metrics_df[f'{category}_tn'].sum()
244
+ fp = all_categories_metrics_df[f'{category}_fp'].sum()
245
+ fn = all_categories_metrics_df[f'{category}_fn'].sum()
246
+
247
+ # ๊ธฐ๋ณธ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
248
+ metrics = {
249
+ 'tp': int(tp),
250
+ 'tn': int(tn),
251
+ 'fp': int(fp),
252
+ 'fn': int(fn),
253
+ 'accuracy': (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0,
254
+ 'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
255
+ 'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
256
+ 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
257
+ 'f1': 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
258
+ }
259
+
260
+ # ์ถ”๊ฐ€ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
261
+ tpr = metrics['recall'] # TPR = recall
262
+ tnr = metrics['specificity'] # TNR = specificity
263
+
264
+ # Balanced Accuracy
265
+ metrics['balanced_accuracy'] = (tpr + tnr) / 2
266
+
267
+ # G-Mean
268
+ metrics['g_mean'] = np.sqrt(tpr * tnr) if (tpr * tnr) > 0 else 0
269
+
270
+ # MCC (Matthews Correlation Coefficient)
271
+ numerator = (tp * tn) - (fp * fn)
272
+ denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
273
+ metrics['mcc'] = numerator / denominator if denominator > 0 else 0
274
+
275
+ # NPV (Negative Predictive Value)
276
+ metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
277
+
278
+ # FAR (False Alarm Rate) = FPR = 1 - specificity
279
+ metrics['far'] = 1 - metrics['specificity']
280
+
281
+ accumulated_results[category] = metrics
282
+
283
+ # ์ „์ฒด ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋ˆ„์  ๊ฐ’์œผ๋กœ ๊ณ„์‚ฐ
284
+ total_tp = sum(accumulated_results[cat]['tp'] for cat in categories)
285
+ total_tn = sum(accumulated_results[cat]['tn'] for cat in categories)
286
+ total_fp = sum(accumulated_results[cat]['fp'] for cat in categories)
287
+ total_fn = sum(accumulated_results[cat]['fn'] for cat in categories)
288
+
289
+ # micro average ๊ณ„์‚ฐ (์ „์ฒด ๋ˆ„์  ๊ฐ’์œผ๋กœ ๊ณ„์‚ฐ)
290
+ accumulated_results["micro_avg"] = {
291
+ 'tp': int(total_tp),
292
+ 'tn': int(total_tn),
293
+ 'fp': int(total_fp),
294
+ 'fn': int(total_fn),
295
+ 'accuracy': (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn),
296
+ 'precision': total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0,
297
+ 'recall': total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0,
298
+ 'f1': 2 * total_tp / (2 * total_tp + total_fp + total_fn) if (2 * total_tp + total_fp + total_fn) > 0 else 0,
299
+ # ... (๋‹ค๋ฅธ ๋ฉ”ํŠธ๋ฆญ๋“ค๋„ ๋™์ผํ•œ ๋ฐฉ์‹์œผ๋กœ ๊ณ„์‚ฐ)
300
+ }
301
+
302
+ return accumulated_results
303
+ def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
304
+ """์„ฑ๋Šฅ ์ง€ํ‘œ ๊ณ„์‚ฐ"""
305
+ tn = np.sum((y_true == 0) & (y_pred == 0))
306
+ fp = np.sum((y_true == 0) & (y_pred == 1))
307
+ fn = np.sum((y_true == 1) & (y_pred == 0))
308
+ tp = np.sum((y_true == 1) & (y_pred == 1))
309
+
310
+ metrics = {
311
+ 'f1': f1_score(y_true, y_pred, zero_division=0),
312
+ 'accuracy': accuracy_score(y_true, y_pred),
313
+ 'precision': precision_score(y_true, y_pred, zero_division=0),
314
+ 'recall': recall_score(y_true, y_pred, zero_division=0),
315
+ 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
316
+ 'tp': int(tp),
317
+ 'tn': int(tn),
318
+ 'fp': int(fp),
319
+ 'fn': int(fn)
320
+ }
321
+
322
+ return metrics
pia_bench/pipe_line/piepline.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pia_bench.checker.bench_checker import BenchChecker
2
+ from pia_bench.checker.sheet_checker import SheetChecker
3
+ from pia_bench.event_alarm import EventDetector
4
+ from pia_bench.metric import MetricsEvaluator
5
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
6
+ from pia_bench.bench import PiaBenchMark
7
+ from dotenv import load_dotenv
8
+ from typing import Optional, List , Dict
9
+ import os
10
+ load_dotenv()
11
+ import numpy as np
12
+ from typing import Dict, Tuple
13
+ from typing import Dict, Optional, Tuple
14
+ import logging
15
+ from dataclasses import dataclass
16
+ from sheet_manager.sheet_checker.sheet_check import SheetChecker
17
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
18
+ from pia_bench.checker.bench_checker import BenchChecker
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+ from enviroments.config import BASE_BENCH_PATH
23
+
24
+ @dataclass
25
+ class PipelineConfig:
26
+ """ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ํด๋ž˜์Šค"""
27
+ model_name: str
28
+ benchmark_name: str
29
+ cfg_target_path: str
30
+ base_path: str = BASE_BENCH_PATH
31
+
32
+ class BenchmarkPipelineStatus:
33
+ """ํŒŒ์ดํ”„๋ผ์ธ ์ƒํƒœ ๋ฐ ๊ฒฐ๊ณผ ๊ด€๋ฆฌ"""
34
+ def __init__(self):
35
+ self.sheet_status: Tuple[bool, bool] = (False, False) # (model_added, benchmark_exists)
36
+ self.bench_status: Dict[str, bool] = {}
37
+ self.bench_result: str = ""
38
+ self.current_stage: str = "not_started"
39
+
40
+ def is_success(self) -> bool:
41
+ """์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์„ฑ๊ณต ์—ฌ๋ถ€"""
42
+ return (not self.sheet_status[0] # ๋ชจ๋ธ์ด ์ด๋ฏธ ์กด์žฌํ•˜๊ณ 
43
+ and self.sheet_status[1] # ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์กด์žฌํ•˜๊ณ 
44
+ and self.bench_result == "all_passed") # ๋ฒค์น˜๋งˆํฌ ์ฒดํฌ๋„ ํ†ต๊ณผ
45
+
46
+ def __str__(self) -> str:
47
+ return (f"Current Stage: {self.current_stage}\n"
48
+ f"Sheet Status: {self.sheet_status}\n"
49
+ f"Bench Status: {self.bench_status}\n"
50
+ f"Bench Result: {self.bench_result}")
51
+
52
+ class BenchmarkPipeline:
53
+ """๋ฒค์น˜๋งˆํฌ ์‹คํ–‰์„ ์œ„ํ•œ ํŒŒ์ดํ”„๋ผ์ธ"""
54
+
55
+ def __init__(self, config: PipelineConfig):
56
+ self.config = config
57
+ self.logger = logging.getLogger(self.__class__.__name__)
58
+ self.status = BenchmarkPipelineStatus()
59
+ self.access_token = os.getenv("ACCESS_TOKEN")
60
+ self.cfg_prompt = os.path.splitext(os.path.basename(self.config.cfg_target_path))[0]
61
+
62
+ # Initialize checkers
63
+ self.sheet_manager = SheetManager()
64
+ self.sheet_checker = SheetChecker(self.sheet_manager)
65
+ self.bench_checker = BenchChecker(self.config.base_path)
66
+
67
+ self.bench_result_dict = None
68
+
69
+ def run(self) -> BenchmarkPipelineStatus:
70
+ """์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰"""
71
+ try:
72
+ self.status.current_stage = "sheet_check"
73
+ proceed = self._check_sheet()
74
+
75
+ if not proceed:
76
+ self.status.current_stage = "completed_no_action_needed"
77
+ self.logger.info("๋ฒค์น˜๋งˆํฌ๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•˜์—ฌ ์ถ”๊ฐ€ ์ž‘์—…์ด ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
78
+ return self.status
79
+
80
+ self.status.current_stage = "bench_check"
81
+ if not self._check_bench():
82
+ return self.status
83
+
84
+ self.status.current_stage = "execution"
85
+ self._execute_based_on_status()
86
+
87
+ self.status.current_stage = "completed"
88
+ return self.status
89
+
90
+ except Exception as e:
91
+ self.logger.error(f"ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ์ค‘ ์—๋Ÿฌ ๋ฐœ์ƒ: {str(e)}")
92
+ self.status.current_stage = "error"
93
+ return self.status
94
+
95
+ def _check_sheet(self) -> bool:
96
+ """๊ตฌ๊ธ€ ์‹œํŠธ ์ƒํƒœ ์ฒดํฌ"""
97
+ self.logger.info("์‹œํŠธ ์ƒํƒœ ์ฒดํฌ ์‹œ์ž‘")
98
+ model_added, benchmark_exists = self.sheet_checker.check_model_and_benchmark(
99
+ self.config.model_name,
100
+ self.config.benchmark_name
101
+ )
102
+ self.status.sheet_status = (model_added, benchmark_exists)
103
+
104
+ if model_added:
105
+ self.logger.info("์ƒˆ๋กœ์šด ๋ชจ๋ธ์ด ์ถ”๊ฐ€๋˜์—ˆ์Šต๋‹ˆ๋‹ค")
106
+ if not benchmark_exists:
107
+ self.logger.info("๋ฒค์น˜๋งˆํฌ ์ธก์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
108
+ return True # ๋ฒค์น˜๋งˆํฌ ์ธก์ •์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ๋งŒ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
109
+
110
+ self.logger.info("์ด๋ฏธ ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
111
+ return False # ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์ด๋ฏธ ์žˆ์œผ๋ฉด ์—ฌ๊ธฐ์„œ ์ค‘๋‹จ
112
+
113
+ def _check_bench(self) -> bool:
114
+ """๋กœ์ปฌ ๋ฒค์น˜๋งˆํฌ ํ™˜๊ฒฝ ์ฒดํฌ"""
115
+ self.logger.info("๋ฒค์น˜๋งˆํฌ ํ™˜๊ฒฝ ์ฒดํฌ ์‹œ์ž‘")
116
+ self.status.bench_status = self.bench_checker.check_benchmark(
117
+ self.config.benchmark_name,
118
+ self.config.model_name,
119
+ self.cfg_prompt
120
+ )
121
+ self.status.bench_result = self.bench_checker.get_benchmark_status(
122
+ self.status.bench_status
123
+ )
124
+
125
+ # no bench ์ƒํƒœ ๋ฒค์น˜๋ฅผ ๋Œ๋ฆฐ์ ์ด ์—†์Œ ํด๋”๊ตฌ์กฐ๋„ ์—†์Œ
126
+ if self.status.bench_result == "no bench":
127
+ self.logger.error("๋ฒค์น˜๋งˆํฌ ์‹คํ–‰์— ํ•„์š”ํ•œ ๊ธฐ๋ณธ ํด๋”๊ตฌ์กฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
128
+ return True
129
+
130
+ return True # ๊ทธ ์™ธ์˜ ๊ฒฝ์šฐ๋งŒ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
131
+
132
+ def _execute_based_on_status(self):
133
+ """์ƒํƒœ์— ๋”ฐ๋ฅธ ์‹คํ–‰ ๋กœ์ง"""
134
+ if self.status.bench_result == "all_passed":
135
+ self._execute_full_pipeline()
136
+ elif self.status.bench_result == "no_vectors":
137
+ self._execute_vector_generation()
138
+ elif self.status.bench_result == "no_metrics":
139
+ self._execute_metrics_generation()
140
+ else:
141
+ self._execute_vector_generation()
142
+ self.logger.warning("ํด๋”๊ตฌ์กฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
143
+
144
+ def _execute_full_pipeline(self):
145
+ """๋ชจ๋“  ์กฐ๊ฑด์ด ์ถฉ์กฑ๋œ ๊ฒฝ์šฐ์˜ ์‹คํ–‰ ๋กœ์ง"""
146
+ self.logger.info("์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ์ค‘...")
147
+ pia_benchmark = PiaBenchMark(
148
+ benchmark_path = f"{BASE_BENCH_PATH}/{self.config.benchmark_name}" ,
149
+ model_name=self.config.model_name,
150
+ cfg_target_path= self.config.cfg_target_path ,
151
+ token=self.access_token )
152
+ pia_benchmark.preprocess_structure()
153
+ print("Categories identified:", pia_benchmark.categories)
154
+ metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
155
+ label_dir=pia_benchmark.dataset_path,
156
+ save_dir=pia_benchmark.metric_path)
157
+
158
+ self.bench_result_dict = metric.evaluate()
159
+
160
+ def _execute_vector_generation(self):
161
+ """๋ฒกํ„ฐ ์ƒ์„ฑ์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ์˜ ์‹คํ–‰ ๋กœ์ง"""
162
+ self.logger.info("๋ฒกํ„ฐ ์ƒ์„ฑ ์ค‘...")
163
+ # ๊ตฌํ˜„ ํ•„์š”
164
+
165
+ pia_benchmark = PiaBenchMark(
166
+ benchmark_path = f"{BASE_BENCH_PATH}/{self.config.benchmark_name}" ,
167
+ model_name=self.config.model_name,
168
+ cfg_target_path= self.config.cfg_target_path ,
169
+ token=self.access_token )
170
+ pia_benchmark.preprocess_structure()
171
+ pia_benchmark.preprocess_label_to_csv()
172
+ print("Categories identified:", pia_benchmark.categories)
173
+
174
+ pia_benchmark.extract_visual_vector()
175
+
176
+ detector = EventDetector(config_path=self.config.cfg_target_path,
177
+ model_name=self.config.model_name ,
178
+ token=pia_benchmark.token)
179
+ detector.process_and_save_predictions(pia_benchmark.vector_video_path,
180
+ pia_benchmark.dataset_path,
181
+ pia_benchmark.alram_path)
182
+ metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
183
+ label_dir=pia_benchmark.dataset_path,
184
+ save_dir=pia_benchmark.metric_path)
185
+
186
+ self.bench_result_dict = metric.evaluate()
187
+
188
+
189
+ def _execute_metrics_generation(self):
190
+ """๋ฉ”ํŠธ๋ฆญ ์ƒ์„ฑ์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ์˜ ์‹คํ–‰ ๋กœ์ง"""
191
+ self.logger.info("๋ฉ”ํŠธ๋ฆญ ์ƒ์„ฑ ์ค‘...")
192
+ # ๊ตฌํ˜„ ํ•„์š”
193
+ pia_benchmark = PiaBenchMark(
194
+ benchmark_path = f"{BASE_BENCH_PATH}/{self.config.benchmark_name}" ,
195
+ model_name=self.config.model_name,
196
+ cfg_target_path= self.config.cfg_target_path ,
197
+ token=self.access_token )
198
+ pia_benchmark.preprocess_structure()
199
+ pia_benchmark.preprocess_label_to_csv()
200
+ print("Categories identified:", pia_benchmark.categories)
201
+
202
+ detector = EventDetector(config_path=self.config.cfg_target_path,
203
+ model_name=self.config.model_name ,
204
+ token=pia_benchmark.token)
205
+ detector.process_and_save_predictions(pia_benchmark.vector_video_path,
206
+ pia_benchmark.dataset_path,
207
+ pia_benchmark.alram_path)
208
+ metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
209
+ label_dir=pia_benchmark.dataset_path,
210
+ save_dir=pia_benchmark.metric_path)
211
+
212
+ self.bench_result_dict = metric.evaluate()
213
+
214
+
215
+ if __name__ == "__main__":
216
+ # ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •
217
+ config = PipelineConfig(
218
+ model_name="T2V_CLIP4CLIP_MSRVTT",
219
+ benchmark_name="PIA",
220
+ cfg_target_path="topk.json",
221
+ base_path=f"{BASE_BENCH_PATH}"
222
+ )
223
+
224
+ # ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰
225
+ pipeline = BenchmarkPipeline(config)
226
+ result = pipeline.run()
227
+
228
+ print(f"\nํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ๊ฒฐ๊ณผ:")
229
+ print(str(result))
sheet_manager/sheet_checker/sheet_check.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+ import logging
3
+ import gspread
4
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
5
+
6
+ class SheetChecker:
7
+ def __init__(self, sheet_manager: SheetManager):
8
+ """SheetChecker ์ดˆ๊ธฐํ™”"""
9
+ self.sheet_manager = sheet_manager
10
+ self.bench_sheet_manager = None
11
+ self.logger = logging.getLogger(__name__)
12
+ self._init_bench_sheet()
13
+
14
+ def _init_bench_sheet(self):
15
+ """model ์‹œํŠธ์šฉ ์‹œํŠธ ๋งค๋‹ˆ์ € ์ดˆ๊ธฐํ™”"""
16
+ self.bench_sheet_manager = type(self.sheet_manager)(
17
+ spreadsheet_url=self.sheet_manager.spreadsheet_url,
18
+ worksheet_name="model",
19
+ column_name="Model name"
20
+ )
21
+
22
+ def add_benchmark_column(self, column_name: str):
23
+ """์ƒˆ๋กœ์šด ๋ฒค์น˜๋งˆํฌ ์ปฌ๋Ÿผ ์ถ”๊ฐ€"""
24
+ try:
25
+ headers = self.bench_sheet_manager.get_available_columns()
26
+
27
+ if column_name in headers:
28
+ return
29
+
30
+ new_col_index = len(headers) + 1
31
+ cell = gspread.utils.rowcol_to_a1(1, new_col_index)
32
+ self.bench_sheet_manager.sheet.update(cell, [[column_name]])
33
+
34
+
35
+ # ๊ด€๋ จ ์ปฌ๋Ÿผ ์ถ”๊ฐ€ (๋ฒค์น˜๋งˆํฌ์ด๋ฆ„*100)
36
+ next_col_index = new_col_index + 1
37
+ next_cell = gspread.utils.rowcol_to_a1(1, next_col_index)
38
+ self.bench_sheet_manager.sheet.update(next_cell, [[f"{column_name}*100"]])
39
+
40
+ self.logger.info(f"์ƒˆ๋กœ์šด ๋ฒค์น˜๋งˆํฌ ์ปฌ๋Ÿผ๋“ค ์ถ”๊ฐ€๋จ: {column_name}, {column_name}*100")
41
+ # ์ปฌ๋Ÿผ ์ถ”๊ฐ€ ํ›„ ์‹œํŠธ ๋งค๋‹ˆ์ € ์žฌ์—ฐ๊ฒฐ
42
+ self.bench_sheet_manager._connect_to_sheet(validate_column=False)
43
+
44
+ except Exception as e:
45
+ self.logger.error(f"๋ฒค์น˜๋งˆํฌ ์ปฌ๋Ÿผ {column_name} ์ถ”๊ฐ€ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
46
+ raise
47
+
48
+ def check_model_and_benchmark(self, model_name: str, benchmark_name: str) -> Tuple[bool, bool]:
49
+ """
50
+ ๋ชจ๋ธ ์กด์žฌ ์—ฌ๋ถ€์™€ ๋ฒค์น˜๋งˆํฌ ์ƒํƒœ๋ฅผ ํ™•์ธํ•˜๊ณ , ํ•„์š”ํ•œ ๊ฒฝ์šฐ ๋ชจ๋ธ ์ •๋ณด๋ฅผ ์ถ”๊ฐ€
51
+
52
+ Args:
53
+ model_name: ํ™•์ธํ•  ๋ชจ๋ธ ์ด๋ฆ„
54
+ benchmark_name: ํ™•์ธํ•  ๋ฒค์น˜๋งˆํฌ ์ด๋ฆ„
55
+
56
+ Returns:
57
+ Tuple[bool, bool]: (๋ชจ๋ธ์ด ์ƒˆ๋กœ ์ถ”๊ฐ€๋˜์—ˆ๋Š”์ง€ ์—ฌ๋ถ€, ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•˜๋Š”์ง€ ์—ฌ๋ถ€)
58
+ """
59
+ try:
60
+ # ๋ชจ๋ธ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ
61
+ model_exists = self._check_model_exists(model_name)
62
+ model_added = False
63
+
64
+ # ๋ชจ๋ธ์ด ์—†์œผ๋ฉด ์ถ”๊ฐ€
65
+ if not model_exists:
66
+ self._add_new_model(model_name)
67
+ model_added = True
68
+ self.logger.info(f"์ƒˆ๋กœ์šด ๋ชจ๋ธ ์ถ”๊ฐ€๋จ: {model_name}")
69
+
70
+ # ๋ฒค์น˜๋งˆํฌ ์ปฌ๋Ÿผ์ด ์—†์œผ๋ฉด ์ถ”๊ฐ€
71
+ available_columns = self.bench_sheet_manager.get_available_columns()
72
+ if benchmark_name not in available_columns:
73
+ self.add_benchmark_column(benchmark_name)
74
+ self.logger.info(f"์ƒˆ๋กœ์šด ๋ฒค์น˜๋งˆํฌ ์ปฌ๋Ÿผ ์ถ”๊ฐ€๋จ: {benchmark_name}")
75
+
76
+ # ๋ฒค์น˜๋งˆํฌ ์ƒํƒœ ํ™•์ธ
77
+ benchmark_exists = self._check_benchmark_exists(model_name, benchmark_name)
78
+
79
+ return model_added, benchmark_exists
80
+
81
+ except Exception as e:
82
+ self.logger.error(f"๋ชจ๋ธ/๋ฒค์น˜๋งˆํฌ ํ™•์ธ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
83
+ raise
84
+
85
+ def _check_model_exists(self, model_name: str) -> bool:
86
+ """๋ชจ๋ธ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ"""
87
+ try:
88
+ self.bench_sheet_manager.change_column("Model name")
89
+ values = self.bench_sheet_manager.get_all_values()
90
+ return model_name in values
91
+ except Exception as e:
92
+ self.logger.error(f"๋ชจ๋ธ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
93
+ raise
94
+
95
+ def _add_new_model(self, model_name: str):
96
+ """์ƒˆ๋กœ์šด ๋ชจ๋ธ ์ •๋ณด ์ถ”๊ฐ€"""
97
+ try:
98
+ model_info = {
99
+ "Model name": model_name,
100
+ "Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
101
+ "Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
102
+ }
103
+
104
+ for column_name, value in model_info.items():
105
+ self.bench_sheet_manager.change_column(column_name)
106
+ self.bench_sheet_manager.push(value)
107
+
108
+ except Exception as e:
109
+ self.logger.error(f"๋ชจ๋ธ ์ •๋ณด ์ถ”๊ฐ€ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
110
+ raise
111
+
112
+ def _check_benchmark_exists(self, model_name: str, benchmark_name: str) -> bool:
113
+ """๋ฒค์น˜๋งˆํฌ ๊ฐ’ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ"""
114
+ try:
115
+ # ํ•ด๋‹น ๋ชจ๋ธ์˜ ๋ฒค์น˜๋งˆํฌ ๊ฐ’ ํ™•์ธ
116
+ self.bench_sheet_manager.change_column("Model name")
117
+ all_values = self.bench_sheet_manager.get_all_values()
118
+ row_index = all_values.index(model_name) + 2
119
+
120
+ self.bench_sheet_manager.change_column(benchmark_name)
121
+ value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
122
+
123
+ return bool(value and value.strip())
124
+
125
+ except Exception as e:
126
+ self.logger.error(f"๋ฒค์น˜๋งˆํฌ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
127
+ raise
128
+
129
+ # ์‚ฌ์šฉ ์˜ˆ์‹œ
130
+ if __name__ == "__main__":
131
+ sheet_manager = SheetManager()
132
+ checker = SheetChecker(sheet_manager)
133
+
134
+ model_added, benchmark_exists = checker.check_model_and_benchmark(
135
+ model_name="test-model",
136
+ benchmark_name="COCO"
137
+ )
138
+
139
+ print(f"Model added: {model_added}")
140
+ print(f"Benchmark exists: {benchmark_exists}")
sheet_manager/sheet_convert/json2sheet.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
3
+ import json
4
+ from typing import Optional, Dict
5
+
6
+ def update_benchmark_json(
7
+ model_name: str,
8
+ benchmark_data: dict,
9
+ worksheet_name: str = "metric",
10
+ target_column: str = "benchmark" # ํƒ€๊ฒŸ ์นผ๋Ÿผ ํŒŒ๋ผ๋ฏธํ„ฐ ์ถ”๊ฐ€
11
+ ):
12
+ """
13
+ ํŠน์ • ๋ชจ๋ธ์˜ ๋ฒค์น˜๋งˆํฌ ๋ฐ์ดํ„ฐ๋ฅผ JSON ํ˜•ํƒœ๋กœ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.
14
+
15
+ Args:
16
+ model_name (str): ์—…๋ฐ์ดํŠธํ•  ๋ชจ๋ธ ์ด๋ฆ„
17
+ benchmark_data (dict): ์—…๋ฐ์ดํŠธํ•  ๋ฒค์น˜๋งˆํฌ ๋ฐ์ดํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ
18
+ worksheet_name (str): ์ž‘์—…ํ•  ์›Œํฌ์‹œํŠธ ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: "metric")
19
+ target_column (str): ์—…๋ฐ์ดํŠธํ•  ํƒ€๊ฒŸ ์นผ๋Ÿผ ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: "benchmark")
20
+ """
21
+ sheet_manager = SheetManager(worksheet_name=worksheet_name)
22
+
23
+ # ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ JSON ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
24
+ json_str = json.dumps(benchmark_data, ensure_ascii=False)
25
+
26
+ # ๋ชจ๋ธ๋ช…์„ ๊ธฐ์ค€์œผ๋กœ ์ง€์ •๋œ ์นผ๋Ÿผ ์—…๋ฐ์ดํŠธ
27
+ row = sheet_manager.update_cell_by_condition(
28
+ condition_column="Model name", # ๋ชจ๋ธ๋ช…์ด ์žˆ๋Š” ์นผ๋Ÿผ
29
+ condition_value=model_name, # ์ฐพ์„ ๋ชจ๋ธ๋ช…
30
+ target_column=target_column, # ์—…๋ฐ์ดํŠธํ•  ํƒ€๊ฒŸ ์นผ๋Ÿผ
31
+ target_value=json_str # ์—…๋ฐ์ดํŠธํ•  JSON ๊ฐ’
32
+ )
33
+
34
+ if row:
35
+ print(f"Successfully updated {target_column} data for model: {model_name}")
36
+ else:
37
+ print(f"Model {model_name} not found in the sheet")
38
+
39
+
40
+
41
+ def get_benchmark_dict(
42
+ model_name: str,
43
+ worksheet_name: str = "metric",
44
+ target_column: str = "benchmark",
45
+ save_path: Optional[str] = None
46
+ ) -> Dict:
47
+ """
48
+ ์‹œํŠธ์—์„œ ํŠน์ • ๋ชจ๋ธ์˜ ๋ฒค์น˜๋งˆํฌ JSON ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์™€ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
49
+
50
+ Args:
51
+ model_name (str): ๊ฐ€์ ธ์˜ฌ ๋ชจ๋ธ ์ด๋ฆ„
52
+ worksheet_name (str): ์ž‘์—…ํ•  ์›Œํฌ์‹œํŠธ ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: "metric")
53
+ target_column (str): ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ฌ ์นผ๋Ÿผ ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: "benchmark")
54
+ save_path (str, optional): ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์ €์žฅํ•  JSON ํŒŒ์ผ ๊ฒฝ๋กœ
55
+
56
+ Returns:
57
+ Dict: ๋ฒค์น˜๋งˆํฌ ๋ฐ์ดํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ. ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๊ฑฐ๋‚˜ JSON ํŒŒ์‹ฑ ์‹คํŒจ์‹œ ๋นˆ ๋”•์…”๋„ˆ๋ฆฌ ๋ฐ˜ํ™˜
58
+ """
59
+ sheet_manager = SheetManager(worksheet_name=worksheet_name)
60
+
61
+ try:
62
+ # ๋ชจ๋“  ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
63
+ data = sheet_manager.sheet.get_all_records()
64
+
65
+ # ํ•ด๋‹น ๋ชจ๋ธ ์ฐพ๊ธฐ
66
+ target_row = next(
67
+ (row for row in data if row.get("Model name") == model_name),
68
+ None
69
+ )
70
+
71
+ if not target_row:
72
+ print(f"Model {model_name} not found in the sheet")
73
+ return {}
74
+
75
+ # ํƒ€๊ฒŸ ์นผ๋Ÿผ์˜ JSON ๋ฌธ์ž์—ด ๊ฐ€์ ธ์˜ค๊ธฐ
76
+ json_str = target_row.get(target_column)
77
+
78
+ if not json_str:
79
+ print(f"No data found in {target_column} for model: {model_name}")
80
+ return {}
81
+
82
+ # JSON ๋ฌธ์ž์—ด์„ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜
83
+ result_dict = json.loads(json_str)
84
+
85
+ # ๊ฒฐ๊ณผ ์ €์žฅ (save_path๊ฐ€ ์ œ๊ณต๋œ ๊ฒฝ์šฐ)
86
+ if save_path:
87
+ with open(save_path, 'w', encoding='utf-8') as f:
88
+ json.dump(result_dict, f, ensure_ascii=False, indent=2)
89
+ print(f"Successfully saved dictionary to: {save_path}")
90
+
91
+ return result_dict
92
+
93
+ except json.JSONDecodeError:
94
+ print(f"Failed to parse JSON data for model: {model_name}")
95
+ return {}
96
+ except Exception as e:
97
+ print(f"Error occurred: {str(e)}")
98
+ return {}
99
+
100
+ def str2json(json_str):
101
+ """
102
+ ๋ฌธ์ž์—ด์„ JSON ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
103
+
104
+ Args:
105
+ json_str (str): JSON ํ˜•์‹์˜ ๋ฌธ์ž์—ด
106
+
107
+ Returns:
108
+ dict: ํŒŒ์‹ฑ๋œ JSON ๊ฐ์ฒด, ์‹คํŒจ์‹œ None
109
+ """
110
+ try:
111
+ return json.loads(json_str)
112
+ except json.JSONDecodeError as e:
113
+ print(f"JSON Parsing Error: {e}")
114
+ return None
115
+ except Exception as e:
116
+ print(f"Unexpected Error: {e}")
117
+ return None
sheet_manager/sheet_crud/create_col.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ import gspread
4
+ from huggingface_hub import HfApi
5
+ import os
6
+ from dotenv import load_dotenv
7
+ from enviroments.convert import get_json_from_env_var
8
+
9
+ load_dotenv()
10
+
11
+ def push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization):
12
+ """
13
+ Fetches model names from Hugging Face and updates a Google Sheet with the names, links, and HTML links.
14
+
15
+ Args:
16
+ json_key_path (str): Path to the Google service account JSON key file.
17
+ spreadsheet_url (str): URL of the Google Spreadsheet.
18
+ sheet_name (str): Name of the sheet to update.
19
+ access_token (str): Hugging Face access token.
20
+ organization (str): Organization name on Hugging Face.
21
+ """
22
+ # Authorize Google Sheets API
23
+ scope = ['https://spreadsheets.google.com/feeds',
24
+ 'https://www.googleapis.com/auth/drive']
25
+ json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
26
+ credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
27
+ gc = gspread.authorize(credential)
28
+
29
+ # Open the Google Spreadsheet
30
+ doc = gc.open_by_url(spreadsheet_url)
31
+ sheet = doc.worksheet(sheet_name)
32
+
33
+ # Fetch existing data from the sheet
34
+ existing_data = pd.DataFrame(sheet.get_all_records())
35
+
36
+ # Fetch models from Hugging Face
37
+ api = HfApi()
38
+ models = api.list_models(author=organization, use_auth_token=access_token)
39
+
40
+ # Extract model names, links, and HTML links
41
+ model_details = [{
42
+ "Model name": model.modelId.split("/")[1],
43
+ "Model link": f"https://huggingface.co/{model.modelId}",
44
+ "Model": f"<a target=\"_blank\" href=\"https://huggingface.co/{model.modelId}\" style=\"color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;\">{model.modelId}</a>"
45
+ } for model in models]
46
+
47
+ new_data_df = pd.DataFrame(model_details)
48
+
49
+ # Check for duplicates and update only new model names
50
+ if "Model name" in existing_data.columns:
51
+ existing_model_names = existing_data["Model name"].tolist()
52
+ else:
53
+ existing_model_names = []
54
+
55
+ new_data_df = new_data_df[~new_data_df["Model name"].isin(existing_model_names)]
56
+
57
+ if not new_data_df.empty:
58
+ # Append new model names, links, and HTML links to the existing data
59
+ updated_data = pd.concat([existing_data, new_data_df], ignore_index=True)
60
+
61
+ # Push updated data back to the sheet
62
+ updated_data = updated_data.replace([float('inf'), float('-inf')], None) # Infinity ๊ฐ’์„ None์œผ๋กœ ๋ณ€ํ™˜
63
+ updated_data = updated_data.fillna('') # NaN ๊ฐ’์„ ๋นˆ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
64
+ sheet.update([updated_data.columns.values.tolist()] + updated_data.values.tolist())
65
+ print("New model names, links, and HTML links successfully added to Google Sheet.")
66
+ else:
67
+ print("No new model names to add.")
68
+
69
+ # Example usage
70
+ if __name__ == "__main__":
71
+ spreadsheet_url = os.getenv("SPREADSHEET_URL")
72
+ access_token = os.getenv("ACCESS_TOKEN")
73
+ sheet_name = "์‹œํŠธ1"
74
+ organization = "PIA-SPACE-LAB"
75
+
76
+ push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization)
sheet_manager/sheet_crud/sheet_crud.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ import gspread
4
+ from dotenv import load_dotenv
5
+ from enviroments.convert import get_json_from_env_var
6
+ from typing import Optional, List
7
+
8
+ load_dotenv(override=True)
9
+
10
+ class SheetManager:
11
+ def __init__(self, spreadsheet_url: Optional[str] = None,
12
+ worksheet_name: str = "flag",
13
+ column_name: str = "huggingface_id"):
14
+ """
15
+ Initialize SheetManager with Google Sheets credentials and connection.
16
+
17
+ Args:
18
+ spreadsheet_url (str, optional): URL of the Google Spreadsheet.
19
+ If None, takes from environment variable.
20
+ worksheet_name (str): Name of the worksheet to operate on.
21
+ Defaults to "flag".
22
+ column_name (str): Name of the column to operate on.
23
+ Defaults to "huggingface_id".
24
+ """
25
+ self.spreadsheet_url = spreadsheet_url or os.getenv("SPREADSHEET_URL")
26
+ if not self.spreadsheet_url:
27
+ raise ValueError("Spreadsheet URL not provided and not found in environment variables")
28
+
29
+ self.worksheet_name = worksheet_name
30
+ self.column_name = column_name
31
+
32
+ # Initialize credentials and client
33
+ self._init_google_client()
34
+
35
+ # Initialize sheet connection
36
+ self.doc = None
37
+ self.sheet = None
38
+ self.col_index = None
39
+ self._connect_to_sheet(validate_column=True)
40
+
41
+ def _init_google_client(self):
42
+ """Initialize Google Sheets client with credentials."""
43
+ scope = ['https://spreadsheets.google.com/feeds',
44
+ 'https://www.googleapis.com/auth/drive']
45
+ json_key_dict = get_json_from_env_var("GOOGLE_CREDENTIALS")
46
+ credentials = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
47
+ self.client = gspread.authorize(credentials)
48
+
49
+ def _connect_to_sheet(self, validate_column: bool = True):
50
+ """
51
+ Connect to the specified Google Sheet and initialize necessary attributes.
52
+
53
+ Args:
54
+ validate_column (bool): Whether to validate the column name exists
55
+ """
56
+ try:
57
+ self.doc = self.client.open_by_url(self.spreadsheet_url)
58
+
59
+ # Try to get the worksheet
60
+ try:
61
+ self.sheet = self.doc.worksheet(self.worksheet_name)
62
+ except gspread.exceptions.WorksheetNotFound:
63
+ raise ValueError(f"Worksheet '{self.worksheet_name}' not found in spreadsheet")
64
+
65
+ # Get headers
66
+ self.headers = self.sheet.row_values(1)
67
+
68
+ # Validate column only if requested
69
+ if validate_column:
70
+ try:
71
+ self.col_index = self.headers.index(self.column_name) + 1
72
+ except ValueError:
73
+ # If column not found, use first available column
74
+ if self.headers:
75
+ self.column_name = self.headers[0]
76
+ self.col_index = 1
77
+ print(f"Column '{self.column_name}' not found. Using first available column: '{self.headers[0]}'")
78
+ else:
79
+ raise ValueError("No columns found in worksheet")
80
+
81
+ except Exception as e:
82
+ if isinstance(e, ValueError):
83
+ raise e
84
+ raise ConnectionError(f"Failed to connect to sheet: {str(e)}")
85
+
86
+ def change_worksheet(self, worksheet_name: str, column_name: Optional[str] = None):
87
+ """
88
+ Change the current worksheet and optionally the column.
89
+
90
+ Args:
91
+ worksheet_name (str): Name of the worksheet to switch to
92
+ column_name (str, optional): Name of the column to switch to
93
+ """
94
+ old_worksheet = self.worksheet_name
95
+ old_column = self.column_name
96
+
97
+ try:
98
+ self.worksheet_name = worksheet_name
99
+ if column_name:
100
+ self.column_name = column_name
101
+
102
+ # First connect without column validation
103
+ self._connect_to_sheet(validate_column=False)
104
+
105
+ # Then validate the column if specified
106
+ if column_name:
107
+ self.change_column(column_name)
108
+ else:
109
+ # Validate existing column in new worksheet
110
+ try:
111
+ self.col_index = self.headers.index(self.column_name) + 1
112
+ except ValueError:
113
+ # If column not found, use first available column
114
+ if self.headers:
115
+ self.column_name = self.headers[0]
116
+ self.col_index = 1
117
+ print(f"Column '{old_column}' not found in new worksheet. Using first available column: '{self.headers[0]}'")
118
+ else:
119
+ raise ValueError("No columns found in worksheet")
120
+
121
+ print(f"Successfully switched to worksheet: {worksheet_name}, using column: {self.column_name}")
122
+
123
+ except Exception as e:
124
+ # Restore previous state on error
125
+ self.worksheet_name = old_worksheet
126
+ self.column_name = old_column
127
+ self._connect_to_sheet()
128
+ raise e
129
+
130
+ def change_column(self, column_name: str):
131
+ """
132
+ Change the target column.
133
+
134
+ Args:
135
+ column_name (str): Name of the column to switch to
136
+ """
137
+ if not self.headers:
138
+ self.headers = self.sheet.row_values(1)
139
+
140
+ try:
141
+ self.col_index = self.headers.index(column_name) + 1
142
+ self.column_name = column_name
143
+ print(f"Successfully switched to column: {column_name}")
144
+ except ValueError:
145
+ raise ValueError(f"Column '{column_name}' not found in worksheet. Available columns: {', '.join(self.headers)}")
146
+
147
+ def get_available_worksheets(self) -> List[str]:
148
+ """Get list of all available worksheets in the spreadsheet."""
149
+ return [worksheet.title for worksheet in self.doc.worksheets()]
150
+
151
+ def get_available_columns(self) -> List[str]:
152
+ """Get list of all available columns in the current worksheet."""
153
+ return self.headers if self.headers else self.sheet.row_values(1)
154
+
155
+ def _reconnect_if_needed(self):
156
+ """Reconnect to the sheet if the connection is lost."""
157
+ try:
158
+ self.sheet.row_values(1)
159
+ except (gspread.exceptions.APIError, AttributeError):
160
+ self._init_google_client()
161
+ self._connect_to_sheet()
162
+
163
+ def _fetch_column_data(self) -> List[str]:
164
+ """Fetch all data from the huggingface_id column."""
165
+ values = self.sheet.col_values(self.col_index)
166
+ return values[1:] # Exclude header
167
+
168
+ def _update_sheet(self, data: List[str]):
169
+ """Update the entire column with new data."""
170
+ try:
171
+ # Prepare the range for update (excluding header)
172
+ start_cell = gspread.utils.rowcol_to_a1(2, self.col_index) # Start from row 2
173
+ end_cell = gspread.utils.rowcol_to_a1(len(data) + 2, self.col_index)
174
+ range_name = f"{start_cell}:{end_cell}"
175
+
176
+ # Convert data to 2D array format required by gspread
177
+ cells = [[value] for value in data]
178
+
179
+ # Update the range
180
+ self.sheet.update(range_name, cells)
181
+ except Exception as e:
182
+ print(f"Error updating sheet: {str(e)}")
183
+ raise
184
+
185
+ def push(self, text: str) -> int:
186
+ """
187
+ Push a text value to the next empty cell in the huggingface_id column.
188
+
189
+ Args:
190
+ text (str): Text to push to the sheet
191
+
192
+ Returns:
193
+ int: The row number where the text was pushed
194
+ """
195
+ try:
196
+ self._reconnect_if_needed()
197
+
198
+ # Get all values in the huggingface_id column
199
+ column_values = self.sheet.col_values(self.col_index)
200
+
201
+ # Find the next empty row
202
+ next_row = None
203
+ for i in range(1, len(column_values)):
204
+ if not column_values[i].strip():
205
+ next_row = i + 1
206
+ break
207
+
208
+ # If no empty row found, append to the end
209
+ if next_row is None:
210
+ next_row = len(column_values) + 1
211
+
212
+ # Update the cell
213
+ self.sheet.update_cell(next_row, self.col_index, text)
214
+ print(f"Successfully pushed value: {text} to row {next_row}")
215
+ return next_row
216
+
217
+ except Exception as e:
218
+ print(f"Error pushing to sheet: {str(e)}")
219
+ raise
220
+
221
+ def pop(self) -> Optional[str]:
222
+ """Remove and return the most recent value."""
223
+ try:
224
+ self._reconnect_if_needed()
225
+ data = self._fetch_column_data()
226
+
227
+ if not data or not data[0].strip():
228
+ return None
229
+
230
+ value = data.pop(0) # Remove first value
231
+ data.append("") # Add empty string at the end to maintain sheet size
232
+
233
+ self._update_sheet(data)
234
+ print(f"Successfully popped value: {value}")
235
+ return value
236
+
237
+ except Exception as e:
238
+ print(f"Error popping from sheet: {str(e)}")
239
+ raise
240
+
241
+ def delete(self, value: str) -> List[int]:
242
+ """Delete all occurrences of a value."""
243
+ try:
244
+ self._reconnect_if_needed()
245
+ data = self._fetch_column_data()
246
+
247
+ # Find all indices before deletion
248
+ indices = [i + 1 for i, v in enumerate(data) if v.strip() == value.strip()]
249
+ if not indices:
250
+ print(f"Value '{value}' not found in sheet")
251
+ return []
252
+
253
+ # Remove matching values and add empty strings at the end
254
+ data = [v for v in data if v.strip() != value.strip()]
255
+ data.extend([""] * len(indices)) # Add empty strings to maintain sheet size
256
+
257
+ self._update_sheet(data)
258
+ print(f"Successfully deleted value '{value}' from rows: {indices}")
259
+ return indices
260
+
261
+ except Exception as e:
262
+ print(f"Error deleting from sheet: {str(e)}")
263
+ raise
264
+
265
+ def update_cell_by_condition(self, condition_column: str, condition_value: str, target_column: str, target_value: str) -> Optional[int]:
266
+ """
267
+ Update the value of a cell based on a condition in another column.
268
+
269
+ Args:
270
+ condition_column (str): The column to check the condition on.
271
+ condition_value (str): The value to match in the condition column.
272
+ target_column (str): The column where the value should be updated.
273
+ target_value (str): The new value to set in the target column.
274
+
275
+ Returns:
276
+ Optional[int]: The row number where the value was updated, or None if no matching row was found.
277
+ """
278
+ try:
279
+ self._reconnect_if_needed()
280
+
281
+ # Get all column headers
282
+ headers = self.sheet.row_values(1)
283
+
284
+ # Find the indices for the condition and target columns
285
+ try:
286
+ condition_col_index = headers.index(condition_column) + 1
287
+ except ValueError:
288
+ raise ValueError(f"์กฐ๊ฑด ์นผ๋Ÿผ '{condition_column}'์ด(๊ฐ€) ์—†์Šต๋‹ˆ๋‹ค.")
289
+
290
+ try:
291
+ target_col_index = headers.index(target_column) + 1
292
+ except ValueError:
293
+ raise ValueError(f"๋ชฉํ‘œ ์นผ๋Ÿผ '{target_column}'์ด(๊ฐ€) ์—†์Šต๋‹ˆ๋‹ค.")
294
+
295
+ # Get all rows of data
296
+ data = self.sheet.get_all_records()
297
+
298
+ # Find the row that matches the condition
299
+ for i, row in enumerate(data):
300
+ if row.get(condition_column) == condition_value:
301
+ # Update the target column in the matching row
302
+ row_number = i + 2 # Row index starts at 2 (1 is header)
303
+ self.sheet.update_cell(row_number, target_col_index, target_value)
304
+ print(f"Updated row {row_number}: Set {target_column} to '{target_value}' where {condition_column} is '{condition_value}'")
305
+ return row_number
306
+
307
+ print(f"์กฐ๊ฑด์— ๋งž๋Š” ํ–‰์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {condition_column} = '{condition_value}'")
308
+ return None
309
+
310
+ except Exception as e:
311
+ print(f"Error updating cell by condition: {str(e)}")
312
+ raise
313
+
314
+ def get_all_values(self) -> List[str]:
315
+ """Get all values from the huggingface_id column."""
316
+ self._reconnect_if_needed()
317
+ return [v for v in self._fetch_column_data() if v.strip()]
318
+
319
+ # Example usage
320
+ if __name__ == "__main__":
321
+ # Initialize sheet manager
322
+ sheet_manager = SheetManager()
323
+
324
+ # # Push some test values
325
+ # sheet_manager.push("test-model-1")
326
+ # sheet_manager.push("test-model-2")
327
+ # sheet_manager.push("test-model-3")
328
+
329
+ # print("Initial values:", sheet_manager.get_all_values())
330
+
331
+ # # Pop the most recent value
332
+ # popped = sheet_manager.pop()
333
+ # print(f"Popped value: {popped}")
334
+ # print("After pop:", sheet_manager.get_all_values())
335
+
336
+ # # Delete a specific value
337
+ # deleted_rows = sheet_manager.delete("test-model-2")
338
+ # print(f"Deleted from rows: {deleted_rows}")
339
+ # print("After delete:", sheet_manager.get_all_values())
340
+
341
+ row_updated = sheet_manager.update_cell_by_condition(
342
+ condition_column="model",
343
+ condition_value="msr",
344
+ target_column="pia",
345
+ target_value="new_value"
346
+ )
347
+
sheet_manager/sheet_loader/sheet2df.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ import gspread
4
+ from dotenv import load_dotenv
5
+ import os
6
+ import json
7
+ import pandas as pd
8
+ from enviroments.convert import get_json_from_env_var
9
+ load_dotenv()
10
+
11
+ def sheet2df(sheet_name:str = "model"):
12
+ """
13
+ Reads data from a specified Google Spreadsheet and converts it into a Pandas DataFrame.
14
+
15
+ Steps:
16
+ 1. Authenticate using a service account JSON key.
17
+ 2. Open the spreadsheet by its URL.
18
+ 3. Select the worksheet to read.
19
+ 4. Convert the worksheet data to a Pandas DataFrame.
20
+ 5. Clean up the DataFrame:
21
+ - Rename columns using the first row of data.
22
+ - Drop the first row after renaming columns.
23
+
24
+ Returns:
25
+ pd.DataFrame: A Pandas DataFrame containing the cleaned data from the spreadsheet.
26
+
27
+ Note:
28
+ - The following variables must be configured before using this function:
29
+ - `json_key_path`: Path to the service account JSON key file.
30
+ - `spreadsheet_url`: URL of the Google Spreadsheet.
31
+ - `sheet_name`: Name of the worksheet to load.
32
+
33
+ Dependencies:
34
+ - pandas
35
+ - gspread
36
+ - oauth2client
37
+ """
38
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
39
+ json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
40
+ credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
41
+ gc = gspread.authorize(credential)
42
+
43
+ spreadsheet_url = os.getenv("SPREADSHEET_URL")
44
+ doc = gc.open_by_url(spreadsheet_url)
45
+ sheet = doc.worksheet(sheet_name)
46
+
47
+ # Convert to DataFrame
48
+ df = pd.DataFrame(sheet.get_all_values())
49
+ # Clean DataFrame
50
+ df.rename(columns=df.iloc[0], inplace=True)
51
+ df.drop(df.index[0], inplace=True)
52
+
53
+ return df
54
+
55
+
56
+ def add_scaled_columns(df: pd.DataFrame, origin_col_name: str) -> pd.DataFrame:
57
+ """
58
+ ํŠน์ • ์นผ๋Ÿผ์˜ ๋ชจ๋“  ๊ฐ’์„ float๋กœ ๋ณ€ํ™˜ํ•œ ํ›„ 100๋ฐฐ ํ•œ ๋‹ค์Œ, ์ƒˆ๋กœ์šด ์นผ๋Ÿผ์„ ์ถ”๊ฐ€ํ•˜๋Š” ํ•จ์ˆ˜.
59
+
60
+ :param df: ์›๋ณธ pandas DataFrame
61
+ :param origin_col_name: ๋ณ€ํ™˜ํ•  ์นผ๋Ÿผ ์ด๋ฆ„ (์ˆซ์ž ๊ฐ’์ด ๋“ค์–ด์žˆ๋Š” ์นผ๋Ÿผ)
62
+ :return: ์ƒˆ๋กœ์šด ์นผ๋Ÿผ์ด ์ถ”๊ฐ€๋œ DataFrame
63
+ """
64
+ # df[origin_col_name] = df[origin_col_name].astype(float)
65
+ df[origin_col_name] = pd.to_numeric(df[origin_col_name], errors='coerce').fillna(0.0)
66
+
67
+ new_col_name = f"{origin_col_name}*100" # ์ƒˆ๋กœ์šด ์นผ๋Ÿผ๋ช… ์ƒ์„ฑ
68
+ df[new_col_name] = df[origin_col_name] * 100 # ๊ฐ’ ๋ณ€ํ™˜ ํ›„ ์ƒˆ๋กœ์šด ์นผ๋Ÿผ ์ถ”๊ฐ€
69
+
70
+ return df
sheet_manager/sheet_monitor/sheet_sync.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from typing import Optional, Callable
4
+ import logging
5
+
6
+ class SheetMonitor:
7
+ def __init__(self, sheet_manager, check_interval: float = 1.0):
8
+ """
9
+ Initialize SheetMonitor with a sheet manager instance.
10
+ """
11
+ self.sheet_manager = sheet_manager
12
+ self.check_interval = check_interval
13
+
14
+ # Threading control
15
+ self.monitor_thread = None
16
+ self.is_running = threading.Event()
17
+ self.pause_monitoring = threading.Event()
18
+ self.monitor_paused = threading.Event()
19
+
20
+ # Queue status
21
+ self.has_data = threading.Event()
22
+
23
+ # Logging setup
24
+ logging.basicConfig(level=logging.INFO)
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def start_monitoring(self):
28
+ """Start the monitoring thread."""
29
+ if self.monitor_thread is not None and self.monitor_thread.is_alive():
30
+ self.logger.warning("Monitoring thread is already running")
31
+ return
32
+
33
+ self.is_running.set()
34
+ self.pause_monitoring.clear()
35
+ self.monitor_thread = threading.Thread(target=self._monitor_loop)
36
+ self.monitor_thread.daemon = True
37
+ self.monitor_thread.start()
38
+ self.logger.info("Started monitoring thread")
39
+
40
+ def stop_monitoring(self):
41
+ """Stop the monitoring thread."""
42
+ self.is_running.clear()
43
+ if self.monitor_thread:
44
+ self.monitor_thread.join()
45
+ self.logger.info("Stopped monitoring thread")
46
+
47
+ def pause(self):
48
+ """Pause the monitoring."""
49
+ self.pause_monitoring.set()
50
+ self.monitor_paused.wait()
51
+ self.logger.info("Monitoring paused")
52
+
53
+ def resume(self):
54
+ """Resume the monitoring."""
55
+ self.pause_monitoring.clear()
56
+ self.monitor_paused.clear()
57
+ # ์ฆ‰์‹œ ์ฒดํฌ ์ˆ˜ํ–‰
58
+ self.logger.info("Monitoring resumed, checking for new data...")
59
+ values = self.sheet_manager.get_all_values()
60
+ if values:
61
+ self.has_data.set()
62
+ self.logger.info(f"Found data after resume: {values}")
63
+
64
+
65
+ def _monitor_loop(self):
66
+ """Main monitoring loop that checks for data in sheet."""
67
+ while self.is_running.is_set():
68
+ if self.pause_monitoring.is_set():
69
+ self.monitor_paused.set()
70
+ self.pause_monitoring.wait()
71
+ self.monitor_paused.clear()
72
+ # continue
73
+
74
+ try:
75
+ # Check if there's any data in the sheet
76
+ values = self.sheet_manager.get_all_values()
77
+ self.logger.info(f"Monitoring: Current column={self.sheet_manager.column_name}, "
78
+ f"Values found={len(values)}, "
79
+ f"Has data={self.has_data.is_set()}")
80
+
81
+ if values: # If there's any non-empty value
82
+ self.has_data.set()
83
+ self.logger.info(f"Data detected: {values}")
84
+ else:
85
+ self.has_data.clear()
86
+ self.logger.info("No data in sheet, waiting...")
87
+
88
+ time.sleep(self.check_interval)
89
+
90
+ except Exception as e:
91
+ self.logger.error(f"Error in monitoring loop: {str(e)}")
92
+ time.sleep(self.check_interval)
93
+
94
+ class MainLoop:
95
+ def __init__(self, sheet_manager, sheet_monitor, callback_function: Callable = None):
96
+ """
97
+ Initialize MainLoop with sheet manager and monitor instances.
98
+ """
99
+ self.sheet_manager = sheet_manager
100
+ self.monitor = sheet_monitor
101
+ self.callback = callback_function
102
+ self.is_running = threading.Event()
103
+ self.logger = logging.getLogger(__name__)
104
+
105
+ def start(self):
106
+ """Start the main processing loop."""
107
+ self.is_running.set()
108
+ self.monitor.start_monitoring()
109
+ self._main_loop()
110
+
111
+ def stop(self):
112
+ """Stop the main processing loop."""
113
+ self.is_running.clear()
114
+ self.monitor.stop_monitoring()
115
+
116
+ def process_new_value(self):
117
+ """Process values by calling pop function for multiple columns and custom callback."""
118
+ try:
119
+ # Store original column
120
+ original_column = self.sheet_manager.column_name
121
+
122
+ # Pop from huggingface_id column
123
+ model_id = self.sheet_manager.pop()
124
+
125
+ if model_id:
126
+ # Pop from benchmark_name column
127
+ self.sheet_manager.change_column("benchmark_name")
128
+ benchmark_name = self.sheet_manager.pop()
129
+
130
+ # Pop from prompt_cfg_name column
131
+ self.sheet_manager.change_column("prompt_cfg_name")
132
+ prompt_cfg_name = self.sheet_manager.pop()
133
+
134
+ # Return to original column
135
+ self.sheet_manager.change_column(original_column)
136
+
137
+ self.logger.info(f"Processed values - model_id: {model_id}, "
138
+ f"benchmark_name: {benchmark_name}, "
139
+ f"prompt_cfg_name: {prompt_cfg_name}")
140
+
141
+ if self.callback:
142
+ # Pass all three values to callback
143
+ self.callback(model_id, benchmark_name, prompt_cfg_name)
144
+
145
+ return model_id, benchmark_name, prompt_cfg_name
146
+
147
+ except Exception as e:
148
+ self.logger.error(f"Error processing values: {str(e)}")
149
+ # Return to original column in case of error
150
+ try:
151
+ self.sheet_manager.change_column(original_column)
152
+ except:
153
+ pass
154
+ return None
155
+
156
+ def _main_loop(self):
157
+ """Main processing loop."""
158
+ while self.is_running.is_set():
159
+ # Wait for data to be available
160
+ if self.monitor.has_data.wait(timeout=1.0):
161
+ # Pause monitoring
162
+ self.monitor.pause()
163
+
164
+ # Process the value
165
+ self.process_new_value()
166
+
167
+ # Check if there's still data in the sheet
168
+ values = self.sheet_manager.get_all_values()
169
+ self.logger.info(f"After processing: Current column={self.sheet_manager.column_name}, "
170
+ f"Values remaining={len(values)}")
171
+
172
+ if not values:
173
+ self.monitor.has_data.clear()
174
+ self.logger.info("All data processed, clearing has_data flag")
175
+ else:
176
+ self.logger.info(f"Remaining data: {values}")
177
+
178
+ # Resume monitoring
179
+ self.monitor.resume()
180
+ ## TODO
181
+ # API ๋ถ„๋‹น ํ˜ธ์ถœ ๋ฌธ์ œ๋กœ ๋งŒ์•ฝ์— ์ฐธ์กฐํ•˜๋‹ค๊ฐ€ ์‹คํŒจํ•  ๊ฒฝ์šฐ ๋Œ€๊ธฐํ–ˆ๋‹ค๊ฐ€ ๋‹ค์‹œ ์‹œ๋„ํ•˜๊ฒŒ๋” ์„ค๊ณ„
182
+
183
+
184
+ # Example usage
185
+ if __name__ == "__main__":
186
+ import sys
187
+ from pathlib import Path
188
+ sys.path.append(str(Path(__file__).parent.parent.parent))
189
+ from sheet_manager.sheet_crud.sheet_crud import SheetManager
190
+ from pia_bench.pipe_line.piepline import PiaBenchMark
191
+ def my_custom_function(huggingface_id, benchmark_name, prompt_cfg_name):
192
+ piabenchmark = PiaBenchMark(huggingface_id, benchmark_name, prompt_cfg_name)
193
+ piabenchmark.bench_start()
194
+
195
+ # Initialize components
196
+ sheet_manager = SheetManager()
197
+ monitor = SheetMonitor(sheet_manager, check_interval=10.0)
198
+ main_loop = MainLoop(sheet_manager, monitor, callback_function=my_custom_function)
199
+
200
+ try:
201
+ main_loop.start()
202
+ while True:
203
+ time.sleep(5)
204
+ except KeyboardInterrupt:
205
+ main_loop.stop()
utils/bench_meta.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ from utils.except_dir import cust_listdir
7
+ def get_video_metadata(video_path, category, benchmark):
8
+ """Extract metadata from a video file."""
9
+ cap = cv2.VideoCapture(video_path)
10
+
11
+ if not cap.isOpened():
12
+ return None
13
+ # Extract metadata
14
+ video_name = os.path.basename(video_path)
15
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
16
+ fps = cap.get(cv2.CAP_PROP_FPS)
17
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
18
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
19
+ resolution = f"{frame_width}x{frame_height}"
20
+ duration_seconds = frame_count / fps if fps > 0 else 0
21
+ aspect_ratio = round(frame_width / frame_height, 2) if frame_height > 0 else 0
22
+ file_size = os.path.getsize(video_path) / (1024 * 1024) # MB
23
+ file_format = os.path.splitext(video_name)[1].lower()
24
+ cap.release()
25
+
26
+ return {
27
+ "video_name": video_name,
28
+ "resolution": resolution,
29
+ "video_duration": f"{duration_seconds // 60:.0f}:{duration_seconds % 60:.0f}",
30
+ "category": category,
31
+ "benchmark": benchmark,
32
+ "duration_seconds": duration_seconds,
33
+ "total_frames": frame_count,
34
+ "file_format": file_format,
35
+ "file_size_mb": round(file_size, 2),
36
+ "aspect_ratio": aspect_ratio,
37
+ "fps": fps
38
+ }
39
+
40
+ def process_videos_in_directory(root_dir):
41
+ """Process all videos in the given directory structure."""
42
+ video_metadata_list = []
43
+
44
+ # ๋ฒค์น˜๋งˆํฌ ํด๋”๋“ค์„ ์ˆœํšŒ
45
+ for benchmark in cust_listdir(root_dir):
46
+ benchmark_path = os.path.join(root_dir, benchmark)
47
+ if not os.path.isdir(benchmark_path):
48
+ continue
49
+
50
+ # dataset ํด๋” ๊ฒฝ๋กœ
51
+ dataset_path = os.path.join(benchmark_path, "dataset")
52
+ if not os.path.isdir(dataset_path):
53
+ continue
54
+
55
+ # dataset ํด๋” ์•ˆ์˜ ์นดํ…Œ๊ณ ๋ฆฌ ํด๋”๋“ค์„ ์ˆœํšŒ
56
+ for category in cust_listdir(dataset_path):
57
+ category_path = os.path.join(dataset_path, category)
58
+ if not os.path.isdir(category_path):
59
+ continue
60
+
61
+ # ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ ํด๋” ์•ˆ์˜ ๋น„๋””์˜ค ํŒŒ์ผ๋“ค์„ ์ฒ˜๋ฆฌ
62
+ for file in cust_listdir(category_path):
63
+ file_path = os.path.join(category_path, file)
64
+
65
+ if file_path.lower().endswith(('.mp4', '.avi', '.mkv', '.mov', 'MOV')):
66
+ metadata = get_video_metadata(file_path, category, benchmark)
67
+ if metadata:
68
+ video_metadata_list.append(metadata)
69
+ # df = pd.DataFrame(video_metadata_list)
70
+ # df.to_csv('sample.csv', index=False)
71
+ return pd.DataFrame(video_metadata_list)
72
+
utils/except_dir.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ import enviroments.config as config
4
+
5
+ def cust_listdir(directory: str) -> List[str]:
6
+ """
7
+ os.listdir์™€ ์œ ์‚ฌํ•˜๊ฒŒ ์ž‘๋™ํ•˜์ง€๋งŒ config์— ์ •์˜๋œ ํด๋”/ํŒŒ์ผ๋“ค์„ ์ œ์™ธํ•˜๊ณ  ๋ชฉ๋ก์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
8
+
9
+ Args:
10
+ directory (str): ํƒ์ƒ‰ํ•  ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
11
+
12
+ Returns:
13
+ List[str]: config์˜ EXCLUDE_DIRS์— ์ •์˜๋œ ํด๋”/ํŒŒ์ผ๋“ค์„ ์ œ์™ธํ•œ ๋ชฉ๋ก
14
+ """
15
+ return [item for item in os.listdir(directory) if item not in config.EXCLUDE_DIRS]
utils/hf_api.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ from typing import Optional, List, Dict, Any
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass
6
+ class ModelInfo:
7
+ """๋ชจ๋ธ ์ •๋ณด๋ฅผ ์ €์žฅํ•˜๋Š” ๋ฐ์ดํ„ฐ ํด๋ž˜์Šค"""
8
+ model_id: str
9
+ last_modified: Any
10
+ downloads: int
11
+ private: bool
12
+ attributes: Dict[str, Any]
13
+
14
+ class HuggingFaceInfoManager:
15
+ def __init__(self, access_token: Optional[str] = None, organization: str = "PIA-SPACE-LAB"):
16
+ """
17
+ HuggingFace API ๊ด€๋ฆฌ์ž ํด๋ž˜์Šค ์ดˆ๊ธฐํ™”
18
+
19
+ Args:
20
+ access_token (str, optional): HuggingFace ์•ก์„ธ์Šค ํ† ํฐ
21
+ organization (str): ์กฐ์ง ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: "PIA-SPACE-LAB")
22
+
23
+ Raises:
24
+ ValueError: access_token์ด None์ผ ๊ฒฝ์šฐ ๋ฐœ์ƒ
25
+ """
26
+ if access_token is None:
27
+ raise ValueError("์•ก์„ธ์Šค ํ† ํฐ์€ ํ•„์ˆ˜ ์ž…๋ ฅ๊ฐ’์ž…๋‹ˆ๋‹ค. HuggingFace์—์„œ ๋ฐœ๊ธ‰๋ฐ›์€ ํ† ํฐ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
28
+
29
+ self.api = HfApi()
30
+ self.access_token = access_token
31
+ self.organization = organization
32
+
33
+ # API ํ˜ธ์ถœ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”๋กœ ์ฒ˜๋ฆฌํ•˜์—ฌ ์ €์žฅ
34
+ api_models = self.api.list_models(author=self.organization, use_auth_token=self.access_token)
35
+ self._stored_models = []
36
+ self._model_infos = []
37
+
38
+ # ๋ชจ๋“  ๋ชจ๋ธ ์ •๋ณด๋ฅผ ๋ฏธ๋ฆฌ ์ฒ˜๋ฆฌํ•˜์—ฌ ์ €์žฅ
39
+ for model in api_models:
40
+ # ๊ธฐ๋ณธ ์ •๋ณด ์ €์žฅ
41
+ model_attrs = {}
42
+ for attr in dir(model):
43
+ if not attr.startswith("_"):
44
+ model_attrs[attr] = getattr(model, attr)
45
+
46
+ # ModelInfo ๊ฐ์ฒด ์ƒ์„ฑ ๋ฐ ์ €์žฅ
47
+ model_info = ModelInfo(
48
+ model_id=model.modelId,
49
+ last_modified=model.lastModified,
50
+ downloads=model.downloads,
51
+ private=model.private,
52
+ attributes=model_attrs
53
+ )
54
+ self._model_infos.append(model_info)
55
+ self._stored_models.append(model)
56
+
57
+ def get_model_info(self) -> List[Dict[str, Any]]:
58
+ """๋ชจ๋“  ๋ชจ๋ธ์˜ ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜"""
59
+ return [
60
+ {
61
+ 'model_id': info.model_id,
62
+ 'last_modified': info.last_modified,
63
+ 'downloads': info.downloads,
64
+ 'private': info.private,
65
+ **info.attributes
66
+ }
67
+ for info in self._model_infos
68
+ ]
69
+
70
+ def get_model_ids(self) -> List[str]:
71
+ """๋ชจ๋“  ๋ชจ๋ธ์˜ ID ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜"""
72
+ return [info.model_id for info in self._model_infos]
73
+
74
+ def get_private_models(self) -> List[Dict[str, Any]]:
75
+ """๋น„๊ณต๊ฐœ ๋ชจ๋ธ ์ •๋ณด ๋ฐ˜ํ™˜"""
76
+ return [
77
+ {
78
+ 'model_id': info.model_id,
79
+ 'last_modified': info.last_modified,
80
+ 'downloads': info.downloads,
81
+ 'private': info.private,
82
+ **info.attributes
83
+ }
84
+ for info in self._model_infos if info.private
85
+ ]
86
+
87
+ def get_public_models(self) -> List[Dict[str, Any]]:
88
+ """๊ณต๊ฐœ ๋ชจ๋ธ ์ •๋ณด ๋ฐ˜ํ™˜"""
89
+ return [
90
+ {
91
+ 'model_id': info.model_id,
92
+ 'last_modified': info.last_modified,
93
+ 'downloads': info.downloads,
94
+ 'private': info.private,
95
+ **info.attributes
96
+ }
97
+ for info in self._model_infos if not info.private
98
+ ]
99
+
100
+ def refresh_models(self) -> None:
101
+ """๋ชจ๋ธ ์ •๋ณด ์ƒˆ๋กœ๊ณ ์นจ (์ƒˆ๋กœ์šด API ํ˜ธ์ถœ ์ˆ˜ํ–‰)"""
102
+ # ํด๋ž˜์Šค ์žฌ์ดˆ๊ธฐํ™”
103
+ self.__init__(self.access_token, self.organization)
utils/logger.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+ from logging.handlers import TimedRotatingFileHandler
5
+ from pathlib import Path
6
+
7
+
8
+ def custom_logger(name: str) -> logging.Logger:
9
+ """
10
+ ์ปค์Šคํ…€ ๋กœ๊ฑฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
11
+ ์ฝ˜์†” ํ•ธ๋“ค๋Ÿฌ์™€ ํŒŒ์ผ ํ•ธ๋“ค๋Ÿฌ๊ฐ€ ๋™์‹œ์— ์ž‘๋™ํ•˜๋ฉฐ, ๊ฐ๊ฐ ๋‹ค๋ฅธ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.
12
+
13
+ Args:
14
+ name (str): ๋กœ๊ฑฐ ์ด๋ฆ„ (๋ณดํ†ต __name__ ์‚ฌ์šฉ)
15
+
16
+ Returns:
17
+ logging.Logger: ์„ค์ •๋œ Logger ๊ฐ์ฒด
18
+
19
+ ๋กœ๊ทธ ๋ ˆ๋ฒจ ์„ค์ •:
20
+ - ์ฝ˜์†” ํ•ธ๋“ค๋Ÿฌ: INFO ๋ ˆ๋ฒจ (INFO, WARNING, ERROR, CRITICAL๋งŒ ์ถœ๋ ฅ)
21
+ - ํŒŒ์ผ ํ•ธ๋“ค๋Ÿฌ: DEBUG ๋ ˆ๋ฒจ (๋ชจ๋“  ๋ ˆ๋ฒจ ๊ธฐ๋ก)
22
+ - ๋กœ๊ทธ ํŒŒ์ผ์€ 'logs' ๋””๋ ‰ํ† ๋ฆฌ์— ๋‚ ์งœ๋ณ„๋กœ ์ €์žฅ๋˜๋ฉฐ 10์ผ์น˜๋งŒ ๋ณด๊ด€
23
+
24
+ ์‚ฌ์šฉ ์˜ˆ์‹œ:
25
+ ```python
26
+ logger = custom_logger(__name__)
27
+
28
+ logger.debug("๋””๋ฒ„๊ทธ ๋ฉ”์‹œ์ง€") # ํŒŒ์ผ์—๋งŒ ๊ธฐ๋ก
29
+ logger.info("์ •๋ณด ๋ฉ”์‹œ์ง€") # ์ฝ˜์†” ์ถœ๋ ฅ + ํŒŒ์ผ ๊ธฐ๋ก
30
+ logger.warning("๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€") # ์ฝ˜์†” ์ถœ๋ ฅ + ํŒŒ์ผ ๊ธฐ๋ก
31
+ logger.error("์—๋Ÿฌ ๋ฉ”์‹œ์ง€") # ์ฝ˜์†” ์ถœ๋ ฅ + ํŒŒ์ผ ๊ธฐ๋ก
32
+ ```
33
+
34
+ ์ถœ๋ ฅ ํ˜•์‹:
35
+ ์ฝ˜์†”: [HH:MM:SS] [๋ ˆ๋ฒจ] [๋ชจ๋“ˆ:๋ผ์ธ] [ํ•จ์ˆ˜๋ช…] ๋ฉ”์‹œ์ง€
36
+ ํŒŒ์ผ: [YYYY-MM-DD HH:MM:SS] [๋ ˆ๋ฒจ] [๋ชจ๋“ˆ:๋ผ์ธ] [ํ•จ์ˆ˜๋ช…] ๋ฉ”์‹œ์ง€
37
+ """
38
+ # ๋กœ๊ฑฐ ์ƒ์„ฑ
39
+ logger = logging.getLogger(name)
40
+ logger.setLevel(logging.DEBUG)
41
+
42
+ # ์ด๋ฏธ ํ•ธ๋“ค๋Ÿฌ๊ฐ€ ์žˆ๋‹ค๋ฉด ์ถ”๊ฐ€ํ•˜์ง€ ์•Š์Œ
43
+ if logger.handlers:
44
+ return logger
45
+
46
+ # logs ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
47
+ log_dir = Path("logs")
48
+ log_dir.mkdir(exist_ok=True)
49
+
50
+ # ํฌ๋งทํ„ฐ ์„ค์ •
51
+ console_formatter = logging.Formatter(
52
+ "[%(asctime)s] [%(levelname)s] [%(module)s:%(lineno)d] [%(funcName)s] %(message)s",
53
+ datefmt="%H:%M:%S",
54
+ )
55
+
56
+ file_formatter = logging.Formatter(
57
+ "[%(asctime)s] [%(levelname)s] [%(module)s:%(lineno)d] [%(funcName)s] %(message)s",
58
+ datefmt="%Y-%m-%d %H:%M:%S",
59
+ )
60
+
61
+ # ์ฝ˜์†” ํ•ธ๋“ค๋Ÿฌ
62
+ console_handler = logging.StreamHandler()
63
+ console_handler.setLevel(logging.INFO)
64
+ console_handler.setFormatter(console_formatter)
65
+
66
+ # ํŒŒ์ผ ํ•ธ๋“ค๋Ÿฌ
67
+ today = datetime.now().strftime("%Y%m%d")
68
+ file_handler = TimedRotatingFileHandler(
69
+ filename=log_dir / f"{today}.log",
70
+ when="midnight",
71
+ interval=1,
72
+ backupCount=10, # 10์ผ์น˜๋งŒ ๋ณด๊ด€
73
+ encoding="utf-8",
74
+ )
75
+ file_handler.setLevel(logging.WARNING)
76
+ file_handler.setFormatter(file_formatter)
77
+
78
+ # ํ•ธ๋“ค๋Ÿฌ ์ถ”๊ฐ€
79
+ logger.addHandler(console_handler)
80
+ logger.addHandler(file_handler)
81
+
82
+ # ์˜ค๋ž˜๋œ ๋กœ๊ทธ ํŒŒ์ผ ์ •๋ฆฌ
83
+ cleanup_old_logs(log_dir)
84
+
85
+ return logger
86
+
87
+
88
+ def cleanup_old_logs(log_dir: Path):
89
+ """10๊ฐœ ์ด์ƒ์˜ ๋กœ๊ทธ ํŒŒ์ผ์ด ์žˆ์„ ๊ฒฝ์šฐ ๊ฐ€์žฅ ์˜ค๋ž˜๋œ ๊ฒƒ๋ถ€ํ„ฐ ์‚ญ์ œ"""
90
+ log_files = sorted(log_dir.glob("*.log"), key=os.path.getctime)
91
+ while len(log_files) > 10:
92
+ log_files[0].unlink() # ๊ฐ€์žฅ ์˜ค๋ž˜๋œ ํŒŒ์ผ ์‚ญ์ œ
93
+ log_files = log_files[1:]
utils/parser.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List, Tuple
3
+
4
+ def load_config(config_path: str) -> Dict:
5
+ """
6
+ JSON ์„ค์ • ํŒŒ์ผ์„ ์ฝ์–ด์„œ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
7
+
8
+ Args:
9
+ config_path (str): JSON ์„ค์ • ํŒŒ์ผ์˜ ๊ฒฝ๋กœ
10
+
11
+ Returns:
12
+ Dict: ์„ค์ • ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ๋”•์…”๋„ˆ๋ฆฌ
13
+ """
14
+ with open(config_path, 'r', encoding='utf-8') as f:
15
+ return json.load(f)
16
+
17
+ class PromptManager:
18
+ def __init__(self, config_path: str):
19
+ self.config = load_config(config_path)
20
+ self.sentences, self.index_mapping = self._extract_all_sentences_with_index()
21
+ self.reverse_mapping = self._create_reverse_mapping()
22
+
23
+ def _extract_all_sentences_with_index(self) -> Tuple[List[str], Dict]:
24
+ """๋ชจ๋“  sentence์™€ ์ธ๋ฑ์Šค ๋งคํ•‘ ์ถ”์ถœ"""
25
+ sentences = []
26
+ index_mapping = {}
27
+
28
+ for event_idx, event_config in enumerate(self.config.get('PROMPT_CFG', [])):
29
+ prompts = event_config.get('prompts', {})
30
+ for status in ['normal', 'abnormal']:
31
+ for prompt_idx, prompt in enumerate(prompts.get(status, [])):
32
+ sentence = prompt.get('sentence', '')
33
+ sentences.append(sentence)
34
+ index_mapping[(event_idx, status, prompt_idx)] = sentence
35
+
36
+ return sentences, index_mapping
37
+
38
+ def _create_reverse_mapping(self) -> Dict:
39
+ """sentence -> indices ์—ญ๋ฐฉํ–ฅ ๋งคํ•‘ ์ƒ์„ฑ"""
40
+ reverse_map = {}
41
+ for indices, sent in self.index_mapping.items():
42
+ if sent not in reverse_map:
43
+ reverse_map[sent] = []
44
+ reverse_map[sent].append(indices)
45
+ return reverse_map
46
+
47
+ def get_sentence_indices(self, sentence: str) -> List[Tuple[int, str, int]]:
48
+ """ํŠน์ • sentence์˜ ๋ชจ๋“  ์ธ๋ฑ์Šค ์œ„์น˜ ๋ฐ˜ํ™˜"""
49
+ return self.reverse_mapping.get(sentence, [])
50
+
51
+ def get_details_by_sentence(self, sentence: str) -> List[Dict]:
52
+ """sentence๋กœ ๋ชจ๋“  ๊ด€๋ จ ์ƒ์„ธ ์ •๋ณด ์ฐพ์•„ ๋ฐ˜ํ™˜"""
53
+ indices = self.get_sentence_indices(sentence)
54
+ return [self.get_details_by_index(*idx) for idx in indices]
55
+
56
+ def get_details_by_index(self, event_idx: int, status: str, prompt_idx: int) -> Dict:
57
+ """์ธ๋ฑ์Šค๋กœ ์ƒ์„ธ ์ •๋ณด ์ฐพ์•„ ๋ฐ˜ํ™˜"""
58
+ event_config = self.config['PROMPT_CFG'][event_idx]
59
+ prompt = event_config['prompts'][status][prompt_idx]
60
+
61
+ return {
62
+ 'event': event_config['event'],
63
+ 'status': status,
64
+ 'sentence': prompt['sentence'],
65
+ 'top_candidates': event_config['top_candidates'],
66
+ 'alert_threshold': event_config['alert_threshold'],
67
+ 'event_idx': event_idx,
68
+ 'prompt_idx': prompt_idx
69
+ }
70
+
71
+ def get_all_sentences(self) -> List[str]:
72
+ """๋ชจ๋“  sentence ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜"""
73
+ return self.sentences