Spaces:
Build error
Build error
feat]
Browse files- .gitignore +183 -0
- LICENSE +201 -0
- README copy.md +1 -0
- app.py +17 -0
- enviroments/.gitkeep +0 -0
- enviroments/config.py +49 -0
- enviroments/convert.py +54 -0
- leaderboard_ui/tab/dataset_visual_tab.py +160 -0
- leaderboard_ui/tab/food_map_tab.py +96 -0
- leaderboard_ui/tab/leaderboard_tab.py +72 -0
- leaderboard_ui/tab/map_tab.py +50 -0
- leaderboard_ui/tab/metric_visaul_tab.py +418 -0
- leaderboard_ui/tab/submit_tab.py +103 -0
- pia_bench/bench.py +219 -0
- pia_bench/checker/bench_checker.py +184 -0
- pia_bench/checker/sheet_checker.py +284 -0
- pia_bench/event_alarm.py +227 -0
- pia_bench/metric.py +322 -0
- pia_bench/pipe_line/piepline.py +229 -0
- sheet_manager/sheet_checker/sheet_check.py +140 -0
- sheet_manager/sheet_convert/json2sheet.py +117 -0
- sheet_manager/sheet_crud/create_col.py +76 -0
- sheet_manager/sheet_crud/sheet_crud.py +347 -0
- sheet_manager/sheet_loader/sheet2df.py +70 -0
- sheet_manager/sheet_monitor/sheet_sync.py +205 -0
- utils/bench_meta.py +72 -0
- utils/except_dir.py +15 -0
- utils/hf_api.py +103 -0
- utils/logger.py +93 -0
- utils/parser.py +73 -0
.gitignore
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# PyPI configuration file
|
| 171 |
+
.pypirc
|
| 172 |
+
|
| 173 |
+
# *.json
|
| 174 |
+
|
| 175 |
+
assets
|
| 176 |
+
DevMACS-AI-solution-devmacs
|
| 177 |
+
Research-AI-research-t2v_f1score_evaluator
|
| 178 |
+
.env
|
| 179 |
+
enviroments/abnormal-situation-leaderboard-3ca42d06719e.json
|
| 180 |
+
leaderboard_test
|
| 181 |
+
enviroments/deep-byte-352904-a072fdf439e7.json
|
| 182 |
+
PIA-SPACE_LeaderBoard
|
| 183 |
+
*.txt
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README copy.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# MAP
|
app.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from leaderboard_ui.tab.leaderboard_tab import maat_gag_tab
|
| 4 |
+
from leaderboard_ui.tab.food_map_tab import map_food_tab
|
| 5 |
+
|
| 6 |
+
abs_path = Path(__file__).parent
|
| 7 |
+
|
| 8 |
+
with gr.Blocks() as demo:
|
| 9 |
+
gr.Markdown("""
|
| 10 |
+
# ๐บ๏ธ ๋ง๊ฐ ๐บ๏ธ
|
| 11 |
+
""")
|
| 12 |
+
with gr.Tabs():
|
| 13 |
+
maat_gag_tab()
|
| 14 |
+
map_food_tab()
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
demo.launch()
|
enviroments/.gitkeep
ADDED
|
File without changes
|
enviroments/config.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
BASE_BENCH_PATH = "/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
|
| 3 |
+
EXCLUDE_DIRS = {"@eaDir", 'temp'}
|
| 4 |
+
TYPES = [
|
| 5 |
+
"str",
|
| 6 |
+
"str",
|
| 7 |
+
"str",
|
| 8 |
+
"str",
|
| 9 |
+
"str",
|
| 10 |
+
"str",
|
| 11 |
+
"str",
|
| 12 |
+
"number",
|
| 13 |
+
"number",
|
| 14 |
+
"number",
|
| 15 |
+
"number",
|
| 16 |
+
"number",
|
| 17 |
+
"markdown",
|
| 18 |
+
"markdown",
|
| 19 |
+
"number",
|
| 20 |
+
"number",
|
| 21 |
+
"number",
|
| 22 |
+
"number",
|
| 23 |
+
"number",
|
| 24 |
+
"number",
|
| 25 |
+
"number",
|
| 26 |
+
"str",
|
| 27 |
+
"str",
|
| 28 |
+
"str",
|
| 29 |
+
"str",
|
| 30 |
+
"bool",
|
| 31 |
+
"str",
|
| 32 |
+
"number",
|
| 33 |
+
"number",
|
| 34 |
+
"bool",
|
| 35 |
+
"str",
|
| 36 |
+
"bool",
|
| 37 |
+
"bool",
|
| 38 |
+
"str",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
ON_LOAD_COLUMNS = [
|
| 42 |
+
"๋ถ๋ฅ",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
OFF_LOAD_COLUMNS = []
|
| 46 |
+
HIDE_COLUMNS = ["๋ค์ด๋ฒ๋ณ์ *100" , "์นด์นด์ค๋ณ์ *100"]
|
| 47 |
+
FILTER_COLUMNS = ["์๋น๋ช
"]
|
| 48 |
+
|
| 49 |
+
NUMERIC_COLUMNS = ["PIA"]
|
enviroments/convert.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
def get_json_from_env_var(env_var_name):
|
| 7 |
+
"""
|
| 8 |
+
ํ๊ฒฝ ๋ณ์์์ JSON ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ ๋์
๋๋ฆฌ๋ก ๋ณํํ๋ ํจ์.
|
| 9 |
+
:param env_var_name: ํ๊ฒฝ ๋ณ์ ์ด๋ฆ
|
| 10 |
+
:return: ๋์
๋๋ฆฌ ํํ์ JSON ๋ฐ์ดํฐ
|
| 11 |
+
"""
|
| 12 |
+
json_string = os.getenv(env_var_name)
|
| 13 |
+
if not json_string:
|
| 14 |
+
raise EnvironmentError(f"ํ๊ฒฝ ๋ณ์ '{env_var_name}'๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.")
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
# ์ค๋ฐ๊ฟ(\n)์ ์ด์ค์ผ์ดํ ๋ฌธ์(\\n)๋ก ๋ณํ
|
| 18 |
+
json_string = json_string.replace("\n", "\\n")
|
| 19 |
+
|
| 20 |
+
# JSON ๋ฌธ์์ด์ ๋์
๋๋ฆฌ๋ก ๋ณํ
|
| 21 |
+
json_data = json.loads(json_string)
|
| 22 |
+
except json.JSONDecodeError as e:
|
| 23 |
+
raise ValueError(f"JSON ๋ณํ ์คํจ: {e}")
|
| 24 |
+
|
| 25 |
+
return json_data
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def json_to_env_var(json_file_path, env_var_name="JSON_ENV_VAR"):
|
| 30 |
+
"""
|
| 31 |
+
์ฃผ์ด์ง JSON ํ์ผ์ ๋ฐ์ดํฐ๋ฅผ ํ๊ฒฝ ๋ณ์ ํํ๋ก ๋ณํํ์ฌ ์ถ๋ ฅํ๋ ํจ์.
|
| 32 |
+
|
| 33 |
+
:param json_file_path: JSON ํ์ผ ๊ฒฝ๋ก
|
| 34 |
+
:param env_var_name: ํ๊ฒฝ ๋ณ์ ์ด๋ฆ (๊ธฐ๋ณธ๊ฐ: JSON_ENV_VAR)
|
| 35 |
+
:return: None
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
# JSON ํ์ผ ์ฝ๊ธฐ
|
| 39 |
+
with open(json_file_path, 'r') as json_file:
|
| 40 |
+
json_data = json.load(json_file)
|
| 41 |
+
|
| 42 |
+
# JSON ๋ฐ์ดํฐ๋ฅผ ๋ฌธ์์ด๋ก ๋ณํ
|
| 43 |
+
json_string = json.dumps(json_data)
|
| 44 |
+
|
| 45 |
+
# ํ๊ฒฝ ๋ณ์ ํํ๋ก ์ถ๋ ฅ
|
| 46 |
+
env_variable = f'{env_var_name}={json_string}'
|
| 47 |
+
print("\nํ๊ฒฝ ๋ณ์๋ก ์ฌ์ฉํ ์ ์๋ ์ถ๋ ฅ๊ฐ:\n")
|
| 48 |
+
print(env_variable)
|
| 49 |
+
print("\n์ ๊ฐ์ .env ํ์ผ์ ๋ณต์ฌํ์ฌ ๋ถ์ฌ๋ฃ์ผ์ธ์.")
|
| 50 |
+
except FileNotFoundError:
|
| 51 |
+
print(f"ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {json_file_path}")
|
| 52 |
+
except json.JSONDecodeError:
|
| 53 |
+
print(f"์ ํจํ JSON ํ์ผ์ด ์๋๋๋ค: {json_file_path}")
|
| 54 |
+
|
leaderboard_ui/tab/dataset_visual_tab.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from leaderboard_ui.tab.submit_tab import submit_tab
|
| 4 |
+
from leaderboard_ui.tab.leaderboard_tab import leaderboard_tab
|
| 5 |
+
abs_path = Path(__file__).parent
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from utils.bench_meta import process_videos_in_directory
|
| 11 |
+
# Mock ๋ฐ์ดํฐ ์์ฑ
|
| 12 |
+
def create_mock_data():
|
| 13 |
+
benchmarks = ['VQA-2023', 'ImageQuality-2024', 'VideoEnhance-2024']
|
| 14 |
+
categories = ['Animation', 'Game', 'Movie', 'Sports', 'Vlog']
|
| 15 |
+
|
| 16 |
+
data_list = []
|
| 17 |
+
|
| 18 |
+
for benchmark in benchmarks:
|
| 19 |
+
n_videos = np.random.randint(50, 100)
|
| 20 |
+
for _ in range(n_videos):
|
| 21 |
+
category = np.random.choice(categories)
|
| 22 |
+
|
| 23 |
+
data_list.append({
|
| 24 |
+
"video_name": f"video_{np.random.randint(1000, 9999)}.mp4",
|
| 25 |
+
"resolution": np.random.choice(["1920x1080", "3840x2160", "1280x720"]),
|
| 26 |
+
"video_duration": f"{np.random.randint(0, 10)}:{np.random.randint(0, 60)}",
|
| 27 |
+
"category": category,
|
| 28 |
+
"benchmark": benchmark,
|
| 29 |
+
"duration_seconds": np.random.randint(30, 600),
|
| 30 |
+
"total_frames": np.random.randint(1000, 10000),
|
| 31 |
+
"file_format": ".mp4",
|
| 32 |
+
"file_size_mb": round(np.random.uniform(10, 1000), 2),
|
| 33 |
+
"aspect_ratio": 16/9,
|
| 34 |
+
"fps": np.random.choice([24, 30, 60])
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
return pd.DataFrame(data_list)
|
| 38 |
+
|
| 39 |
+
# Mock ๋ฐ์ดํฐ ์์ฑ
|
| 40 |
+
# df = process_videos_in_directory("/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench")
|
| 41 |
+
df = pd.read_csv("sample.csv")
|
| 42 |
+
print("DataFrame shape:", df.shape)
|
| 43 |
+
print("DataFrame columns:", df.columns)
|
| 44 |
+
print("DataFrame head:\n", df.head())
|
| 45 |
+
def create_category_pie_chart(df, selected_benchmark, selected_categories=None):
|
| 46 |
+
filtered_df = df[df['benchmark'] == selected_benchmark]
|
| 47 |
+
|
| 48 |
+
if selected_categories:
|
| 49 |
+
filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
|
| 50 |
+
|
| 51 |
+
category_counts = filtered_df['category'].value_counts()
|
| 52 |
+
|
| 53 |
+
fig = px.pie(
|
| 54 |
+
values=category_counts.values,
|
| 55 |
+
names=category_counts.index,
|
| 56 |
+
title=f'{selected_benchmark} - Video Distribution by Category',
|
| 57 |
+
hole=0.3
|
| 58 |
+
)
|
| 59 |
+
fig.update_traces(textposition='inside', textinfo='percent+label')
|
| 60 |
+
|
| 61 |
+
return fig
|
| 62 |
+
|
| 63 |
+
###TODO ์คํธ๋ง์ผ๊ฒฝ์ฐ ์ด์ผ ์ฒ๋ฆฌ
|
| 64 |
+
|
| 65 |
+
def create_bar_chart(df, selected_benchmark, selected_categories, selected_column):
|
| 66 |
+
# Filter by benchmark and categories
|
| 67 |
+
filtered_df = df[df['benchmark'] == selected_benchmark]
|
| 68 |
+
if selected_categories:
|
| 69 |
+
filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
|
| 70 |
+
|
| 71 |
+
# Create bar chart for selected column
|
| 72 |
+
fig = px.bar(
|
| 73 |
+
filtered_df,
|
| 74 |
+
x=selected_column,
|
| 75 |
+
y='video_name',
|
| 76 |
+
color='category', # Color by category
|
| 77 |
+
title=f'{selected_benchmark} - Video {selected_column}',
|
| 78 |
+
orientation='h', # Horizontal bar chart
|
| 79 |
+
color_discrete_sequence=px.colors.qualitative.Set3 # Color palette
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Adjust layout
|
| 83 |
+
fig.update_layout(
|
| 84 |
+
height=max(400, len(filtered_df) * 30), # Adjust height based on data
|
| 85 |
+
yaxis={'categoryorder': 'total ascending'}, # Sort by value
|
| 86 |
+
margin=dict(l=200), # Margin for long video names
|
| 87 |
+
showlegend=True, # Show legend
|
| 88 |
+
legend=dict(
|
| 89 |
+
orientation="h", # Horizontal legend
|
| 90 |
+
yanchor="bottom",
|
| 91 |
+
y=1.02, # Place legend above graph
|
| 92 |
+
xanchor="right",
|
| 93 |
+
x=1
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return fig
|
| 98 |
+
|
| 99 |
+
def submit_tab():
|
| 100 |
+
with gr.Tab("๐ Submit here! "):
|
| 101 |
+
with gr.Row():
|
| 102 |
+
gr.Markdown("# โ๏ธโจ Submit your Result here!")
|
| 103 |
+
|
| 104 |
+
def visual_tab():
|
| 105 |
+
with gr.Tab("๐ Bench Info"):
|
| 106 |
+
with gr.Row():
|
| 107 |
+
benchmark_dropdown = gr.Dropdown(
|
| 108 |
+
choices=sorted(df['benchmark'].unique().tolist()),
|
| 109 |
+
value=sorted(df['benchmark'].unique().tolist())[0],
|
| 110 |
+
label="Select Benchmark",
|
| 111 |
+
interactive=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
category_multiselect = gr.CheckboxGroup(
|
| 115 |
+
choices=sorted(df['category'].unique().tolist()),
|
| 116 |
+
label="Select Categories (empty for all)",
|
| 117 |
+
interactive=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Pie chart
|
| 121 |
+
pie_plot_output = gr.Plot(label="pie")
|
| 122 |
+
|
| 123 |
+
# Column selection dropdown
|
| 124 |
+
column_options = [
|
| 125 |
+
"video_duration", "duration_seconds", "total_frames",
|
| 126 |
+
"file_size_mb", "aspect_ratio", "fps", "file_format"
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
column_dropdown = gr.Dropdown(
|
| 130 |
+
choices=column_options,
|
| 131 |
+
value=column_options[0],
|
| 132 |
+
label="Select Data to Compare",
|
| 133 |
+
interactive=True
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Bar chart
|
| 137 |
+
bar_plot_output = gr.Plot(label="video")
|
| 138 |
+
|
| 139 |
+
def update_plots(benchmark, categories, selected_column):
|
| 140 |
+
pie_chart = create_category_pie_chart(df, benchmark, categories)
|
| 141 |
+
bar_chart = create_bar_chart(df, benchmark, categories, selected_column)
|
| 142 |
+
return pie_chart, bar_chart
|
| 143 |
+
|
| 144 |
+
# Connect event handlers
|
| 145 |
+
benchmark_dropdown.change(
|
| 146 |
+
fn=update_plots,
|
| 147 |
+
inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
|
| 148 |
+
outputs=[pie_plot_output, bar_plot_output]
|
| 149 |
+
)
|
| 150 |
+
category_multiselect.change(
|
| 151 |
+
fn=update_plots,
|
| 152 |
+
inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
|
| 153 |
+
outputs=[pie_plot_output, bar_plot_output]
|
| 154 |
+
)
|
| 155 |
+
column_dropdown.change(
|
| 156 |
+
fn=update_plots,
|
| 157 |
+
inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
|
| 158 |
+
outputs=[pie_plot_output, bar_plot_output]
|
| 159 |
+
)
|
| 160 |
+
|
leaderboard_ui/tab/food_map_tab.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from gradio_leaderboard import Leaderboard, SelectColumns, ColumnFilter,SearchColumns
|
| 8 |
+
import enviroments.config as config
|
| 9 |
+
from sheet_manager.sheet_loader.sheet2df import sheet2df, add_scaled_columns
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
NAVER_CLIENT_ID = os.getenv("NAVER_KEY")
|
| 13 |
+
NAVER_CLIENT_SECRET = os.getenv("NAVER_SECRET_KEY")
|
| 14 |
+
|
| 15 |
+
def get_lat_lon_naver(address: str):
|
| 16 |
+
"""
|
| 17 |
+
๋ค์ด๋ฒ Geocoding API๋ฅผ ์ฌ์ฉํด ์ฃผ์๋ฅผ ์๋/๊ฒฝ๋๋ก ๋ณํํ๋ ํจ์
|
| 18 |
+
|
| 19 |
+
:param address: ๋ณํํ ์ฃผ์ (์: "์์ธ ์์ด๊ตฌ ๋จ๋ถ์ํ๋ก358๊ธธ 8")
|
| 20 |
+
:return: (์๋, ๊ฒฝ๋) ํํ
|
| 21 |
+
"""
|
| 22 |
+
url = "https://naveropenapi.apigw.ntruss.com/map-geocode/v2/geocode"
|
| 23 |
+
headers = {
|
| 24 |
+
"X-NCP-APIGW-API-KEY-ID": NAVER_CLIENT_ID,
|
| 25 |
+
"X-NCP-APIGW-API-KEY": NAVER_CLIENT_SECRET
|
| 26 |
+
}
|
| 27 |
+
params = {"query": address}
|
| 28 |
+
|
| 29 |
+
response = requests.get(url, headers=headers, params=params)
|
| 30 |
+
|
| 31 |
+
if response.status_code == 200:
|
| 32 |
+
data = response.json()
|
| 33 |
+
if "addresses" in data and len(data["addresses"]) > 0:
|
| 34 |
+
lat = float(data["addresses"][0]["y"]) # ์๋
|
| 35 |
+
lon = float(data["addresses"][0]["x"]) # ๊ฒฝ๋
|
| 36 |
+
return lat, lon
|
| 37 |
+
return None, None # ๋ณํ ์คํจ ์ None ๋ฐํ
|
| 38 |
+
|
| 39 |
+
df = sheet2df(sheet_name="์์ธ")
|
| 40 |
+
for i in ["๋ค์ด๋ฒ๋ณ์ ", "์นด์นด์ค๋ณ์ "]:
|
| 41 |
+
df = add_scaled_columns(df, i)
|
| 42 |
+
|
| 43 |
+
df[['์๋', '๊ฒฝ๋']] = df['์ฃผ์'].apply(lambda x: pd.Series(get_lat_lon_naver(x)))
|
| 44 |
+
df = df.dropna(subset=["์๋", "๊ฒฝ๋"])
|
| 45 |
+
df["์๋"] = df["์๋"].astype(float)
|
| 46 |
+
df["๊ฒฝ๋"] = df["๊ฒฝ๋"].astype(float)
|
| 47 |
+
print(df[["์๋น๋ช
", "์ฃผ์", "์๋", "๊ฒฝ๋"]]) # โ
๋ฐ์ดํฐ ํ์ธ
|
| 48 |
+
print(df[["์๋", "๊ฒฝ๋"]].dtypes) # โ
๋ฐ์ดํฐ ํ์
ํ์ธ
|
| 49 |
+
|
| 50 |
+
def plot_map(df):
|
| 51 |
+
"""
|
| 52 |
+
Plotly๋ฅผ ์ด์ฉํด ๋ง์ง ์์น๋ฅผ ์ง๋์ ์๊ฐํํ๋ ํจ์
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# ๐น ์ง๋ ์ค์ฌ์ ๋ฐ์ดํฐ ํ๊ท ๊ฐ์ผ๋ก ์ค์
|
| 56 |
+
center_lat = df["์๋"].mean()
|
| 57 |
+
center_lon = df["๊ฒฝ๋"].mean()
|
| 58 |
+
|
| 59 |
+
fig = go.Figure(go.Scattermapbox(
|
| 60 |
+
lat=df["์๋"].tolist(),
|
| 61 |
+
lon=df["๊ฒฝ๋"].tolist(),
|
| 62 |
+
mode='markers',
|
| 63 |
+
marker=go.scattermapbox.Marker(
|
| 64 |
+
size=10, # ๋ง์ปค ํฌ๊ธฐ ์ค์
|
| 65 |
+
color="red", # ๋ง์ปค ์์ ์ค์
|
| 66 |
+
opacity=0.7
|
| 67 |
+
),
|
| 68 |
+
customdata=df["์๋น๋ช
"].tolist(), # ์๋น๋ช
ํ์
|
| 69 |
+
hoverinfo="text",
|
| 70 |
+
hovertemplate="<b>์๋น๋ช
</b>: %{customdata}<br>" # ๋ง์ฐ์ค ์ค๋ฒ ์ ํ์๋ ํ
์คํธ
|
| 71 |
+
))
|
| 72 |
+
|
| 73 |
+
fig.update_layout(
|
| 74 |
+
mapbox_style="open-street-map", # โ
๋ฌด๋ฃ OpenStreetMap ์ฌ์ฉ
|
| 75 |
+
hovermode='closest',
|
| 76 |
+
mapbox=dict(
|
| 77 |
+
bearing=0,
|
| 78 |
+
center=go.layout.mapbox.Center(
|
| 79 |
+
lat=center_lat,
|
| 80 |
+
lon=center_lon
|
| 81 |
+
),
|
| 82 |
+
pitch=0,
|
| 83 |
+
zoom=12 # ์ด๊ธฐ ์ค ๋ ๋ฒจ ์ค์
|
| 84 |
+
),
|
| 85 |
+
margin={"r":0, "t":0, "l":0, "b":0}
|
| 86 |
+
)
|
| 87 |
+
return fig
|
| 88 |
+
|
| 89 |
+
# โ
Gradio UI
|
| 90 |
+
def map_interface():
|
| 91 |
+
return plot_map(df)
|
| 92 |
+
|
| 93 |
+
def map_food_tab():
|
| 94 |
+
with gr.Tab("๐์ง๋๐"):
|
| 95 |
+
gr.Markdown("# ๐ ๋ง์ง ์ง๋ ์๊ฐํ")
|
| 96 |
+
map_plot = gr.Plot(map_interface)
|
leaderboard_ui/tab/leaderboard_tab.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from gradio_leaderboard import Leaderboard, SelectColumns, ColumnFilter,SearchColumns
|
| 3 |
+
import enviroments.config as config
|
| 4 |
+
from sheet_manager.sheet_loader.sheet2df import sheet2df, add_scaled_columns
|
| 5 |
+
|
| 6 |
+
df = sheet2df(sheet_name="์์ธ")
|
| 7 |
+
for i in ["๋ค์ด๋ฒ๋ณ์ ", "์นด์นด์ค๋ณ์ "]:
|
| 8 |
+
df = add_scaled_columns(df, i)
|
| 9 |
+
columns = df.columns.tolist()
|
| 10 |
+
print(columns)
|
| 11 |
+
df.columns = df.columns.str.strip()
|
| 12 |
+
print(df.columns.tolist())
|
| 13 |
+
# print(df["๋ค์ด๋ฒ๋ณ์ *100"])
|
| 14 |
+
|
| 15 |
+
def maat_gag_tab():
|
| 16 |
+
with gr.Tab("๐๋ง๊ฐ๋ชจ์๐"):
|
| 17 |
+
leaderboard = Leaderboard(
|
| 18 |
+
value=df,
|
| 19 |
+
select_columns=SelectColumns(
|
| 20 |
+
# default_selection=config.ON_LOAD_COLUMNS,
|
| 21 |
+
default_selection=columns,
|
| 22 |
+
cant_deselect=config.OFF_LOAD_COLUMNS,
|
| 23 |
+
label="Select Columns to Display:",
|
| 24 |
+
info="Check"
|
| 25 |
+
),
|
| 26 |
+
|
| 27 |
+
search_columns=SearchColumns(
|
| 28 |
+
primary_column="์๋น๋ช
",
|
| 29 |
+
secondary_columns=["๋ํ๋ฉ๋ด"],
|
| 30 |
+
placeholder="Search",
|
| 31 |
+
label="Search"
|
| 32 |
+
),
|
| 33 |
+
hide_columns=config.HIDE_COLUMNS,
|
| 34 |
+
filter_columns=[
|
| 35 |
+
ColumnFilter(
|
| 36 |
+
column="์นด์นด์ค๋ณ์ *100",
|
| 37 |
+
type="slider",
|
| 38 |
+
min=0, # 77
|
| 39 |
+
max=500, # 92
|
| 40 |
+
# default=[min_val, max_val],
|
| 41 |
+
default = [400 ,500],
|
| 42 |
+
label="์นด์นด์ค๋ณ์ " # ์ค์ ๊ฐ์ 100๋ฐฐ๋ก ํ์๋จ,
|
| 43 |
+
),
|
| 44 |
+
ColumnFilter(
|
| 45 |
+
column="๋ค์ด๋ฒ๋ณ์ *100",
|
| 46 |
+
type="slider",
|
| 47 |
+
min=0, # 77
|
| 48 |
+
max=500, # 92
|
| 49 |
+
# default=[min_val, max_val],
|
| 50 |
+
default = [400 ,500],
|
| 51 |
+
label="๋ค์ด๋ฒ๋ณ์ " # ์ค์ ๊ฐ์ 100๋ฐฐ๋ก ํ์๋จ,
|
| 52 |
+
),
|
| 53 |
+
ColumnFilter(
|
| 54 |
+
column= "์ง์ญ",
|
| 55 |
+
label="์ง์ญ"
|
| 56 |
+
)
|
| 57 |
+
],
|
| 58 |
+
|
| 59 |
+
datatype=config.TYPES,
|
| 60 |
+
# column_widths=["33%", "10%"],
|
| 61 |
+
)
|
| 62 |
+
refresh_button = gr.Button("๐ Refresh Leaderboard")
|
| 63 |
+
|
| 64 |
+
def refresh_leaderboard():
|
| 65 |
+
return sheet2df()
|
| 66 |
+
|
| 67 |
+
refresh_button.click(
|
| 68 |
+
refresh_leaderboard,
|
| 69 |
+
inputs=[],
|
| 70 |
+
outputs=leaderboard,
|
| 71 |
+
)
|
| 72 |
+
|
leaderboard_ui/tab/map_tab.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
load_dotenv()
|
| 6 |
+
# โ
๋ค์ด๋ฒ API ํค ๋ถ๋ฌ์ค๊ธฐ
|
| 7 |
+
NAVER_CLIENT_ID = os.getenv("NAVER_KEY")
|
| 8 |
+
NAVER_CLIENT_SECRET = os.getenv("NAVER_SECRET_KEY")
|
| 9 |
+
|
| 10 |
+
def get_lat_lon_naver(address: str):
|
| 11 |
+
"""
|
| 12 |
+
๋ค์ด๋ฒ Geocoding API๋ฅผ ์ฌ์ฉํด ์ฃผ์๋ฅผ ์๋/๊ฒฝ๋๋ก ๋ณํํ๋ ํจ์
|
| 13 |
+
|
| 14 |
+
:param address: ๋ณํํ ์ฃผ์ (์: "์์ธ ์์ด๊ตฌ ๋จ๋ถ์ํ๋ก358๊ธธ 8")
|
| 15 |
+
:return: (์๋, ๊ฒฝ๋) ํํ
|
| 16 |
+
"""
|
| 17 |
+
url = "https://naveropenapi.apigw.ntruss.com/map-geocode/v2/geocode"
|
| 18 |
+
headers = {
|
| 19 |
+
"X-NCP-APIGW-API-KEY-ID": NAVER_CLIENT_ID,
|
| 20 |
+
"X-NCP-APIGW-API-KEY": NAVER_CLIENT_SECRET
|
| 21 |
+
}
|
| 22 |
+
params = {"query": address}
|
| 23 |
+
|
| 24 |
+
response = requests.get(url, headers=headers, params=params)
|
| 25 |
+
|
| 26 |
+
if response.status_code == 200:
|
| 27 |
+
data = response.json()
|
| 28 |
+
if "addresses" in data and len(data["addresses"]) > 0:
|
| 29 |
+
lat = float(data["addresses"][0]["y"]) # ์๋
|
| 30 |
+
lon = float(data["addresses"][0]["x"]) # ๊ฒฝ๋
|
| 31 |
+
return lat, lon
|
| 32 |
+
return None, None # ๋ณํ ์คํจ ์ None ๋ฐํ
|
| 33 |
+
|
| 34 |
+
# โ
ํ
์คํธ ๋ฐ์ดํฐํ๋ ์ (์๋น ๋ชฉ๋ก)
|
| 35 |
+
df = pd.DataFrame({
|
| 36 |
+
"์๋น๋ช
": ["์๋์กฑ๋ฐ", "๊ณ ๋์", "๋ํฌ", "์์ง๋ก๋ณด์", "์๋ฐ์ฐ"],
|
| 37 |
+
"์ฃผ์": [
|
| 38 |
+
"์์ธ ์์ด๊ตฌ ๋จ๋ถ์ํ๋ก358๊ธธ 8",
|
| 39 |
+
"์์ธ ์กํ๊ตฌ ๋ฐฑ์ ๊ณ ๋ถ๋ก45๊ธธ 28",
|
| 40 |
+
"์์ธ ์ฑ๋๊ตฌ ์์ธ์ฒ4๊ธธ 18-8",
|
| 41 |
+
"์์ธ ์ค๊ตฌ ๋ง๋ฅธ๋ด๋ก 11-10",
|
| 42 |
+
"์์ธ ์๋๋ฌธ๊ตฌ ๊ฐ์ข๋ก 36"
|
| 43 |
+
]
|
| 44 |
+
})
|
| 45 |
+
|
| 46 |
+
# โ
์ฃผ์๋ฅผ ์๋/๊ฒฝ๋๋ก ๋ณํํ์ฌ ๋ฐ์ดํฐํ๋ ์์ ์ถ๊ฐ
|
| 47 |
+
df[['์๋', '๊ฒฝ๋']] = df['์ฃผ์'].apply(lambda x: pd.Series(get_lat_lon_naver(x)))
|
| 48 |
+
|
| 49 |
+
# โ
๋ณํ๋ ๋ฐ์ดํฐ ์ถ๋ ฅ
|
| 50 |
+
print(df)
|
leaderboard_ui/tab/metric_visaul_tab.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
abs_path = Path(__file__).parent
|
| 4 |
+
import plotly.express as px
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sheet_manager.sheet_loader.sheet2df import sheet2df
|
| 9 |
+
from sheet_manager.sheet_convert.json2sheet import str2json
|
| 10 |
+
# Mock ๋ฐ์ดํฐ ์์ฑ
|
| 11 |
+
def calculate_avg_metrics(df):
|
| 12 |
+
"""
|
| 13 |
+
๊ฐ ๋ชจ๋ธ์ ์นดํ
๊ณ ๋ฆฌ๋ณ ํ๊ท ์ฑ๋ฅ ์งํ๋ฅผ ๊ณ์ฐ
|
| 14 |
+
"""
|
| 15 |
+
metrics_data = []
|
| 16 |
+
|
| 17 |
+
for _, row in df.iterrows():
|
| 18 |
+
model_name = row['Model name']
|
| 19 |
+
|
| 20 |
+
# PIA๊ฐ ๋น์ด์๊ฑฐ๋ ๋ค๋ฅธ ๊ฐ์ธ ๊ฒฝ์ฐ ๊ฑด๋๋ฐ๊ธฐ
|
| 21 |
+
if pd.isna(row['PIA']) or not isinstance(row['PIA'], str):
|
| 22 |
+
print(f"Skipping model {model_name}: Invalid PIA data")
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
metrics = str2json(row['PIA'])
|
| 27 |
+
|
| 28 |
+
# metrics๊ฐ None์ด๊ฑฐ๋ dict๊ฐ ์๋ ๊ฒฝ์ฐ ๊ฑด๋๋ฐ๊ธฐ
|
| 29 |
+
if not metrics or not isinstance(metrics, dict):
|
| 30 |
+
print(f"Skipping model {model_name}: Invalid JSON format")
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# ํ์ํ ์นดํ
๊ณ ๋ฆฌ๊ฐ ๋ชจ๋ ์๋์ง ํ์ธ
|
| 34 |
+
required_categories = ['falldown', 'violence', 'fire']
|
| 35 |
+
if not all(cat in metrics for cat in required_categories):
|
| 36 |
+
print(f"Skipping model {model_name}: Missing required categories")
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
# ํ์ํ ๋ฉํธ๋ฆญ์ด ๋ชจ๋ ์๋์ง ํ์ธ
|
| 40 |
+
required_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 41 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
|
| 42 |
+
|
| 43 |
+
avg_metrics = {}
|
| 44 |
+
for metric in required_metrics:
|
| 45 |
+
try:
|
| 46 |
+
values = [metrics[cat][metric] for cat in required_categories
|
| 47 |
+
if metric in metrics[cat]]
|
| 48 |
+
if values: # ๊ฐ์ด ์๋ ๊ฒฝ์ฐ๋ง ํ๊ท ๊ณ์ฐ
|
| 49 |
+
avg_metrics[metric] = sum(values) / len(values)
|
| 50 |
+
else:
|
| 51 |
+
avg_metrics[metric] = 0 # ๋๋ ๋ค๋ฅธ ๊ธฐ๋ณธ๊ฐ ์ค์
|
| 52 |
+
except (KeyError, TypeError) as e:
|
| 53 |
+
print(f"Error calculating {metric} for {model_name}: {str(e)}")
|
| 54 |
+
avg_metrics[metric] = 0 # ์๋ฌ ๋ฐ์ ์ ๊ธฐ๋ณธ๊ฐ ์ค์
|
| 55 |
+
|
| 56 |
+
metrics_data.append({
|
| 57 |
+
'model_name': model_name,
|
| 58 |
+
**avg_metrics
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Error processing model {model_name}: {str(e)}")
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
return pd.DataFrame(metrics_data)
|
| 66 |
+
|
| 67 |
+
def create_performance_chart(df, selected_metrics):
|
| 68 |
+
"""
|
| 69 |
+
๋ชจ๋ธ๋ณ ์ ํ๋ ์ฑ๋ฅ ์งํ์ ์ํ ๋ง๋ ๊ทธ๋ํ ์์ฑ
|
| 70 |
+
"""
|
| 71 |
+
fig = go.Figure()
|
| 72 |
+
|
| 73 |
+
# ๋ชจ๋ธ ์ด๋ฆ ๊ธธ์ด์ ๋ฐ๋ฅธ ๋ง์ง ๊ณ์ฐ
|
| 74 |
+
max_name_length = max([len(name) for name in df['model_name']])
|
| 75 |
+
left_margin = min(max_name_length * 7, 500) # ๊ธ์ ์์ ๋ฐ๋ผ ๋ง์ง ์กฐ์ , ์ต๋ 500
|
| 76 |
+
|
| 77 |
+
for metric in selected_metrics:
|
| 78 |
+
fig.add_trace(go.Bar(
|
| 79 |
+
name=metric,
|
| 80 |
+
y=df['model_name'], # y์ถ์ ๋ชจ๋ธ ์ด๋ฆ
|
| 81 |
+
x=df[metric], # x์ถ์ ์ฑ๋ฅ ์งํ ๊ฐ
|
| 82 |
+
text=[f'{val:.3f}' for val in df[metric]],
|
| 83 |
+
textposition='auto',
|
| 84 |
+
orientation='h' # ์ํ ๋ฐฉํฅ ๋ง๋
|
| 85 |
+
))
|
| 86 |
+
|
| 87 |
+
fig.update_layout(
|
| 88 |
+
title='Model Performance Comparison',
|
| 89 |
+
yaxis_title='Model Name',
|
| 90 |
+
xaxis_title='Performance',
|
| 91 |
+
barmode='group',
|
| 92 |
+
height=max(400, len(df) * 40), # ๋ชจ๋ธ ์์ ๋ฐ๋ผ ๋์ด ์กฐ์
|
| 93 |
+
margin=dict(l=left_margin, r=50, t=50, b=50), # ์ผ์ชฝ ๋ง์ง ๋์ ์กฐ์
|
| 94 |
+
showlegend=True,
|
| 95 |
+
legend=dict(
|
| 96 |
+
orientation="h",
|
| 97 |
+
yanchor="bottom",
|
| 98 |
+
y=1.02,
|
| 99 |
+
xanchor="right",
|
| 100 |
+
x=1
|
| 101 |
+
),
|
| 102 |
+
yaxis={'categoryorder': 'total ascending'} # ์ฑ๋ฅ ์์ผ๋ก ์ ๋ ฌ
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# y์ถ ๋ ์ด๋ธ ์คํ์ผ ์กฐ์
|
| 106 |
+
fig.update_yaxes(tickfont=dict(size=10)) # ๊ธ์ ํฌ๊ธฐ ์กฐ์
|
| 107 |
+
|
| 108 |
+
return fig
|
| 109 |
+
def create_confusion_matrix(metrics_data, selected_category):
|
| 110 |
+
"""ํผ๋ ํ๋ ฌ ์๊ฐํ ์์ฑ"""
|
| 111 |
+
# ์ ํ๋ ์นดํ
๊ณ ๋ฆฌ์ ํผ๋ ํ๋ ฌ ๋ฐ์ดํฐ
|
| 112 |
+
tp = metrics_data[selected_category]['tp']
|
| 113 |
+
tn = metrics_data[selected_category]['tn']
|
| 114 |
+
fp = metrics_data[selected_category]['fp']
|
| 115 |
+
fn = metrics_data[selected_category]['fn']
|
| 116 |
+
|
| 117 |
+
# ํผ๋ ํ๋ ฌ ๋ฐ์ดํฐ
|
| 118 |
+
z = [[tn, fp], [fn, tp]]
|
| 119 |
+
x = ['Negative', 'Positive']
|
| 120 |
+
y = ['Negative', 'Positive']
|
| 121 |
+
|
| 122 |
+
# ํํธ๋งต ์์ฑ
|
| 123 |
+
fig = go.Figure(data=go.Heatmap(
|
| 124 |
+
z=z,
|
| 125 |
+
x=x,
|
| 126 |
+
y=y,
|
| 127 |
+
colorscale=[[0, '#f7fbff'], [1, '#08306b']],
|
| 128 |
+
showscale=False,
|
| 129 |
+
text=[[str(val) for val in row] for row in z],
|
| 130 |
+
texttemplate="%{text}",
|
| 131 |
+
textfont={"color": "black", "size": 16}, # ๊ธ์ ์๏ฟฝ๏ฟฝ๏ฟฝ์ ๊ฒ์ ์์ผ๋ก ๊ณ ์
|
| 132 |
+
))
|
| 133 |
+
|
| 134 |
+
# ๋ ์ด์์ ์
๋ฐ์ดํธ
|
| 135 |
+
fig.update_layout(
|
| 136 |
+
title={
|
| 137 |
+
'text': f'Confusion Matrix - {selected_category}',
|
| 138 |
+
'y':0.9,
|
| 139 |
+
'x':0.5,
|
| 140 |
+
'xanchor': 'center',
|
| 141 |
+
'yanchor': 'top'
|
| 142 |
+
},
|
| 143 |
+
xaxis_title='Predicted',
|
| 144 |
+
yaxis_title='Actual',
|
| 145 |
+
width=600, # ๋๋น ์ฆ๊ฐ
|
| 146 |
+
height=600, # ๋์ด ์ฆ๊ฐ
|
| 147 |
+
margin=dict(l=80, r=80, t=100, b=80), # ์ฌ๋ฐฑ ์กฐ์
|
| 148 |
+
paper_bgcolor='white',
|
| 149 |
+
plot_bgcolor='white',
|
| 150 |
+
font=dict(size=14) # ์ ์ฒด ํฐํธ ํฌ๊ธฐ ์กฐ์
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# ์ถ ์ค์
|
| 154 |
+
fig.update_xaxes(side="bottom", tickfont=dict(size=14))
|
| 155 |
+
fig.update_yaxes(side="left", tickfont=dict(size=14))
|
| 156 |
+
|
| 157 |
+
return fig
|
| 158 |
+
|
| 159 |
+
def get_metrics_for_model(df, model_name, benchmark_name):
|
| 160 |
+
"""ํน์ ๋ชจ๋ธ๊ณผ ๋ฒค์น๋งํฌ์ ๋ํ ๋ฉํธ๋ฆญ์ค ๋ฐ์ดํฐ ์ถ์ถ"""
|
| 161 |
+
row = df[(df['Model name'] == model_name) & (df['Benchmark'] == benchmark_name)]
|
| 162 |
+
if not row.empty:
|
| 163 |
+
metrics = str2json(row['PIA'].iloc[0])
|
| 164 |
+
return metrics
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
def metric_visual_tab():
|
| 168 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 169 |
+
df = sheet2df(sheet_name="metric")
|
| 170 |
+
avg_metrics_df = calculate_avg_metrics(df)
|
| 171 |
+
|
| 172 |
+
# ๊ฐ๋ฅํ ๋ชจ๋ ๋ฉํธ๋ฆญ ๋ฆฌ์คํธ
|
| 173 |
+
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 174 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
|
| 175 |
+
|
| 176 |
+
with gr.Tab("๐ Performance Visualization"):
|
| 177 |
+
with gr.Row():
|
| 178 |
+
metrics_multiselect = gr.CheckboxGroup(
|
| 179 |
+
choices=all_metrics,
|
| 180 |
+
value=[], # ์ด๊ธฐ ์ ํ ์์
|
| 181 |
+
label="Select Performance Metrics",
|
| 182 |
+
interactive=True
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Performance comparison chart (์ด๊ธฐ๊ฐ ์์)
|
| 186 |
+
performance_plot = gr.Plot()
|
| 187 |
+
|
| 188 |
+
def update_plot(selected_metrics):
|
| 189 |
+
if not selected_metrics: # ์ ํ๋ ๋ฉํธ๋ฆญ์ด ์๋ ๊ฒฝ์ฐ
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
# accuracy ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
|
| 194 |
+
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
|
| 195 |
+
return create_performance_chart(sorted_df, selected_metrics)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"Error in update_plot: {str(e)}")
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
# Connect event handler
|
| 201 |
+
metrics_multiselect.change(
|
| 202 |
+
fn=update_plot,
|
| 203 |
+
inputs=[metrics_multiselect],
|
| 204 |
+
outputs=[performance_plot]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def create_category_metrics_chart(metrics_data, selected_metrics):
|
| 208 |
+
"""
|
| 209 |
+
์ ํ๋ ๋ชจ๋ธ์ ๊ฐ ์นดํ
๊ณ ๋ฆฌ๋ณ ์ฑ๋ฅ ์งํ ์๊ฐํ
|
| 210 |
+
"""
|
| 211 |
+
fig = go.Figure()
|
| 212 |
+
categories = ['falldown', 'violence', 'fire']
|
| 213 |
+
|
| 214 |
+
for metric in selected_metrics:
|
| 215 |
+
values = []
|
| 216 |
+
for category in categories:
|
| 217 |
+
values.append(metrics_data[category][metric])
|
| 218 |
+
|
| 219 |
+
fig.add_trace(go.Bar(
|
| 220 |
+
name=metric,
|
| 221 |
+
x=categories,
|
| 222 |
+
y=values,
|
| 223 |
+
text=[f'{val:.3f}' for val in values],
|
| 224 |
+
textposition='auto',
|
| 225 |
+
))
|
| 226 |
+
|
| 227 |
+
fig.update_layout(
|
| 228 |
+
title='Performance Metrics by Category',
|
| 229 |
+
xaxis_title='Category',
|
| 230 |
+
yaxis_title='Score',
|
| 231 |
+
barmode='group',
|
| 232 |
+
height=500,
|
| 233 |
+
showlegend=True,
|
| 234 |
+
legend=dict(
|
| 235 |
+
orientation="h",
|
| 236 |
+
yanchor="bottom",
|
| 237 |
+
y=1.02,
|
| 238 |
+
xanchor="right",
|
| 239 |
+
x=1
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return fig
|
| 244 |
+
|
| 245 |
+
def metric_visual_tab():
|
| 246 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ฒซ ๋ฒ์งธ ์๊ฐํ ๋ถ๋ถ
|
| 247 |
+
df = sheet2df(sheet_name="metric")
|
| 248 |
+
avg_metrics_df = calculate_avg_metrics(df)
|
| 249 |
+
|
| 250 |
+
# ๊ฐ๋ฅํ ๋ชจ๋ ๋ฉํธ๋ฆญ ๋ฆฌ์คํธ
|
| 251 |
+
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 252 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
|
| 253 |
+
|
| 254 |
+
with gr.Tab("๐ Performance Visualization"):
|
| 255 |
+
with gr.Row():
|
| 256 |
+
metrics_multiselect = gr.CheckboxGroup(
|
| 257 |
+
choices=all_metrics,
|
| 258 |
+
value=[], # ์ด๊ธฐ ์ ํ ์์
|
| 259 |
+
label="Select Performance Metrics",
|
| 260 |
+
interactive=True
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
performance_plot = gr.Plot()
|
| 264 |
+
|
| 265 |
+
def update_plot(selected_metrics):
|
| 266 |
+
if not selected_metrics:
|
| 267 |
+
return None
|
| 268 |
+
try:
|
| 269 |
+
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
|
| 270 |
+
return create_performance_chart(sorted_df, selected_metrics)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Error in update_plot: {str(e)}")
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
metrics_multiselect.change(
|
| 276 |
+
fn=update_plot,
|
| 277 |
+
inputs=[metrics_multiselect],
|
| 278 |
+
outputs=[performance_plot]
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# ๋ ๋ฒ์งธ ์๊ฐํ ์น์
|
| 282 |
+
gr.Markdown("## Detailed Model Analysis")
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
# ๋ชจ๋ธ ์ ํ
|
| 286 |
+
model_dropdown = gr.Dropdown(
|
| 287 |
+
choices=sorted(df['Model name'].unique().tolist()),
|
| 288 |
+
label="Select Model",
|
| 289 |
+
interactive=True
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ์ปฌ๋ผ ์ ํ (Model name ์ ์ธ)
|
| 293 |
+
column_dropdown = gr.Dropdown(
|
| 294 |
+
choices=[col for col in df.columns if col != 'Model name'],
|
| 295 |
+
label="Select Metric Column",
|
| 296 |
+
interactive=True
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ์นดํ
๊ณ ๋ฆฌ ์ ํ
|
| 300 |
+
category_dropdown = gr.Dropdown(
|
| 301 |
+
choices=['falldown', 'violence', 'fire'],
|
| 302 |
+
label="Select Category",
|
| 303 |
+
interactive=True
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# ํผ๋ ํ๋ ฌ ์๊ฐํ
|
| 307 |
+
with gr.Row():
|
| 308 |
+
with gr.Column(scale=1):
|
| 309 |
+
gr.Markdown("") # ๋น ๊ณต๊ฐ
|
| 310 |
+
with gr.Column(scale=2):
|
| 311 |
+
confusion_matrix_plot = gr.Plot(container=True) # container=True ์ถ๊ฐ
|
| 312 |
+
with gr.Column(scale=1):
|
| 313 |
+
gr.Markdown("") # ๋น ๊ณต๊ฐ
|
| 314 |
+
|
| 315 |
+
with gr.Column(scale=2):
|
| 316 |
+
# ์ฑ๋ฅ ์งํ ์ ํ
|
| 317 |
+
metrics_select = gr.CheckboxGroup(
|
| 318 |
+
choices=['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 319 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'],
|
| 320 |
+
value=['accuracy'], # ๊ธฐ๋ณธ๊ฐ
|
| 321 |
+
label="Select Metrics to Display",
|
| 322 |
+
interactive=True
|
| 323 |
+
)
|
| 324 |
+
category_metrics_plot = gr.Plot()
|
| 325 |
+
|
| 326 |
+
def update_visualizations(model, column, category, selected_metrics):
|
| 327 |
+
if not all([model, column]): # category๋ ํผ๋ํ๋ ฌ์๋ง ํ์
|
| 328 |
+
return None, None
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
# ์ ํ๋ ๋ชจ๋ธ์ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
|
| 332 |
+
selected_data = df[df['Model name'] == model][column].iloc[0]
|
| 333 |
+
metrics = str2json(selected_data)
|
| 334 |
+
|
| 335 |
+
if not metrics:
|
| 336 |
+
return None, None
|
| 337 |
+
|
| 338 |
+
# ํผ๋ ํ๋ ฌ (์ผ์ชฝ)
|
| 339 |
+
confusion_fig = create_confusion_matrix(metrics, category) if category else None
|
| 340 |
+
|
| 341 |
+
# ์นดํ
๊ณ ๋ฆฌ๋ณ ์ฑ๋ฅ ์งํ (์ค๋ฅธ์ชฝ)
|
| 342 |
+
if not selected_metrics:
|
| 343 |
+
selected_metrics = ['accuracy']
|
| 344 |
+
category_fig = create_category_metrics_chart(metrics, selected_metrics)
|
| 345 |
+
|
| 346 |
+
return confusion_fig, category_fig
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"Error updating visualizations: {str(e)}")
|
| 350 |
+
return None, None
|
| 351 |
+
|
| 352 |
+
# ์ด๋ฒคํธ ํธ๋ค๋ฌ ์ฐ๊ฒฐ
|
| 353 |
+
for input_component in [model_dropdown, column_dropdown, category_dropdown, metrics_select]:
|
| 354 |
+
input_component.change(
|
| 355 |
+
fn=update_visualizations,
|
| 356 |
+
inputs=[model_dropdown, column_dropdown, category_dropdown, metrics_select],
|
| 357 |
+
outputs=[confusion_matrix_plot, category_metrics_plot]
|
| 358 |
+
)
|
| 359 |
+
# def update_confusion_matrix(model, column, category):
|
| 360 |
+
# if not all([model, column, category]):
|
| 361 |
+
# return None
|
| 362 |
+
|
| 363 |
+
# try:
|
| 364 |
+
# # ์ ํ๋ ๋ชจ๋ธ์ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
|
| 365 |
+
# selected_data = df[df['Model name'] == model][column].iloc[0]
|
| 366 |
+
# metrics = str2json(selected_data)
|
| 367 |
+
|
| 368 |
+
# if metrics and category in metrics:
|
| 369 |
+
# category_data = metrics[category]
|
| 370 |
+
|
| 371 |
+
# # ํผ๋ ํ๋ ฌ ๋ฐ์ดํฐ
|
| 372 |
+
# confusion_data = {
|
| 373 |
+
# 'tp': category_data['tp'],
|
| 374 |
+
# 'tn': category_data['tn'],
|
| 375 |
+
# 'fp': category_data['fp'],
|
| 376 |
+
# 'fn': category_data['fn']
|
| 377 |
+
# }
|
| 378 |
+
|
| 379 |
+
# # ํํธ๋งต ์์ฑ
|
| 380 |
+
# z = [[confusion_data['tn'], confusion_data['fp']],
|
| 381 |
+
# [confusion_data['fn'], confusion_data['tp']]]
|
| 382 |
+
|
| 383 |
+
# fig = go.Figure(data=go.Heatmap(
|
| 384 |
+
# z=z,
|
| 385 |
+
# x=['Negative', 'Positive'],
|
| 386 |
+
# y=['Negative', 'Positive'],
|
| 387 |
+
# text=[[str(val) for val in row] for row in z],
|
| 388 |
+
# texttemplate="%{text}",
|
| 389 |
+
# textfont={"size": 16},
|
| 390 |
+
# colorscale='Blues',
|
| 391 |
+
# showscale=False
|
| 392 |
+
# ))
|
| 393 |
+
|
| 394 |
+
# fig.update_layout(
|
| 395 |
+
# title=f'Confusion Matrix - {category}',
|
| 396 |
+
# xaxis_title='Predicted',
|
| 397 |
+
# yaxis_title='Actual',
|
| 398 |
+
# width=500,
|
| 399 |
+
# height=500
|
| 400 |
+
# )
|
| 401 |
+
|
| 402 |
+
# return fig
|
| 403 |
+
|
| 404 |
+
# except Exception as e:
|
| 405 |
+
# print(f"Error updating confusion matrix: {str(e)}")
|
| 406 |
+
# return None
|
| 407 |
+
|
| 408 |
+
# # ์ด๋ฒคํธ ํธ๋ค๋ฌ ์ฐ๊ฒฐ
|
| 409 |
+
# for dropdown in [model_dropdown, column_dropdown, category_dropdown]:
|
| 410 |
+
# dropdown.change(
|
| 411 |
+
# fn=update_confusion_matrix,
|
| 412 |
+
# inputs=[model_dropdown, column_dropdown, category_dropdown],
|
| 413 |
+
# outputs=confusion_matrix_plot
|
| 414 |
+
# )
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
leaderboard_ui/tab/submit_tab.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
def list_to_dataframe(data):
|
| 6 |
+
"""
|
| 7 |
+
๋ฆฌ์คํธ ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํํ๋ ํจ์.
|
| 8 |
+
๊ฐ ๊ฐ์ด ๋ฐ์ดํฐํ๋ ์์ ํ ํ(row)์ ๋ค์ด๊ฐ๋๋ก ์ค์ .
|
| 9 |
+
|
| 10 |
+
:param data: ๋ฆฌ์คํธ ํํ์ ๋ฐ์ดํฐ
|
| 11 |
+
:return: pandas.DataFrame
|
| 12 |
+
"""
|
| 13 |
+
if not isinstance(data, list):
|
| 14 |
+
raise ValueError("์
๋ ฅ ๋ฐ์ดํฐ๋ ๋ฆฌ์คํธ ํํ์ฌ์ผ ํฉ๋๋ค.")
|
| 15 |
+
|
| 16 |
+
# ์ด ์ด๋ฆ์ ๋ฌธ์์ด๋ก ์ค์
|
| 17 |
+
headers = [f"Queue {i}" for i in range(len(data))]
|
| 18 |
+
df = pd.DataFrame([data], columns=headers)
|
| 19 |
+
return df
|
| 20 |
+
|
| 21 |
+
def model_submit(model_id , benchmark_name, prompt_cfg_name):
|
| 22 |
+
model_id = model_id.split("/")[-1]
|
| 23 |
+
sheet_manager = SheetManager()
|
| 24 |
+
sheet_manager.push(model_id)
|
| 25 |
+
model_q = list_to_dataframe(sheet_manager.get_all_values())
|
| 26 |
+
sheet_manager.change_column("benchmark_name")
|
| 27 |
+
sheet_manager.push(benchmark_name)
|
| 28 |
+
sheet_manager.change_column("prompt_cfg_name")
|
| 29 |
+
sheet_manager.push(prompt_cfg_name)
|
| 30 |
+
|
| 31 |
+
return model_q
|
| 32 |
+
|
| 33 |
+
def read_queue():
|
| 34 |
+
sheet_manager = SheetManager()
|
| 35 |
+
return list_to_dataframe(sheet_manager.get_all_values())
|
| 36 |
+
|
| 37 |
+
def submit_tab():
|
| 38 |
+
with gr.Tab("๐ Submit here! "):
|
| 39 |
+
with gr.Row():
|
| 40 |
+
gr.Markdown("# โ๏ธโจ Submit your Result here!")
|
| 41 |
+
|
| 42 |
+
with gr.Row():
|
| 43 |
+
with gr.Tab("Model"):
|
| 44 |
+
with gr.Row():
|
| 45 |
+
with gr.Column():
|
| 46 |
+
model_id_textbox = gr.Textbox(
|
| 47 |
+
label="huggingface_id",
|
| 48 |
+
placeholder="PIA-SPACE-LAB/T2V_CLIP4Clip",
|
| 49 |
+
interactive = True
|
| 50 |
+
)
|
| 51 |
+
benchmark_name_textbox = gr.Textbox(
|
| 52 |
+
label="benchmark_name",
|
| 53 |
+
placeholder="PiaFSV",
|
| 54 |
+
interactive = True,
|
| 55 |
+
value="PIA"
|
| 56 |
+
)
|
| 57 |
+
prompt_cfg_name_textbox = gr.Textbox(
|
| 58 |
+
label="prompt_cfg_name",
|
| 59 |
+
placeholder="topk",
|
| 60 |
+
interactive = True,
|
| 61 |
+
value="topk"
|
| 62 |
+
)
|
| 63 |
+
with gr.Column():
|
| 64 |
+
gr.Markdown("## ํ๊ฐ๋ฅผ ๋ฐ์๋ณด์ธ์ ๋ฐ๋์ ํ๊น
ํ์ด์ค์ ์
๋ก๋๋ ๋ชจ๋ธ์ด์ด์ผ ํฉ๋๋ค.")
|
| 65 |
+
gr.Markdown("#### ํ์ฌ ํ๊ฐ ๋๊ธฐ์ค ๋ชจ๋ธ์
๋๋ค.")
|
| 66 |
+
model_queue = gr.Dataframe()
|
| 67 |
+
refresh_button = gr.Button("refresh")
|
| 68 |
+
refresh_button.click(
|
| 69 |
+
fn=read_queue,
|
| 70 |
+
outputs=model_queue
|
| 71 |
+
)
|
| 72 |
+
with gr.Row():
|
| 73 |
+
model_submit_button = gr.Button("Submit Eval")
|
| 74 |
+
model_submit_button.click(
|
| 75 |
+
fn=model_submit,
|
| 76 |
+
inputs=[model_id_textbox,
|
| 77 |
+
benchmark_name_textbox ,
|
| 78 |
+
prompt_cfg_name_textbox],
|
| 79 |
+
outputs=model_queue
|
| 80 |
+
)
|
| 81 |
+
with gr.Tab("Prompt"):
|
| 82 |
+
with gr.Row():
|
| 83 |
+
with gr.Column():
|
| 84 |
+
prompt_cfg_selector = gr.Dropdown(
|
| 85 |
+
choices=["์ ๋ถ"],
|
| 86 |
+
label="Prompt_CFG",
|
| 87 |
+
multiselect=False,
|
| 88 |
+
value=None,
|
| 89 |
+
interactive=True,
|
| 90 |
+
)
|
| 91 |
+
weight_type = gr.Dropdown(
|
| 92 |
+
choices=["์ ๋ถ"],
|
| 93 |
+
label="Weights type",
|
| 94 |
+
multiselect=False,
|
| 95 |
+
value=None,
|
| 96 |
+
interactive=True,
|
| 97 |
+
)
|
| 98 |
+
with gr.Column():
|
| 99 |
+
gr.Markdown("## ํ๊ฐ๋ฅผ ๋ฐ์๋ณด์ธ์ ๋ฐ๋์ ํ๊น
ํ์ด์ค์ ์
๋ก๋๋ ๋ชจ๋ธ์ด์ด์ผ ํฉ๋๋ค.")
|
| 100 |
+
|
| 101 |
+
with gr.Row():
|
| 102 |
+
prompt_submit_button = gr.Button("Submit Eval")
|
| 103 |
+
|
pia_bench/bench.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from devmacs_core.devmacs_core import DevMACSCore
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from utils.except_dir import cust_listdir
|
| 9 |
+
from utils.parser import load_config
|
| 10 |
+
from utils.logger import custom_logger
|
| 11 |
+
|
| 12 |
+
logger = custom_logger(__name__)
|
| 13 |
+
|
| 14 |
+
DATA_SET = "dataset"
|
| 15 |
+
CFG = "CFG"
|
| 16 |
+
VECTOR = "vector"
|
| 17 |
+
TEXT = "text"
|
| 18 |
+
VIDEO = "video"
|
| 19 |
+
EXECPT = ["@eaDir", "README.md"]
|
| 20 |
+
ALRAM = "alarm"
|
| 21 |
+
METRIC = "metric"
|
| 22 |
+
MSRVTT = "MSRVTT"
|
| 23 |
+
MODEL = "models"
|
| 24 |
+
|
| 25 |
+
class PiaBenchMark:
|
| 26 |
+
def __init__(self, benchmark_path :str, cfg_target_path : str = None , model_name : str = MSRVTT , token:str =None):
|
| 27 |
+
"""
|
| 28 |
+
PIA ๋ฒค์น๋งํฌ ์์คํ
์ ๊ตฌ์ถ ์ํ ํด๋์ค์
๋๋ค.
|
| 29 |
+
๋ฐ์ดํฐ์
ํด๋๊ตฌ์กฐ, ๋ฒกํฐ ์ถ์ถ, ๊ตฌ์กฐ ์์ฑ ๋ฑ์ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
benchmark_path (str): ๋ฒค์น๋งํฌ ๊ธฐ๋ณธ ๊ฒฝ๋ก
|
| 33 |
+
cfg_target_path (str): ์ค์ ํ์ผ ๊ฒฝ๋ก
|
| 34 |
+
model_name (str): ์ฌ์ฉํ ๋ชจ๋ธ ์ด๋ฆ
|
| 35 |
+
token (str): ์ธ์ฆ ํ ํฐ
|
| 36 |
+
categories (List[str]): ์ฒ๋ฆฌํ ์นดํ
๊ณ ๋ฆฌ ๋ชฉ๋ก
|
| 37 |
+
"""
|
| 38 |
+
self.benchmark_path = benchmark_path
|
| 39 |
+
self.token = token
|
| 40 |
+
self.model_name = model_name
|
| 41 |
+
self.devmacs_core = None
|
| 42 |
+
self.cfg_target_path = cfg_target_path
|
| 43 |
+
self.cfg_name = Path(cfg_target_path).stem
|
| 44 |
+
self.cfg_dict = load_config(self.cfg_target_path)
|
| 45 |
+
|
| 46 |
+
self.dataset_path = os.path.join(benchmark_path, DATA_SET)
|
| 47 |
+
self.cfg_path = os.path.join(benchmark_path , CFG)
|
| 48 |
+
|
| 49 |
+
self.model_path = os.path.join(self.benchmark_path , MODEL)
|
| 50 |
+
self.model_name_path = os.path.join(self.model_path ,self.model_name)
|
| 51 |
+
self.model_name_cfg_path = os.path.join(self.model_name_path , CFG)
|
| 52 |
+
self.model_name_cfg_name_path = os.path.join(self.model_name_cfg_path , self.cfg_name)
|
| 53 |
+
self.alram_path = os.path.join(self.model_name_cfg_name_path , ALRAM)
|
| 54 |
+
self.metric_path = os.path.join(self.model_name_cfg_name_path , METRIC)
|
| 55 |
+
|
| 56 |
+
self.vector_path = os.path.join(self.model_name_path , VECTOR)
|
| 57 |
+
self.vector_text_path = os.path.join(self.vector_path , TEXT)
|
| 58 |
+
self.vector_video_path = os.path.join(self.vector_path , VIDEO)
|
| 59 |
+
|
| 60 |
+
self.categories = []
|
| 61 |
+
|
| 62 |
+
def _create_frame_labels(self, label_data: Dict, total_frames: int) -> pd.DataFrame:
|
| 63 |
+
"""
|
| 64 |
+
ํ๋ ์ ๊ธฐ๋ฐ์ ๋ ์ด๋ธ ๋ฐ์ดํฐํ๋ ์์ ์์ฑํฉ๋๋ค.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
label_data (Dict): ๋ ์ด๋ธ ์ ๋ณด๊ฐ ๋ด๊ธด ๋์
๋๋ฆฌ
|
| 68 |
+
total_frames (int): ์ด ํ๋ ์ ์
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
pd.DataFrame: ํ๋ ์๋ณ ๋ ์ด๋ธ์ด ์ ์ฅ๋ ๋ฐ์ดํฐํ๋ ์
|
| 72 |
+
|
| 73 |
+
Note:
|
| 74 |
+
๋ฐํ๋๋ ๋ฐ์ดํฐํ๋ ์์ ๊ฐ ํ๋ ์๋ณ๋ก ์นดํ
๊ณ ๋ฆฌ์ ์กด์ฌ ์ฌ๋ถ๋ฅผ 0๊ณผ 1๋ก ํ์ํฉ๋๋ค.
|
| 75 |
+
"""
|
| 76 |
+
colmuns = ['frame'] + sorted(self.categories)
|
| 77 |
+
df = pd.DataFrame(0, index=range(total_frames), columns=colmuns)
|
| 78 |
+
df['frame'] = range(total_frames)
|
| 79 |
+
|
| 80 |
+
for clip_info in label_data['clips'].values():
|
| 81 |
+
category = clip_info['category']
|
| 82 |
+
if category in self.categories: # ํด๋น ์นดํ
๊ณ ๋ฆฌ๊ฐ ๋ชฉ๋ก์ ์๋ ๊ฒฝ์ฐ๋ง ์ฒ๋ฆฌ
|
| 83 |
+
start_frame, end_frame = clip_info['timestamp']
|
| 84 |
+
df.loc[start_frame:end_frame, category] = 1
|
| 85 |
+
|
| 86 |
+
return df
|
| 87 |
+
|
| 88 |
+
def preprocess_label_to_csv(self):
|
| 89 |
+
"""
|
| 90 |
+
๋ฐ์ดํฐ์
์ ๋ชจ๋ JSON ๋ ์ด๋ธ ํ์ผ์ ํ๋ ์ ๊ธฐ๋ฐ CSV ํ์ผ๋ก ๋ณํํฉ๋๋ค.
|
| 91 |
+
|
| 92 |
+
Raises:
|
| 93 |
+
ValueError: JSON ํ์ผ์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ ๋ฐ์
|
| 94 |
+
|
| 95 |
+
Note:
|
| 96 |
+
- ๊ฐ ์นดํ
๊ณ ๋ฆฌ ํด๋ ๋ด์ JSON ํ์ผ์ ์ฒ๋ฆฌํฉ๋๋ค.
|
| 97 |
+
- ์ด๋ฏธ CSV๋ก ๋ณํ๋ ํ์ผ์ ๊ฑด๋๋๋๋ค.
|
| 98 |
+
"""
|
| 99 |
+
json_files = []
|
| 100 |
+
csv_files = []
|
| 101 |
+
|
| 102 |
+
# categories๊ฐ ๋น์ด์๋ ๊ฒฝ์ฐ์๋ง ์ฑ์ฐ๋๋ก ์์
|
| 103 |
+
if not self.categories:
|
| 104 |
+
for cate in cust_listdir(self.dataset_path):
|
| 105 |
+
if os.path.isdir(os.path.join(self.dataset_path, cate)):
|
| 106 |
+
self.categories.append(cate)
|
| 107 |
+
|
| 108 |
+
for category in self.categories:
|
| 109 |
+
category_path = os.path.join(self.dataset_path, category)
|
| 110 |
+
category_jsons = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.json')]
|
| 111 |
+
json_files.extend(category_jsons)
|
| 112 |
+
category_csvs = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.csv')]
|
| 113 |
+
csv_files.extend(category_csvs)
|
| 114 |
+
|
| 115 |
+
if not json_files:
|
| 116 |
+
logger.error("No JSON files found in any category directory")
|
| 117 |
+
raise ValueError("No JSON files found in any category directory")
|
| 118 |
+
|
| 119 |
+
if len(json_files) == len(csv_files):
|
| 120 |
+
logger.info("All JSON files have already been processed to CSV. No further processing needed.")
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
for json_file in json_files:
|
| 124 |
+
json_path = os.path.join(self.dataset_path, json_file)
|
| 125 |
+
video_name = os.path.splitext(json_file)[0]
|
| 126 |
+
|
| 127 |
+
label_info = load_config(json_path)
|
| 128 |
+
video_info = label_info['video_info']
|
| 129 |
+
total_frames = video_info['total_frame']
|
| 130 |
+
|
| 131 |
+
df = self._create_frame_labels( label_info, total_frames)
|
| 132 |
+
|
| 133 |
+
output_path = os.path.join(self.dataset_path, f"{video_name}.csv")
|
| 134 |
+
df.to_csv(output_path , index=False)
|
| 135 |
+
logger.info("Complete !")
|
| 136 |
+
|
| 137 |
+
def preprocess_structure(self):
|
| 138 |
+
"""
|
| 139 |
+
๋ฒค์น๋งํฌ ์์คํ
์ ํ์ํ ๋๋ ํ ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 140 |
+
|
| 141 |
+
์์ฑ๋๋ ๊ตฌ์กฐ:
|
| 142 |
+
- dataset/: ๋ฐ์ดํฐ์
์ ์ฅ
|
| 143 |
+
- cfg/: ์ค์ ํ์ผ ์ ์ฅ
|
| 144 |
+
- vector/: ์ถ์ถ๋ ๋ฒกํฐ ์ ์ฅ
|
| 145 |
+
- alarm/: ์๋ ๊ด๋ จ ํ์ผ ์ ์ฅ
|
| 146 |
+
- metric/: ํ๊ฐ ์งํ ์ ์ฅ
|
| 147 |
+
|
| 148 |
+
Note:
|
| 149 |
+
๊ธฐ์กด ์นดํ
๊ณ ๋ฆฌ ๊ตฌ์กฐ๊ฐ ์๋ค๋ฉด ์ ์งํ๊ณ , ์๋ค๋ฉด ์๋ก ์์ฑํฉ๋๋ค.
|
| 150 |
+
"""
|
| 151 |
+
logger.info("Starting directory structure preprocessing...")
|
| 152 |
+
os.makedirs(self.dataset_path, exist_ok=True)
|
| 153 |
+
os.makedirs(self.cfg_path, exist_ok=True)
|
| 154 |
+
os.makedirs(self.vector_text_path, exist_ok=True)
|
| 155 |
+
os.makedirs(self.vector_video_path, exist_ok=True)
|
| 156 |
+
os.makedirs(self.alram_path, exist_ok=True)
|
| 157 |
+
os.makedirs(self.metric_path, exist_ok=True)
|
| 158 |
+
os.makedirs(self.model_name_cfg_name_path , exist_ok=True)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# dataset ํด๋๊ฐ ์ด๋ฏธ ์กด์ฌํ๊ณ ๊ทธ ์์ ์นดํ
๊ณ ๋ฆฌ ํด๋๋ค์ด ์๋์ง ํ์ธ
|
| 162 |
+
if os.path.exists(self.dataset_path) and any(os.path.isdir(os.path.join(self.dataset_path, d)) for d in cust_listdir(self.dataset_path)):
|
| 163 |
+
# ์ด๋ฏธ ๊ตฌ์ฑ๋ ๊ตฌ์กฐ๋ผ๋ฉด, dataset ํด๋์์ ์นดํ
๊ณ ๋ฆฌ๋ค์ ๊ฐ์ ธ์ด
|
| 164 |
+
self.categories = [d for d in cust_listdir(self.dataset_path) if os.path.isdir(os.path.join(self.dataset_path, d))]
|
| 165 |
+
else:
|
| 166 |
+
# ์ฒ์ ์คํ๋๋ ๊ฒฝ์ฐ, ๊ธฐ์กด ๋ก์ง๋๋ก ์งํ
|
| 167 |
+
for item in cust_listdir(self.benchmark_path):
|
| 168 |
+
item_path = os.path.join(self.benchmark_path, item)
|
| 169 |
+
|
| 170 |
+
if item.startswith("@") or item in [METRIC ,"README.md",MODEL, CFG, DATA_SET, VECTOR, ALRAM] or not os.path.isdir(item_path):
|
| 171 |
+
continue
|
| 172 |
+
target_path = os.path.join(self.dataset_path, item)
|
| 173 |
+
if not os.path.exists(target_path):
|
| 174 |
+
shutil.move(item_path, target_path)
|
| 175 |
+
self.categories.append(item)
|
| 176 |
+
|
| 177 |
+
for category in self.categories:
|
| 178 |
+
category_path = os.path.join(self.vector_video_path, category)
|
| 179 |
+
os.makedirs(category_path, exist_ok=True)
|
| 180 |
+
|
| 181 |
+
logger.info("Folder preprocessing completed.")
|
| 182 |
+
|
| 183 |
+
def extract_visual_vector(self):
|
| 184 |
+
"""
|
| 185 |
+
๋ฐ์ดํฐ์
์์ ์๊ฐ์ ํน์ง ๋ฒกํฐ๋ฅผ ์ถ์ถํฉ๋๋ค.
|
| 186 |
+
|
| 187 |
+
Note:
|
| 188 |
+
- Hugging Face ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํน์ง์ ์ถ์ถํฉ๋๋ค.
|
| 189 |
+
- ์ถ์ถ๋ ๋ฒกํฐ๋ vector/video/ ๊ฒฝ๋ก์ ์ ์ฅ๋ฉ๋๋ค.
|
| 190 |
+
|
| 191 |
+
Requires:
|
| 192 |
+
DevMACSCore๊ฐ ์ด๊ธฐํ๋์ด ์์ด์ผ ํฉ๋๋ค.
|
| 193 |
+
"""
|
| 194 |
+
logger.info(f"Starting visual vector extraction using model: {self.model_name}")
|
| 195 |
+
try:
|
| 196 |
+
self.devmacs_core = DevMACSCore.from_huggingface(token=self.token, repo_id=f"PIA-SPACE-LAB/{self.model_name}")
|
| 197 |
+
self.devmacs_core.save_visual_results(
|
| 198 |
+
vid_dir = self.dataset_path,
|
| 199 |
+
result_dir = self.vector_video_path
|
| 200 |
+
)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"Error during vector extraction: {str(e)}")
|
| 203 |
+
raise
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
from dotenv import load_dotenv
|
| 207 |
+
import os
|
| 208 |
+
load_dotenv()
|
| 209 |
+
|
| 210 |
+
access_token = os.getenv("ACCESS_TOKEN")
|
| 211 |
+
model_name = "T2V_CLIP4CLIP_MSRVTT"
|
| 212 |
+
|
| 213 |
+
benchmark_path = "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA"
|
| 214 |
+
cfg_target_path= "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA/CFG/topk.json"
|
| 215 |
+
|
| 216 |
+
pia_benchmark = PiaBenchMark(benchmark_path ,model_name=model_name, cfg_target_path= cfg_target_path , token=access_token )
|
| 217 |
+
pia_benchmark.preprocess_structure()
|
| 218 |
+
pia_benchmark.preprocess_label_to_csv()
|
| 219 |
+
print("Categories identified:", pia_benchmark.categories)
|
pia_bench/checker/bench_checker.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Dict, Optional, Tuple
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
logging.basicConfig(level=logging.INFO)
|
| 8 |
+
|
| 9 |
+
class BenchChecker:
|
| 10 |
+
def __init__(self, base_path: str):
|
| 11 |
+
"""Initialize BenchChecker with base assets path.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
base_path (str): Base path to assets directory containing benchmark folders
|
| 15 |
+
"""
|
| 16 |
+
self.base_path = Path(base_path)
|
| 17 |
+
self.logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
def check_benchmark_exists(self, benchmark_name: str) -> bool:
|
| 20 |
+
"""Check if benchmark folder exists."""
|
| 21 |
+
benchmark_path = self.base_path / benchmark_name
|
| 22 |
+
exists = benchmark_path.exists() and benchmark_path.is_dir()
|
| 23 |
+
if exists:
|
| 24 |
+
self.logger.info(f"Found benchmark directory: {benchmark_name}")
|
| 25 |
+
else:
|
| 26 |
+
self.logger.error(f"Benchmark directory not found: {benchmark_name}")
|
| 27 |
+
return exists
|
| 28 |
+
|
| 29 |
+
def get_video_list(self, benchmark_name: str) -> List[str]:
|
| 30 |
+
"""Get list of videos from benchmark's dataset directory. Return empty list if no videos found."""
|
| 31 |
+
dataset_path = self.base_path / benchmark_name / "dataset"
|
| 32 |
+
videos = []
|
| 33 |
+
|
| 34 |
+
if not dataset_path.exists():
|
| 35 |
+
self.logger.info(f"Dataset directory exists but no videos found for {benchmark_name}")
|
| 36 |
+
return videos # ๋น ๋ฆฌ์คํธ ๋ฐํ
|
| 37 |
+
|
| 38 |
+
# Recursively find all .mp4 files
|
| 39 |
+
for category in dataset_path.glob("*"):
|
| 40 |
+
if category.is_dir():
|
| 41 |
+
for video_file in category.glob("*.mp4"):
|
| 42 |
+
videos.append(video_file.stem)
|
| 43 |
+
|
| 44 |
+
self.logger.info(f"Found {len(videos)} videos in {benchmark_name} dataset")
|
| 45 |
+
return videos
|
| 46 |
+
|
| 47 |
+
def check_model_exists(self, benchmark_name: str, model_name: str) -> bool:
|
| 48 |
+
"""Check if model directory exists in benchmark's models directory."""
|
| 49 |
+
model_path = self.base_path / benchmark_name / "models" / model_name
|
| 50 |
+
exists = model_path.exists() and model_path.is_dir()
|
| 51 |
+
if exists:
|
| 52 |
+
self.logger.info(f"Found model directory: {model_name}")
|
| 53 |
+
else:
|
| 54 |
+
self.logger.error(f"Model directory not found: {model_name}")
|
| 55 |
+
return exists
|
| 56 |
+
|
| 57 |
+
def check_cfg_files(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Tuple[bool, bool]:
|
| 58 |
+
"""Check if CFG files/directories exist in both benchmark and model directories."""
|
| 59 |
+
# Check benchmark CFG json
|
| 60 |
+
benchmark_cfg = self.base_path / benchmark_name / "CFG" / f"{cfg_prompt}.json"
|
| 61 |
+
benchmark_cfg_exists = benchmark_cfg.exists() and benchmark_cfg.is_file()
|
| 62 |
+
|
| 63 |
+
# Check model CFG directory
|
| 64 |
+
model_cfg = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt
|
| 65 |
+
model_cfg_exists = model_cfg.exists() and model_cfg.is_dir()
|
| 66 |
+
|
| 67 |
+
if benchmark_cfg_exists:
|
| 68 |
+
self.logger.info(f"Found benchmark CFG file: {cfg_prompt}.json")
|
| 69 |
+
else:
|
| 70 |
+
self.logger.error(f"Benchmark CFG file not found: {cfg_prompt}.json")
|
| 71 |
+
|
| 72 |
+
if model_cfg_exists:
|
| 73 |
+
self.logger.info(f"Found model CFG directory: {cfg_prompt}")
|
| 74 |
+
else:
|
| 75 |
+
self.logger.error(f"Model CFG directory not found: {cfg_prompt}")
|
| 76 |
+
|
| 77 |
+
return benchmark_cfg_exists, model_cfg_exists
|
| 78 |
+
def check_vector_files(self, benchmark_name: str, model_name: str, video_list: List[str]) -> bool:
|
| 79 |
+
"""Check if video vectors match with dataset."""
|
| 80 |
+
vector_path = self.base_path / benchmark_name / "models" / model_name / "vector" / "video"
|
| 81 |
+
|
| 82 |
+
# ๋น๋์ค๊ฐ ์๋ ๊ฒฝ์ฐ๋ ๋ฌด์กฐ๊ฑด False
|
| 83 |
+
if not video_list:
|
| 84 |
+
self.logger.error("No videos found in dataset - cannot proceed")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
# ๋ฒกํฐ ๋๋ ํ ๋ฆฌ๊ฐ ์๋์ง ํ์ธ
|
| 88 |
+
if not vector_path.exists():
|
| 89 |
+
self.logger.error("Vector directory doesn't exist")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
# ๋ฒกํฐ ํ์ผ ๋ฆฌ์คํธ ๊ฐ์ ธ์ค๊ธฐ
|
| 93 |
+
# vector_files = [f.stem for f in vector_path.glob("*.npy")]
|
| 94 |
+
vector_files = [f.stem for f in vector_path.rglob("*.npy")]
|
| 95 |
+
|
| 96 |
+
missing_vectors = set(video_list) - set(vector_files)
|
| 97 |
+
extra_vectors = set(vector_files) - set(video_list)
|
| 98 |
+
|
| 99 |
+
if missing_vectors:
|
| 100 |
+
self.logger.error(f"Missing vectors for videos: {missing_vectors}")
|
| 101 |
+
return False
|
| 102 |
+
if extra_vectors:
|
| 103 |
+
self.logger.error(f"Extra vectors found: {extra_vectors}")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
self.logger.info(f"Vector status: videos={len(video_list)}, vectors={len(vector_files)}")
|
| 107 |
+
return len(video_list) == len(vector_files)
|
| 108 |
+
|
| 109 |
+
def check_metrics_file(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> bool:
|
| 110 |
+
"""Check if overall_metrics.json exists in the model's CFG/metrics directory."""
|
| 111 |
+
metrics_path = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt / "metric" / "overall_metrics.json"
|
| 112 |
+
exists = metrics_path.exists() and metrics_path.is_file()
|
| 113 |
+
|
| 114 |
+
if exists:
|
| 115 |
+
self.logger.info(f"Found overall metrics file for {model_name}")
|
| 116 |
+
else:
|
| 117 |
+
self.logger.error(f"Overall metrics file not found for {model_name}")
|
| 118 |
+
return exists
|
| 119 |
+
|
| 120 |
+
def check_benchmark(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Dict[str, bool]:
|
| 121 |
+
"""
|
| 122 |
+
Perform all benchmark checks and return status.
|
| 123 |
+
"""
|
| 124 |
+
status = {
|
| 125 |
+
'benchmark_exists': False,
|
| 126 |
+
'model_exists': False,
|
| 127 |
+
'cfg_files_exist': False,
|
| 128 |
+
'vectors_match': False,
|
| 129 |
+
'metrics_exist': False
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Check benchmark directory
|
| 133 |
+
status['benchmark_exists'] = self.check_benchmark_exists(benchmark_name)
|
| 134 |
+
if not status['benchmark_exists']:
|
| 135 |
+
return status
|
| 136 |
+
|
| 137 |
+
# Get video list
|
| 138 |
+
video_list = self.get_video_list(benchmark_name)
|
| 139 |
+
|
| 140 |
+
# Check model directory
|
| 141 |
+
status['model_exists'] = self.check_model_exists(benchmark_name, model_name)
|
| 142 |
+
if not status['model_exists']:
|
| 143 |
+
return status
|
| 144 |
+
|
| 145 |
+
# Check CFG files
|
| 146 |
+
benchmark_cfg, model_cfg = self.check_cfg_files(benchmark_name, model_name, cfg_prompt)
|
| 147 |
+
status['cfg_files_exist'] = benchmark_cfg and model_cfg
|
| 148 |
+
if not status['cfg_files_exist']:
|
| 149 |
+
return status
|
| 150 |
+
|
| 151 |
+
# Check vectors
|
| 152 |
+
status['vectors_match'] = self.check_vector_files(benchmark_name, model_name, video_list)
|
| 153 |
+
|
| 154 |
+
# Check metrics file (only if vectors match)
|
| 155 |
+
if status['vectors_match']:
|
| 156 |
+
status['metrics_exist'] = self.check_metrics_file(benchmark_name, model_name, cfg_prompt)
|
| 157 |
+
|
| 158 |
+
return status
|
| 159 |
+
|
| 160 |
+
def get_benchmark_status(self, check_status: Dict[str, bool]) -> str:
|
| 161 |
+
"""Determine which execution path to take based on check results."""
|
| 162 |
+
basic_checks = ['benchmark_exists', 'model_exists', 'cfg_files_exist']
|
| 163 |
+
if not all(check_status[check] for check in basic_checks):
|
| 164 |
+
return "cannot_execute"
|
| 165 |
+
if check_status['vectors_match'] and check_status['metrics_exist']:
|
| 166 |
+
return "all_passed"
|
| 167 |
+
elif not check_status['vectors_match']:
|
| 168 |
+
return "no_vectors"
|
| 169 |
+
else: # vectors exist but no metrics
|
| 170 |
+
return "no_metrics"
|
| 171 |
+
|
| 172 |
+
# Example usage
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
|
| 175 |
+
bench_checker = BenchChecker("assets")
|
| 176 |
+
status = bench_checker.check_benchmark(
|
| 177 |
+
benchmark_name="huggingface_benchmarks_dataset",
|
| 178 |
+
model_name="MSRVTT",
|
| 179 |
+
cfg_prompt="topk"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
execution_path = bench_checker.get_benchmark_status(status)
|
| 183 |
+
print(f"Checks completed. Execution path: {execution_path}")
|
| 184 |
+
print(f"Status: {status}")
|
pia_bench/checker/sheet_checker.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional, Set, Tuple
|
| 2 |
+
import logging
|
| 3 |
+
import gspread
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
class SheetChecker:
|
| 10 |
+
def __init__(self, sheet_manager):
|
| 11 |
+
"""Initialize SheetChecker with a sheet manager instance."""
|
| 12 |
+
self.sheet_manager = sheet_manager
|
| 13 |
+
self.bench_sheet_manager = None
|
| 14 |
+
self.logger = logging.getLogger(__name__)
|
| 15 |
+
self._init_bench_sheet()
|
| 16 |
+
|
| 17 |
+
def _init_bench_sheet(self):
|
| 18 |
+
"""Initialize sheet manager for the model sheet."""
|
| 19 |
+
self.bench_sheet_manager = type(self.sheet_manager)(
|
| 20 |
+
spreadsheet_url=self.sheet_manager.spreadsheet_url,
|
| 21 |
+
worksheet_name="model",
|
| 22 |
+
column_name="Model name"
|
| 23 |
+
)
|
| 24 |
+
def add_benchmark_column(self, column_name: str):
|
| 25 |
+
"""Add a new benchmark column to the sheet."""
|
| 26 |
+
try:
|
| 27 |
+
# Get current headers
|
| 28 |
+
headers = self.bench_sheet_manager.get_available_columns()
|
| 29 |
+
|
| 30 |
+
# If column already exists, return
|
| 31 |
+
if column_name in headers:
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
# Add new column header
|
| 35 |
+
new_col_index = len(headers) + 1
|
| 36 |
+
cell = gspread.utils.rowcol_to_a1(1, new_col_index)
|
| 37 |
+
# Update with 2D array format
|
| 38 |
+
self.bench_sheet_manager.sheet.update(cell, [[column_name]]) # ๊ฐ์ 2D ๋ฐฐ์ด๋ก ๋ณ๊ฒฝ
|
| 39 |
+
self.logger.info(f"Added new benchmark column: {column_name}")
|
| 40 |
+
|
| 41 |
+
# Update headers in bench_sheet_manager
|
| 42 |
+
self.bench_sheet_manager._connect_to_sheet(validate_column=False)
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
self.logger.error(f"Error adding benchmark column {column_name}: {str(e)}")
|
| 46 |
+
raise
|
| 47 |
+
def validate_benchmark_columns(self, benchmark_columns: List[str]) -> Tuple[List[str], List[str]]:
|
| 48 |
+
"""
|
| 49 |
+
Validate benchmark columns and add missing ones.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
benchmark_columns: List of benchmark column names to validate
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Tuple[List[str], List[str]]: (valid columns, invalid columns)
|
| 56 |
+
"""
|
| 57 |
+
available_columns = self.bench_sheet_manager.get_available_columns()
|
| 58 |
+
valid_columns = []
|
| 59 |
+
invalid_columns = []
|
| 60 |
+
|
| 61 |
+
for col in benchmark_columns:
|
| 62 |
+
if col in available_columns:
|
| 63 |
+
valid_columns.append(col)
|
| 64 |
+
else:
|
| 65 |
+
try:
|
| 66 |
+
self.add_benchmark_column(col)
|
| 67 |
+
valid_columns.append(col)
|
| 68 |
+
self.logger.info(f"Added new benchmark column: {col}")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
invalid_columns.append(col)
|
| 71 |
+
self.logger.error(f"Failed to add benchmark column '{col}': {str(e)}")
|
| 72 |
+
|
| 73 |
+
return valid_columns, invalid_columns
|
| 74 |
+
|
| 75 |
+
def check_model_and_benchmarks(self, model_name: str, benchmark_columns: List[str]) -> Dict[str, List[str]]:
|
| 76 |
+
"""
|
| 77 |
+
Check model existence and which benchmarks need to be filled.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_name: Name of the model to check
|
| 81 |
+
benchmark_columns: List of benchmark column names to check
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Dict with keys:
|
| 85 |
+
'status': 'model_not_found' or 'model_exists'
|
| 86 |
+
'empty_benchmarks': List of benchmark columns that need to be filled
|
| 87 |
+
'filled_benchmarks': List of benchmark columns that are already filled
|
| 88 |
+
'invalid_benchmarks': List of benchmark columns that don't exist
|
| 89 |
+
"""
|
| 90 |
+
result = {
|
| 91 |
+
'status': '',
|
| 92 |
+
'empty_benchmarks': [],
|
| 93 |
+
'filled_benchmarks': [],
|
| 94 |
+
'invalid_benchmarks': []
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
# First check if model exists
|
| 98 |
+
exists = self.check_model_exists(model_name)
|
| 99 |
+
if not exists:
|
| 100 |
+
result['status'] = 'model_not_found'
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
result['status'] = 'model_exists'
|
| 104 |
+
|
| 105 |
+
# Validate benchmark columns
|
| 106 |
+
valid_columns, invalid_columns = self.validate_benchmark_columns(benchmark_columns)
|
| 107 |
+
result['invalid_benchmarks'] = invalid_columns
|
| 108 |
+
|
| 109 |
+
if not valid_columns:
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
# Check which valid benchmarks are empty
|
| 113 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 114 |
+
all_values = self.bench_sheet_manager.get_all_values()
|
| 115 |
+
row_index = all_values.index(model_name) + 2
|
| 116 |
+
|
| 117 |
+
for column in valid_columns:
|
| 118 |
+
try:
|
| 119 |
+
self.bench_sheet_manager.change_column(column)
|
| 120 |
+
value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
|
| 121 |
+
if not value or not value.strip():
|
| 122 |
+
result['empty_benchmarks'].append(column)
|
| 123 |
+
else:
|
| 124 |
+
result['filled_benchmarks'].append(column)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
self.logger.error(f"Error checking column {column}: {str(e)}")
|
| 127 |
+
result['empty_benchmarks'].append(column)
|
| 128 |
+
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
def update_model_info(self, model_name: str, model_info: Dict[str, str]):
|
| 132 |
+
"""Update basic model information columns."""
|
| 133 |
+
try:
|
| 134 |
+
for column_name, value in model_info.items():
|
| 135 |
+
self.bench_sheet_manager.change_column(column_name)
|
| 136 |
+
self.bench_sheet_manager.push(value)
|
| 137 |
+
self.logger.info(f"Successfully added new model: {model_name}")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
self.logger.error(f"Error updating model info: {str(e)}")
|
| 140 |
+
raise
|
| 141 |
+
|
| 142 |
+
def update_benchmarks(self, model_name: str, benchmark_values: Dict[str, str]):
|
| 143 |
+
"""
|
| 144 |
+
Update benchmark values.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
model_name: Name of the model
|
| 148 |
+
benchmark_values: Dictionary of benchmark column names and their values
|
| 149 |
+
"""
|
| 150 |
+
try:
|
| 151 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 152 |
+
all_values = self.bench_sheet_manager.get_all_values()
|
| 153 |
+
row_index = all_values.index(model_name) + 2
|
| 154 |
+
|
| 155 |
+
for column, value in benchmark_values.items():
|
| 156 |
+
self.bench_sheet_manager.change_column(column)
|
| 157 |
+
self.bench_sheet_manager.sheet.update_cell(row_index, self.bench_sheet_manager.col_index, value)
|
| 158 |
+
self.logger.info(f"Updated benchmark {column} for model {model_name}")
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
self.logger.error(f"Error updating benchmarks: {str(e)}")
|
| 162 |
+
raise
|
| 163 |
+
|
| 164 |
+
def check_model_exists(self, model_name: str) -> bool:
|
| 165 |
+
"""Check if model exists in the sheet."""
|
| 166 |
+
try:
|
| 167 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 168 |
+
values = self.bench_sheet_manager.get_all_values()
|
| 169 |
+
return model_name in values
|
| 170 |
+
except Exception as e:
|
| 171 |
+
self.logger.error(f"Error checking model existence: {str(e)}")
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def process_model_benchmarks(
|
| 176 |
+
model_name: str,
|
| 177 |
+
bench_checker: SheetChecker,
|
| 178 |
+
model_info_func,
|
| 179 |
+
benchmark_processor_func: callable,
|
| 180 |
+
benchmark_columns: List[str],
|
| 181 |
+
cfg_prompt: str
|
| 182 |
+
) -> None:
|
| 183 |
+
"""
|
| 184 |
+
Process model benchmarks according to the specified workflow.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
model_name: Name of the model to process
|
| 188 |
+
bench_checker: SheetChecker instance
|
| 189 |
+
model_info_func: Function that returns model info (name, link, etc.)
|
| 190 |
+
benchmark_processor_func: Function that processes empty benchmarks and returns values
|
| 191 |
+
benchmark_columns: List of benchmark columns to check
|
| 192 |
+
"""
|
| 193 |
+
try:
|
| 194 |
+
# Check model and benchmarks
|
| 195 |
+
check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
|
| 196 |
+
|
| 197 |
+
# Handle invalid benchmark columns
|
| 198 |
+
if check_result['invalid_benchmarks']:
|
| 199 |
+
bench_checker.logger.warning(
|
| 200 |
+
f"Skipping invalid benchmark columns: {', '.join(check_result['invalid_benchmarks'])}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# If model doesn't exist, add it
|
| 204 |
+
if check_result['status'] == 'model_not_found':
|
| 205 |
+
model_info = model_info_func(model_name)
|
| 206 |
+
bench_checker.update_model_info(model_name, model_info)
|
| 207 |
+
bench_checker.logger.info(f"Added new model: {model_name}")
|
| 208 |
+
# Recheck benchmarks after adding model
|
| 209 |
+
check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
|
| 210 |
+
|
| 211 |
+
# Log filled benchmarks
|
| 212 |
+
if check_result['filled_benchmarks']:
|
| 213 |
+
bench_checker.logger.info(
|
| 214 |
+
f"Skipping filled benchmark columns: {', '.join(check_result['filled_benchmarks'])}"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Process empty benchmarks
|
| 218 |
+
if check_result['empty_benchmarks']:
|
| 219 |
+
bench_checker.logger.info(
|
| 220 |
+
f"Processing empty benchmark columns: {', '.join(check_result['empty_benchmarks'])}"
|
| 221 |
+
)
|
| 222 |
+
# Get benchmark values from processor function
|
| 223 |
+
benchmark_values = benchmark_processor_func(
|
| 224 |
+
model_name,
|
| 225 |
+
check_result['empty_benchmarks'],
|
| 226 |
+
cfg_prompt
|
| 227 |
+
)
|
| 228 |
+
# Update benchmarks
|
| 229 |
+
bench_checker.update_benchmarks(model_name, benchmark_values)
|
| 230 |
+
else:
|
| 231 |
+
bench_checker.logger.info("No empty benchmark columns to process")
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
bench_checker.logger.error(f"Error processing model {model_name}: {str(e)}")
|
| 235 |
+
raise
|
| 236 |
+
|
| 237 |
+
def get_model_info(model_name: str) -> Dict[str, str]:
|
| 238 |
+
return {
|
| 239 |
+
"Model name": model_name,
|
| 240 |
+
"Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
|
| 241 |
+
"Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 242 |
+
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
def process_benchmarks(
|
| 246 |
+
model_name: str,
|
| 247 |
+
empty_benchmarks: List[str],
|
| 248 |
+
cfg_prompt: str
|
| 249 |
+
) -> Dict[str, str]:
|
| 250 |
+
"""
|
| 251 |
+
Measure benchmark scores for given model with specific configuration.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
model_name: Name of the model to evaluate
|
| 255 |
+
empty_benchmarks: List of benchmarks to measure
|
| 256 |
+
cfg_prompt: Prompt configuration for evaluation
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Dict[str, str]: Dictionary mapping benchmark names to their scores
|
| 260 |
+
"""
|
| 261 |
+
result = {}
|
| 262 |
+
for benchmark in empty_benchmarks:
|
| 263 |
+
# ์ค์ ๋ฒค์น๋งํฌ ์ธก์ ์ํ
|
| 264 |
+
# score = measure_benchmark(model_name, benchmark, cfg_prompt)
|
| 265 |
+
if benchmark == "COCO":
|
| 266 |
+
score = 0.5
|
| 267 |
+
elif benchmark == "ImageNet":
|
| 268 |
+
score = 15.0
|
| 269 |
+
result[benchmark] = str(score)
|
| 270 |
+
return result
|
| 271 |
+
# Example usage
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
|
| 274 |
+
sheet_manager = SheetManager()
|
| 275 |
+
bench_checker = SheetChecker(sheet_manager)
|
| 276 |
+
|
| 277 |
+
process_model_benchmarks(
|
| 278 |
+
"test-model",
|
| 279 |
+
bench_checker,
|
| 280 |
+
get_model_info,
|
| 281 |
+
process_benchmarks,
|
| 282 |
+
["COCO", "ImageNet"],
|
| 283 |
+
"cfg_prompt_value"
|
| 284 |
+
)
|
pia_bench/event_alarm.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
+
from devmacs_core.devmacs_core import DevMACSCore
|
| 6 |
+
# from devmacs_core.devmacs_core_copy import DevMACSCore
|
| 7 |
+
from devmacs_core.utils.common.cal import loose_similarity
|
| 8 |
+
from utils.parser import load_config, PromptManager
|
| 9 |
+
import json
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import logging
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from utils.except_dir import cust_listdir
|
| 15 |
+
from utils.logger import custom_logger
|
| 16 |
+
|
| 17 |
+
logger = custom_logger(__name__)
|
| 18 |
+
|
| 19 |
+
class EventDetector:
|
| 20 |
+
def __init__(self, config_path: str , model_name:str = None, token:str = None):
|
| 21 |
+
self.config = load_config(config_path)
|
| 22 |
+
self.macs = DevMACSCore.from_huggingface(token=token, repo_id=f"PIA-SPACE-LAB/{model_name}")
|
| 23 |
+
# self.macs = DevMACSCore(model_type="clip4clip_web")
|
| 24 |
+
|
| 25 |
+
self.prompt_manager = PromptManager(config_path)
|
| 26 |
+
self.sentences = self.prompt_manager.sentences
|
| 27 |
+
self.text_vectors = self.macs.get_text_vector(self.sentences)
|
| 28 |
+
|
| 29 |
+
def process_and_save_predictions(self, vector_base_dir: str, label_base_dir: str, save_base_dir: str):
|
| 30 |
+
"""๋น๋์ค ๋ฒกํฐ๋ฅผ ์ฒ๋ฆฌํ๊ณ ๊ฒฐ๊ณผ๋ฅผ CSV๋ก ์ ์ฅ"""
|
| 31 |
+
|
| 32 |
+
# ์ ์ฒด ๋น๋์ค ํ์ผ ์ ๊ณ์ฐ
|
| 33 |
+
total_videos = sum(len([f for f in cust_listdir(os.path.join(vector_base_dir, d))
|
| 34 |
+
if f.endswith('.npy')])
|
| 35 |
+
for d in cust_listdir(vector_base_dir)
|
| 36 |
+
if os.path.isdir(os.path.join(vector_base_dir, d)))
|
| 37 |
+
pbar = tqdm(total=total_videos, desc="Processing videos")
|
| 38 |
+
|
| 39 |
+
for category in cust_listdir(vector_base_dir):
|
| 40 |
+
category_path = os.path.join(vector_base_dir, category)
|
| 41 |
+
if not os.path.isdir(category_path):
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
# ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 45 |
+
save_category_dir = os.path.join(save_base_dir, category)
|
| 46 |
+
os.makedirs(save_category_dir, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
for file in cust_listdir(category_path):
|
| 49 |
+
if file.endswith('.npy'):
|
| 50 |
+
video_name = os.path.splitext(file)[0]
|
| 51 |
+
vector_path = os.path.join(category_path, file)
|
| 52 |
+
|
| 53 |
+
# ๋ผ๋ฒจ ํ์ผ ์ฝ๊ธฐ
|
| 54 |
+
label_path = os.path.join(label_base_dir, category, f"{video_name}.json")
|
| 55 |
+
with open(label_path, 'r') as f:
|
| 56 |
+
label_data = json.load(f)
|
| 57 |
+
total_frames = label_data['video_info']['total_frame']
|
| 58 |
+
|
| 59 |
+
# ์์ธก ๊ฒฐ๊ณผ ์์ฑ ๋ฐ ์ ์ฅ
|
| 60 |
+
self._process_and_save_single_video(
|
| 61 |
+
vector_path=vector_path,
|
| 62 |
+
total_frames=total_frames,
|
| 63 |
+
save_path=os.path.join(save_category_dir, f"{video_name}.csv")
|
| 64 |
+
)
|
| 65 |
+
pbar.update(1)
|
| 66 |
+
pbar.close()
|
| 67 |
+
|
| 68 |
+
def _process_and_save_single_video(self, vector_path: str, total_frames: int, save_path: str):
|
| 69 |
+
"""๋จ์ผ ๋น๋์ค ์ฒ๋ฆฌ ๋ฐ ์ ์ฅ"""
|
| 70 |
+
# ๊ธฐ๋ณธ ์์ธก ์ํ
|
| 71 |
+
sparse_predictions = self._process_single_vector(vector_path)
|
| 72 |
+
|
| 73 |
+
# ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํ ๋ฐ ํ์ฅ
|
| 74 |
+
df = self._expand_predictions(sparse_predictions, total_frames)
|
| 75 |
+
|
| 76 |
+
# CSV๋ก ์ ์ฅ
|
| 77 |
+
df.to_csv(save_path, index=False)
|
| 78 |
+
|
| 79 |
+
def _process_single_vector(self, vector_path: str) -> Dict:
|
| 80 |
+
"""๊ธฐ์กด ์์ธก ๋ก์ง"""
|
| 81 |
+
video_vector = np.load(vector_path)
|
| 82 |
+
processed_vectors = []
|
| 83 |
+
frame_interval = 15
|
| 84 |
+
|
| 85 |
+
for vector in video_vector:
|
| 86 |
+
v = vector.squeeze(0) # numpy array
|
| 87 |
+
v = torch.from_numpy(v).unsqueeze(0).cuda() # torch tensor๋ก ๋ณํ ํ GPU๋ก
|
| 88 |
+
processed_vectors.append(v)
|
| 89 |
+
|
| 90 |
+
frame_results = {}
|
| 91 |
+
for vector_idx, v in enumerate(processed_vectors):
|
| 92 |
+
actual_frame = vector_idx * frame_interval
|
| 93 |
+
sim_scores = loose_similarity(
|
| 94 |
+
sequence_output=self.text_vectors.cuda(),
|
| 95 |
+
visual_output=v.unsqueeze(1)
|
| 96 |
+
)
|
| 97 |
+
frame_results[actual_frame] = self._calculate_alarms(sim_scores)
|
| 98 |
+
|
| 99 |
+
return frame_results
|
| 100 |
+
|
| 101 |
+
def _expand_predictions(self, sparse_predictions: Dict, total_frames: int) -> pd.DataFrame:
|
| 102 |
+
"""์์ธก์ ์ ์ฒด ํ๋ ์์ผ๋ก ํ์ฅ"""
|
| 103 |
+
# ์นดํ
๊ณ ๋ฆฌ ๋ชฉ๋ก ์ถ์ถ (์ฒซ ๋ฒ์งธ ํ๋ ์์ ์๋ ๊ฒฐ๊ณผ์์)
|
| 104 |
+
first_frame = list(sparse_predictions.keys())[0]
|
| 105 |
+
categories = list(sparse_predictions[first_frame].keys())
|
| 106 |
+
|
| 107 |
+
# ์ ์ฒด ํ๋ ์ ์์ฑ
|
| 108 |
+
df = pd.DataFrame({'frame': range(total_frames)})
|
| 109 |
+
|
| 110 |
+
# ๊ฐ ์นดํ
๊ณ ๋ฆฌ์ ๋ํ ์๋ ๊ฐ ์ด๊ธฐํ
|
| 111 |
+
for category in categories:
|
| 112 |
+
df[category] = 0
|
| 113 |
+
|
| 114 |
+
# ์์ธก๊ฐ ์ฑ์ฐ๊ธฐ
|
| 115 |
+
frame_keys = sorted(sparse_predictions.keys())
|
| 116 |
+
for i in range(len(frame_keys)):
|
| 117 |
+
current_frame = frame_keys[i]
|
| 118 |
+
next_frame = frame_keys[i + 1] if i + 1 < len(frame_keys) else total_frames
|
| 119 |
+
|
| 120 |
+
# ๊ฐ ์นดํ
๊ณ ๋ฆฌ์ ์๋ ๊ฐ ์ค์
|
| 121 |
+
for category in categories:
|
| 122 |
+
alarm_value = sparse_predictions[current_frame][category]['alarm']
|
| 123 |
+
df.loc[current_frame:next_frame-1, category] = alarm_value
|
| 124 |
+
|
| 125 |
+
return df
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _calculate_alarms(self, sim_scores: torch.Tensor) -> Dict:
|
| 129 |
+
"""์ ์ฌ๋ ์ ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ ์ด๋ฒคํธ์ ์๋ ์ํ ๊ณ์ฐ"""
|
| 130 |
+
# ๋ก๊ฑฐ ์ค์
|
| 131 |
+
log_filename = f"alarm_calculation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
| 132 |
+
logging.basicConfig(
|
| 133 |
+
filename=log_filename,
|
| 134 |
+
level=logging.ERROR,
|
| 135 |
+
format='%(asctime)s - %(message)s',
|
| 136 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 137 |
+
)
|
| 138 |
+
logger = logging.getLogger(__name__)
|
| 139 |
+
|
| 140 |
+
event_alarms = {}
|
| 141 |
+
|
| 142 |
+
for event_config in self.config['PROMPT_CFG']:
|
| 143 |
+
event = event_config['event']
|
| 144 |
+
top_k = event_config['top_candidates']
|
| 145 |
+
threshold = event_config['alert_threshold']
|
| 146 |
+
|
| 147 |
+
# logger.info(f"\nProcessing event: {event}")
|
| 148 |
+
# logger.info(f"Top K: {top_k}, Threshold: {threshold}")
|
| 149 |
+
|
| 150 |
+
event_prompts = self._get_event_prompts(event)
|
| 151 |
+
|
| 152 |
+
# logger.debug(f"\nEvent Prompts Debug for {event}:")
|
| 153 |
+
# logger.debug(f"Indices: {event_prompts['indices']}")
|
| 154 |
+
# logger.debug(f"Types: {event_prompts['types']}")
|
| 155 |
+
# logger.debug(f"\nSim Scores Debug:")
|
| 156 |
+
# logger.debug(f"Shape: {sim_scores.shape}")
|
| 157 |
+
# logger.debug(f"Raw scores: {sim_scores}")
|
| 158 |
+
|
| 159 |
+
# event_scores = sim_scores[event_prompts['indices']]
|
| 160 |
+
event_scores = sim_scores[event_prompts['indices']].squeeze(-1) # shape ๋ณ๊ฒฝ
|
| 161 |
+
|
| 162 |
+
# logger.debug(f"Event scores shape: {event_scores.shape}")
|
| 163 |
+
# logger.debug(f"Event scores: {event_scores}")
|
| 164 |
+
# ๊ฐ ํ๋กฌํํธ์ ์ ์ ์ถ๋ ฅ
|
| 165 |
+
# logger.info("\nDEBUG VALUES:")
|
| 166 |
+
# logger.info(f"event_scores: {event_scores}")
|
| 167 |
+
# logger.info(f"indices: {event_prompts['indices']}")
|
| 168 |
+
# logger.info(f"types: {event_prompts['types']}")
|
| 169 |
+
|
| 170 |
+
# logger.info("\nAll prompts and scores:")
|
| 171 |
+
# for idx, (score, prompt_type) in enumerate(zip(event_scores, event_prompts['types'])):
|
| 172 |
+
# logger.info(f"Type: {prompt_type}, Score: {score.item():.4f}")
|
| 173 |
+
|
| 174 |
+
top_k_values, top_k_indices = torch.topk(event_scores, min(top_k, len(event_scores)))
|
| 175 |
+
|
| 176 |
+
# logger.info(f"top_k_values: {top_k_values}")
|
| 177 |
+
# logger.info(f"top_k_indices (raw): {top_k_indices}")
|
| 178 |
+
# Top K ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 179 |
+
# logger.info(f"\nTop {top_k} selections:")
|
| 180 |
+
for idx, (value, index) in enumerate(zip(top_k_values, top_k_indices)):
|
| 181 |
+
# indices[index]๊ฐ ์๋ index๋ฅผ ์ง์ ์ฌ์ฉ
|
| 182 |
+
prompt_type = event_prompts['types'][index] # ์์ ๋ ๋ถ๋ถ
|
| 183 |
+
# logger.info(f"DEBUG: index={index}, types={event_prompts['types']}, selected_type={prompt_type}")
|
| 184 |
+
# logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
|
| 185 |
+
|
| 186 |
+
abnormal_count = sum(1 for idx in top_k_indices
|
| 187 |
+
if event_prompts['types'][idx] == 'abnormal') # ์์ ๋ ๋ถ๋ถ
|
| 188 |
+
# for idx, (value, orig_idx) in enumerate(zip(top_k_values, top_k_indices)):
|
| 189 |
+
# prompt_type = event_prompts['types'][orig_idx.item()]
|
| 190 |
+
# logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
|
| 191 |
+
|
| 192 |
+
# abnormal_count = sum(1 for idx in top_k_indices
|
| 193 |
+
# if event_prompts['types'][idx.item()] == 'abnormal')
|
| 194 |
+
|
| 195 |
+
# ์๋ ๊ฒฐ์ ๊ณผ์ ์ถ๋ ฅ
|
| 196 |
+
# logger.info(f"\nAbnormal count: {abnormal_count}")
|
| 197 |
+
alarm_result = 1 if abnormal_count >= threshold else 0
|
| 198 |
+
# logger.info(f"Final alarm decision: {alarm_result}")
|
| 199 |
+
# logger.info("-" * 50)
|
| 200 |
+
|
| 201 |
+
event_alarms[event] = {
|
| 202 |
+
'alarm': alarm_result,
|
| 203 |
+
'scores': top_k_values.tolist(),
|
| 204 |
+
'top_k_types': [event_prompts['types'][idx.item()] for idx in top_k_indices]
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# ๋ก๊ฑฐ ์ข
๋ฃ
|
| 208 |
+
logging.shutdown()
|
| 209 |
+
|
| 210 |
+
return event_alarms
|
| 211 |
+
|
| 212 |
+
def _get_event_prompts(self, event: str) -> Dict:
|
| 213 |
+
indices = []
|
| 214 |
+
types = []
|
| 215 |
+
current_idx = 0
|
| 216 |
+
|
| 217 |
+
for event_config in self.config['PROMPT_CFG']:
|
| 218 |
+
if event_config['event'] == event:
|
| 219 |
+
for status in ['normal', 'abnormal']:
|
| 220 |
+
for _ in range(len(event_config['prompts'][status])):
|
| 221 |
+
indices.append(current_idx)
|
| 222 |
+
types.append(status)
|
| 223 |
+
current_idx += 1
|
| 224 |
+
|
| 225 |
+
return {'indices': indices, 'types': types}
|
| 226 |
+
|
| 227 |
+
|
pia_bench/metric.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
import json
|
| 7 |
+
from utils.except_dir import cust_listdir
|
| 8 |
+
|
| 9 |
+
class MetricsEvaluator:
|
| 10 |
+
def __init__(self, pred_dir: str, label_dir: str, save_dir: str):
|
| 11 |
+
"""
|
| 12 |
+
Args:
|
| 13 |
+
pred_dir: ์์ธก csv ํ์ผ๋ค์ด ์๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 14 |
+
label_dir: ์ ๋ต csv ํ์ผ๋ค์ด ์๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 15 |
+
save_dir: ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 16 |
+
"""
|
| 17 |
+
self.pred_dir = pred_dir
|
| 18 |
+
self.label_dir = label_dir
|
| 19 |
+
self.save_dir = save_dir
|
| 20 |
+
|
| 21 |
+
def evaluate(self) -> Dict:
|
| 22 |
+
"""์ ์ฒด ํ๊ฐ ์ํ"""
|
| 23 |
+
category_metrics = {} # ์นดํ
๊ณ ๋ฆฌ๋ณ ํ๊ท ์ฑ๋ฅ ์ ์ฅ
|
| 24 |
+
all_metrics = { # ๋ชจ๋ ์นดํ
๊ณ ๋ฆฌ ํตํฉ ๋ฉํธ๋ฆญ
|
| 25 |
+
'falldown': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
|
| 26 |
+
'violence': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
|
| 27 |
+
'fire': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
# ๋ชจ๋ ์นดํ
๊ณ ๋ฆฌ์ metrics๋ฅผ ์ ์ฅํ DataFrame ๋ฆฌ์คํธ
|
| 31 |
+
all_categories_metrics = []
|
| 32 |
+
|
| 33 |
+
for category in cust_listdir(self.pred_dir):
|
| 34 |
+
if not os.path.isdir(os.path.join(self.pred_dir, category)):
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
pred_category_path = os.path.join(self.pred_dir, category)
|
| 38 |
+
label_category_path = os.path.join(self.label_dir, category)
|
| 39 |
+
save_category_path = os.path.join(self.save_dir, category)
|
| 40 |
+
os.makedirs(save_category_path, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ์ ์ํ ๋ฐ์ดํฐํ๋ ์ ์์ฑ
|
| 43 |
+
metrics_df = self._evaluate_category(category, pred_category_path, label_category_path)
|
| 44 |
+
|
| 45 |
+
metrics_df['category'] = category
|
| 46 |
+
|
| 47 |
+
metrics_df.to_csv(os.path.join(save_category_path, f"{category}_metrics.csv"), index=False)
|
| 48 |
+
|
| 49 |
+
all_categories_metrics.append(metrics_df)
|
| 50 |
+
|
| 51 |
+
# ์นดํ
๊ณ ๋ฆฌ๋ณ ํ๊ท ์ฑ๋ฅ ์ ์ฅ
|
| 52 |
+
category_metrics[category] = metrics_df.iloc[-1].to_dict() # ๋ง์ง๋ง row(ํ๊ท )
|
| 53 |
+
|
| 54 |
+
# ์ ์ฒด ํ๊ท ์ ์ํ ๋ฉํธ๋ฆญ ์์ง
|
| 55 |
+
# for col in metrics_df.columns:
|
| 56 |
+
# if col != 'video_name':
|
| 57 |
+
# event_type, metric_type = col.split('_')
|
| 58 |
+
# all_metrics[event_type][metric_type].append(category_metrics[category][col])
|
| 59 |
+
|
| 60 |
+
for col in metrics_df.columns:
|
| 61 |
+
if col != 'video_name':
|
| 62 |
+
try:
|
| 63 |
+
# ์ฒซ ๋ฒ์งธ ์ธ๋์ค์ฝ์ด๋ฅผ ๊ธฐ์ค์ผ๋ก ์ด๋ฒคํธ ํ์
๊ณผ ๋ฉํธ๋ฆญ ํ์
๋ถ๋ฆฌ
|
| 64 |
+
parts = col.split('_', 1) # maxsplit=1๋ก ์ฒซ ๋ฒ์งธ ์ธ๋์ค์ฝ์ด์์๋ง ๋ถ๋ฆฌ
|
| 65 |
+
if len(parts) == 2:
|
| 66 |
+
event_type, metric_type = parts
|
| 67 |
+
if event_type in all_metrics and metric_type in all_metrics[event_type]:
|
| 68 |
+
all_metrics[event_type][metric_type].append(category_metrics[category][col])
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Warning: Could not process column {col}: {str(e)}")
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
# ๊ฐ DataFrame์์ ๋ง์ง๋ง ํ(average)์ ์ ๊ฑฐ
|
| 74 |
+
all_categories_metrics_without_avg = [df.iloc[:-1] for df in all_categories_metrics]
|
| 75 |
+
# ๋ชจ๋ ์นดํ
๊ณ ๋ฆฌ์ metrics๋ฅผ ํ๋์ DataFrame์ผ๋ก ํฉ์น๊ธฐ
|
| 76 |
+
combined_metrics_df = pd.concat(all_categories_metrics_without_avg, ignore_index=True)
|
| 77 |
+
# ํฉ์ณ์ง metrics๋ฅผ json ํ์ผ๊ณผ ๊ฐ์ ์์น์ ์ ์ฅ
|
| 78 |
+
combined_metrics_df.to_csv(os.path.join(self.save_dir, "all_categories_metrics.csv"), index=False)
|
| 79 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 80 |
+
# print("\nCategory-wise Average Metrics:")
|
| 81 |
+
# for category, metrics in category_metrics.items():
|
| 82 |
+
# print(f"\n{category}:")
|
| 83 |
+
# for metric_name, value in metrics.items():
|
| 84 |
+
# if metric_name != "video_name":
|
| 85 |
+
# print(f"{metric_name}: {value:.3f}")
|
| 86 |
+
|
| 87 |
+
print("\nCategory-wise Average Metrics:")
|
| 88 |
+
for category, metrics in category_metrics.items():
|
| 89 |
+
print(f"\n{category}:")
|
| 90 |
+
for metric_name, value in metrics.items():
|
| 91 |
+
if metric_name != "video_name":
|
| 92 |
+
try:
|
| 93 |
+
if isinstance(value, str):
|
| 94 |
+
print(f"{metric_name}: {value}")
|
| 95 |
+
elif metric_name in ['tp', 'tn', 'fp', 'fn']:
|
| 96 |
+
print(f"{metric_name}: {int(value)}")
|
| 97 |
+
else:
|
| 98 |
+
print(f"{metric_name}: {float(value):.3f}")
|
| 99 |
+
except (ValueError, TypeError):
|
| 100 |
+
print(f"{metric_name}: {value}")
|
| 101 |
+
# ์ ์ฒด ํ๊ท ๊ณ์ฐ ๋ฐ ์ถ๋ ฅ
|
| 102 |
+
print("\n" + "="*50)
|
| 103 |
+
print("Overall Average Metrics Across All Categories:")
|
| 104 |
+
print("="*50)
|
| 105 |
+
|
| 106 |
+
# for event_type in all_metrics:
|
| 107 |
+
# print(f"\n{event_type}:")
|
| 108 |
+
# for metric_type, values in all_metrics[event_type].items():
|
| 109 |
+
# avg_value = np.mean(values)
|
| 110 |
+
# print(f"{metric_type}: {avg_value:.3f}")
|
| 111 |
+
|
| 112 |
+
for event_type in all_metrics:
|
| 113 |
+
print(f"\n{event_type}:")
|
| 114 |
+
for metric_type, values in all_metrics[event_type].items():
|
| 115 |
+
avg_value = np.mean(values)
|
| 116 |
+
if metric_type in ['tp', 'tn', 'fp', 'fn']: # ์ ์ ๊ฐ
|
| 117 |
+
print(f"{metric_type}: {int(avg_value)}")
|
| 118 |
+
else: # ์์์ ๊ฐ
|
| 119 |
+
print(f"{metric_type}: {avg_value:.3f}")
|
| 120 |
+
##################################################################################################
|
| 121 |
+
# ์ต์ข
๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋์
๋๋ฆฌ
|
| 122 |
+
final_results = {
|
| 123 |
+
"category_metrics": {},
|
| 124 |
+
"overall_metrics": {}
|
| 125 |
+
}
|
| 126 |
+
# ์นดํ
๊ณ ๋ฆฌ๋ณ ๋ฉํธ๋ฆญ ์ ์ฅ
|
| 127 |
+
|
| 128 |
+
for category, metrics in category_metrics.items():
|
| 129 |
+
final_results["category_metrics"][category] = {}
|
| 130 |
+
for metric_name, value in metrics.items():
|
| 131 |
+
if metric_name != "video_name":
|
| 132 |
+
if isinstance(value, (int, float)):
|
| 133 |
+
final_results["category_metrics"][category][metric_name] = float(value)
|
| 134 |
+
|
| 135 |
+
# ์ ์ฒด ํ๊ท ๊ณ์ฐ ๋ฐ ์ ์ฅ
|
| 136 |
+
for event_type in all_metrics:
|
| 137 |
+
# print(f"\n{event_type}:")
|
| 138 |
+
final_results["overall_metrics"][event_type] = {}
|
| 139 |
+
for metric_type, values in all_metrics[event_type].items():
|
| 140 |
+
avg_value = float(np.mean(values))
|
| 141 |
+
# print(f"{metric_type}: {avg_value:.3f}")
|
| 142 |
+
final_results["overall_metrics"][event_type][metric_type] = avg_value
|
| 143 |
+
|
| 144 |
+
# JSON ํ์ผ๋ก ์ ์ฅ
|
| 145 |
+
json_path = os.path.join(self.save_dir, "overall_metrics.json")
|
| 146 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 147 |
+
json.dump(final_results, f, indent=4)
|
| 148 |
+
|
| 149 |
+
# return category_metrics
|
| 150 |
+
|
| 151 |
+
# ๋์ ๋ฉํธ๋ฆญ ๊ณ์ฐ
|
| 152 |
+
accumulated_metrics = self.calculate_accumulated_metrics(combined_metrics_df)
|
| 153 |
+
|
| 154 |
+
# JSON์ ๋์ ๋ฉํธ๋ฆญ ์ถ๊ฐ
|
| 155 |
+
final_results["accumulated_metrics"] = accumulated_metrics
|
| 156 |
+
|
| 157 |
+
# ๋์ ๋ฉํธ๋ฆญ๋ง ๋ฐ๋ก ์ ์ฅ
|
| 158 |
+
accumulated_json_path = os.path.join(self.save_dir, "accumulated_metrics.json")
|
| 159 |
+
with open(accumulated_json_path, 'w', encoding='utf-8') as f:
|
| 160 |
+
json.dump(accumulated_metrics, f, indent=4)
|
| 161 |
+
|
| 162 |
+
return accumulated_metrics
|
| 163 |
+
|
| 164 |
+
def _evaluate_category(self, category: str, pred_path: str, label_path: str) -> pd.DataFrame:
|
| 165 |
+
"""์นดํ
๊ณ ๋ฆฌ๋ณ ํ๊ฐ ์ํ"""
|
| 166 |
+
results = []
|
| 167 |
+
metrics_columns = ['video_name']
|
| 168 |
+
|
| 169 |
+
for pred_file in cust_listdir(pred_path):
|
| 170 |
+
if not pred_file.endswith('.csv'):
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
video_name = os.path.splitext(pred_file)[0]
|
| 174 |
+
pred_df = pd.read_csv(os.path.join(pred_path, pred_file))
|
| 175 |
+
|
| 176 |
+
# ํด๋น ๋น๋์ค์ ์ ๋ต CSV ํ์ผ ๋ก๋
|
| 177 |
+
label_file = f"{video_name}.csv"
|
| 178 |
+
label_path_full = os.path.join(label_path, label_file)
|
| 179 |
+
|
| 180 |
+
if not os.path.exists(label_path_full):
|
| 181 |
+
print(f"Warning: Label file not found for {video_name}")
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
label_df = pd.read_csv(label_path_full)
|
| 185 |
+
|
| 186 |
+
# ๊ฐ ์นดํ
๊ณ ๋ฆฌ๋ณ ๋ฉํธ๋ฆญ ๊ณ์ฐ
|
| 187 |
+
video_metrics = {'video_name': video_name}
|
| 188 |
+
categories = [col for col in pred_df.columns if col != 'frame']
|
| 189 |
+
|
| 190 |
+
for cat in categories:
|
| 191 |
+
# ์ ๋ต๊ฐ๊ณผ ์์ธก๊ฐ
|
| 192 |
+
y_true = label_df[cat].values
|
| 193 |
+
y_pred = pred_df[cat].values
|
| 194 |
+
|
| 195 |
+
# ๋ฉํธ๋ฆญ ๊ณ์ฐ
|
| 196 |
+
metrics = self._calculate_metrics(y_true, y_pred)
|
| 197 |
+
|
| 198 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 199 |
+
for metric_name, value in metrics.items():
|
| 200 |
+
col_name = f"{cat}_{metric_name}"
|
| 201 |
+
video_metrics[col_name] = value
|
| 202 |
+
if col_name not in metrics_columns:
|
| 203 |
+
metrics_columns.append(col_name)
|
| 204 |
+
|
| 205 |
+
results.append(video_metrics)
|
| 206 |
+
|
| 207 |
+
# ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํ
|
| 208 |
+
metrics_df = pd.DataFrame(results, columns=metrics_columns)
|
| 209 |
+
|
| 210 |
+
# ํ๊ท ๊ณ์ฐํ์ฌ ์ถ๊ฐ
|
| 211 |
+
avg_metrics = {'video_name': 'average'}
|
| 212 |
+
for col in metrics_columns[1:]: # video_name ์ ์ธ
|
| 213 |
+
avg_metrics[col] = metrics_df[col].mean()
|
| 214 |
+
|
| 215 |
+
metrics_df = pd.concat([metrics_df, pd.DataFrame([avg_metrics])], ignore_index=True)
|
| 216 |
+
|
| 217 |
+
return metrics_df
|
| 218 |
+
|
| 219 |
+
# def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
|
| 220 |
+
# """์ฑ๋ฅ ์งํ ๊ณ์ฐ"""
|
| 221 |
+
# tn = np.sum((y_true == 0) & (y_pred == 0))
|
| 222 |
+
# fp = np.sum((y_true == 0) & (y_pred == 1))
|
| 223 |
+
|
| 224 |
+
# metrics = {
|
| 225 |
+
# 'f1': f1_score(y_true, y_pred, zero_division=0),
|
| 226 |
+
# 'accuracy': accuracy_score(y_true, y_pred),
|
| 227 |
+
# 'precision': precision_score(y_true, y_pred, zero_division=0),
|
| 228 |
+
# 'recall': recall_score(y_true, y_pred, zero_division=0),
|
| 229 |
+
# 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 230 |
+
# }
|
| 231 |
+
|
| 232 |
+
# return metrics
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def calculate_accumulated_metrics(self, all_categories_metrics_df: pd.DataFrame) -> Dict:
|
| 236 |
+
"""๋์ ๋ ํผ๋ํ๋ ฌ๋ก ๊ฐ ์นดํ
๊ณ ๋ฆฌ๋ณ ์ฑ๋ฅ ์งํ ๊ณ์ฐ"""
|
| 237 |
+
accumulated_results = {"micro_avg": {}}
|
| 238 |
+
categories = ['falldown', 'violence', 'fire']
|
| 239 |
+
|
| 240 |
+
for category in categories:
|
| 241 |
+
# ํด๋น ์นดํ
๊ณ ๋ฆฌ์ ํผ๋ํ๋ ฌ ๊ฐ๋ค ๋์
|
| 242 |
+
tp = all_categories_metrics_df[f'{category}_tp'].sum()
|
| 243 |
+
tn = all_categories_metrics_df[f'{category}_tn'].sum()
|
| 244 |
+
fp = all_categories_metrics_df[f'{category}_fp'].sum()
|
| 245 |
+
fn = all_categories_metrics_df[f'{category}_fn'].sum()
|
| 246 |
+
|
| 247 |
+
# ๊ธฐ๋ณธ ๋ฉํธ๋ฆญ ๊ณ์ฐ
|
| 248 |
+
metrics = {
|
| 249 |
+
'tp': int(tp),
|
| 250 |
+
'tn': int(tn),
|
| 251 |
+
'fp': int(fp),
|
| 252 |
+
'fn': int(fn),
|
| 253 |
+
'accuracy': (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0,
|
| 254 |
+
'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
|
| 255 |
+
'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
|
| 256 |
+
'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
|
| 257 |
+
'f1': 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# ์ถ๊ฐ ๋ฉํธ๋ฆญ ๊ณ์ฐ
|
| 261 |
+
tpr = metrics['recall'] # TPR = recall
|
| 262 |
+
tnr = metrics['specificity'] # TNR = specificity
|
| 263 |
+
|
| 264 |
+
# Balanced Accuracy
|
| 265 |
+
metrics['balanced_accuracy'] = (tpr + tnr) / 2
|
| 266 |
+
|
| 267 |
+
# G-Mean
|
| 268 |
+
metrics['g_mean'] = np.sqrt(tpr * tnr) if (tpr * tnr) > 0 else 0
|
| 269 |
+
|
| 270 |
+
# MCC (Matthews Correlation Coefficient)
|
| 271 |
+
numerator = (tp * tn) - (fp * fn)
|
| 272 |
+
denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
|
| 273 |
+
metrics['mcc'] = numerator / denominator if denominator > 0 else 0
|
| 274 |
+
|
| 275 |
+
# NPV (Negative Predictive Value)
|
| 276 |
+
metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
|
| 277 |
+
|
| 278 |
+
# FAR (False Alarm Rate) = FPR = 1 - specificity
|
| 279 |
+
metrics['far'] = 1 - metrics['specificity']
|
| 280 |
+
|
| 281 |
+
accumulated_results[category] = metrics
|
| 282 |
+
|
| 283 |
+
# ์ ์ฒด ์นดํ
๊ณ ๋ฆฌ์ ๋์ ๊ฐ์ผ๋ก ๊ณ์ฐ
|
| 284 |
+
total_tp = sum(accumulated_results[cat]['tp'] for cat in categories)
|
| 285 |
+
total_tn = sum(accumulated_results[cat]['tn'] for cat in categories)
|
| 286 |
+
total_fp = sum(accumulated_results[cat]['fp'] for cat in categories)
|
| 287 |
+
total_fn = sum(accumulated_results[cat]['fn'] for cat in categories)
|
| 288 |
+
|
| 289 |
+
# micro average ๊ณ์ฐ (์ ์ฒด ๋์ ๊ฐ์ผ๋ก ๊ณ์ฐ)
|
| 290 |
+
accumulated_results["micro_avg"] = {
|
| 291 |
+
'tp': int(total_tp),
|
| 292 |
+
'tn': int(total_tn),
|
| 293 |
+
'fp': int(total_fp),
|
| 294 |
+
'fn': int(total_fn),
|
| 295 |
+
'accuracy': (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn),
|
| 296 |
+
'precision': total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0,
|
| 297 |
+
'recall': total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0,
|
| 298 |
+
'f1': 2 * total_tp / (2 * total_tp + total_fp + total_fn) if (2 * total_tp + total_fp + total_fn) > 0 else 0,
|
| 299 |
+
# ... (๋ค๋ฅธ ๋ฉํธ๋ฆญ๋ค๋ ๋์ผํ ๋ฐฉ์์ผ๋ก ๊ณ์ฐ)
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
return accumulated_results
|
| 303 |
+
def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
|
| 304 |
+
"""์ฑ๋ฅ ์งํ ๊ณ์ฐ"""
|
| 305 |
+
tn = np.sum((y_true == 0) & (y_pred == 0))
|
| 306 |
+
fp = np.sum((y_true == 0) & (y_pred == 1))
|
| 307 |
+
fn = np.sum((y_true == 1) & (y_pred == 0))
|
| 308 |
+
tp = np.sum((y_true == 1) & (y_pred == 1))
|
| 309 |
+
|
| 310 |
+
metrics = {
|
| 311 |
+
'f1': f1_score(y_true, y_pred, zero_division=0),
|
| 312 |
+
'accuracy': accuracy_score(y_true, y_pred),
|
| 313 |
+
'precision': precision_score(y_true, y_pred, zero_division=0),
|
| 314 |
+
'recall': recall_score(y_true, y_pred, zero_division=0),
|
| 315 |
+
'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
|
| 316 |
+
'tp': int(tp),
|
| 317 |
+
'tn': int(tn),
|
| 318 |
+
'fp': int(fp),
|
| 319 |
+
'fn': int(fn)
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
return metrics
|
pia_bench/pipe_line/piepline.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pia_bench.checker.bench_checker import BenchChecker
|
| 2 |
+
from pia_bench.checker.sheet_checker import SheetChecker
|
| 3 |
+
from pia_bench.event_alarm import EventDetector
|
| 4 |
+
from pia_bench.metric import MetricsEvaluator
|
| 5 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 6 |
+
from pia_bench.bench import PiaBenchMark
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from typing import Optional, List , Dict
|
| 9 |
+
import os
|
| 10 |
+
load_dotenv()
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Dict, Tuple
|
| 13 |
+
from typing import Dict, Optional, Tuple
|
| 14 |
+
import logging
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from sheet_manager.sheet_checker.sheet_check import SheetChecker
|
| 17 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 18 |
+
from pia_bench.checker.bench_checker import BenchChecker
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
|
| 22 |
+
from enviroments.config import BASE_BENCH_PATH
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class PipelineConfig:
|
| 26 |
+
"""ํ์ดํ๋ผ์ธ ์ค์ ์ ์ํ ๋ฐ์ดํฐ ํด๋์ค"""
|
| 27 |
+
model_name: str
|
| 28 |
+
benchmark_name: str
|
| 29 |
+
cfg_target_path: str
|
| 30 |
+
base_path: str = BASE_BENCH_PATH
|
| 31 |
+
|
| 32 |
+
class BenchmarkPipelineStatus:
|
| 33 |
+
"""ํ์ดํ๋ผ์ธ ์ํ ๋ฐ ๊ฒฐ๊ณผ ๊ด๋ฆฌ"""
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.sheet_status: Tuple[bool, bool] = (False, False) # (model_added, benchmark_exists)
|
| 36 |
+
self.bench_status: Dict[str, bool] = {}
|
| 37 |
+
self.bench_result: str = ""
|
| 38 |
+
self.current_stage: str = "not_started"
|
| 39 |
+
|
| 40 |
+
def is_success(self) -> bool:
|
| 41 |
+
"""์ ์ฒด ํ์ดํ๋ผ์ธ ์ฑ๊ณต ์ฌ๋ถ"""
|
| 42 |
+
return (not self.sheet_status[0] # ๋ชจ๋ธ์ด ์ด๋ฏธ ์กด์ฌํ๊ณ
|
| 43 |
+
and self.sheet_status[1] # ๋ฒค์น๋งํฌ๊ฐ ์กด์ฌํ๊ณ
|
| 44 |
+
and self.bench_result == "all_passed") # ๋ฒค์น๋งํฌ ์ฒดํฌ๋ ํต๊ณผ
|
| 45 |
+
|
| 46 |
+
def __str__(self) -> str:
|
| 47 |
+
return (f"Current Stage: {self.current_stage}\n"
|
| 48 |
+
f"Sheet Status: {self.sheet_status}\n"
|
| 49 |
+
f"Bench Status: {self.bench_status}\n"
|
| 50 |
+
f"Bench Result: {self.bench_result}")
|
| 51 |
+
|
| 52 |
+
class BenchmarkPipeline:
|
| 53 |
+
"""๋ฒค์น๋งํฌ ์คํ์ ์ํ ํ์ดํ๋ผ์ธ"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config: PipelineConfig):
|
| 56 |
+
self.config = config
|
| 57 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 58 |
+
self.status = BenchmarkPipelineStatus()
|
| 59 |
+
self.access_token = os.getenv("ACCESS_TOKEN")
|
| 60 |
+
self.cfg_prompt = os.path.splitext(os.path.basename(self.config.cfg_target_path))[0]
|
| 61 |
+
|
| 62 |
+
# Initialize checkers
|
| 63 |
+
self.sheet_manager = SheetManager()
|
| 64 |
+
self.sheet_checker = SheetChecker(self.sheet_manager)
|
| 65 |
+
self.bench_checker = BenchChecker(self.config.base_path)
|
| 66 |
+
|
| 67 |
+
self.bench_result_dict = None
|
| 68 |
+
|
| 69 |
+
def run(self) -> BenchmarkPipelineStatus:
|
| 70 |
+
"""์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ"""
|
| 71 |
+
try:
|
| 72 |
+
self.status.current_stage = "sheet_check"
|
| 73 |
+
proceed = self._check_sheet()
|
| 74 |
+
|
| 75 |
+
if not proceed:
|
| 76 |
+
self.status.current_stage = "completed_no_action_needed"
|
| 77 |
+
self.logger.info("๋ฒค์น๋งํฌ๊ฐ ์ด๋ฏธ ์กด์ฌํ์ฌ ์ถ๊ฐ ์์
์ด ํ์ํ์ง ์์ต๋๋ค.")
|
| 78 |
+
return self.status
|
| 79 |
+
|
| 80 |
+
self.status.current_stage = "bench_check"
|
| 81 |
+
if not self._check_bench():
|
| 82 |
+
return self.status
|
| 83 |
+
|
| 84 |
+
self.status.current_stage = "execution"
|
| 85 |
+
self._execute_based_on_status()
|
| 86 |
+
|
| 87 |
+
self.status.current_stage = "completed"
|
| 88 |
+
return self.status
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
self.logger.error(f"ํ์ดํ๋ผ์ธ ์คํ ์ค ์๋ฌ ๋ฐ์: {str(e)}")
|
| 92 |
+
self.status.current_stage = "error"
|
| 93 |
+
return self.status
|
| 94 |
+
|
| 95 |
+
def _check_sheet(self) -> bool:
|
| 96 |
+
"""๊ตฌ๊ธ ์ํธ ์ํ ์ฒดํฌ"""
|
| 97 |
+
self.logger.info("์ํธ ์ํ ์ฒดํฌ ์์")
|
| 98 |
+
model_added, benchmark_exists = self.sheet_checker.check_model_and_benchmark(
|
| 99 |
+
self.config.model_name,
|
| 100 |
+
self.config.benchmark_name
|
| 101 |
+
)
|
| 102 |
+
self.status.sheet_status = (model_added, benchmark_exists)
|
| 103 |
+
|
| 104 |
+
if model_added:
|
| 105 |
+
self.logger.info("์๋ก์ด ๋ชจ๋ธ์ด ์ถ๊ฐ๋์์ต๋๋ค")
|
| 106 |
+
if not benchmark_exists:
|
| 107 |
+
self.logger.info("๋ฒค์น๋งํฌ ์ธก์ ์ด ํ์ํฉ๋๋ค")
|
| 108 |
+
return True # ๋ฒค์น๋งํฌ ์ธก์ ์ด ํ์ํ ๊ฒฝ์ฐ๋ง ๋ค์ ๋จ๊ณ๋ก ์งํ
|
| 109 |
+
|
| 110 |
+
self.logger.info("์ด๋ฏธ ๋ฒค์น๋งํฌ๊ฐ ์กด์ฌํฉ๋๋ค. ํ์ดํ๋ผ์ธ์ ์ข
๋ฃํฉ๋๋ค.")
|
| 111 |
+
return False # ๋ฒค์น๋งํฌ๊ฐ ์ด๋ฏธ ์์ผ๋ฉด ์ฌ๊ธฐ์ ์ค๋จ
|
| 112 |
+
|
| 113 |
+
def _check_bench(self) -> bool:
|
| 114 |
+
"""๋ก์ปฌ ๋ฒค์น๋งํฌ ํ๊ฒฝ ์ฒดํฌ"""
|
| 115 |
+
self.logger.info("๋ฒค์น๋งํฌ ํ๊ฒฝ ์ฒดํฌ ์์")
|
| 116 |
+
self.status.bench_status = self.bench_checker.check_benchmark(
|
| 117 |
+
self.config.benchmark_name,
|
| 118 |
+
self.config.model_name,
|
| 119 |
+
self.cfg_prompt
|
| 120 |
+
)
|
| 121 |
+
self.status.bench_result = self.bench_checker.get_benchmark_status(
|
| 122 |
+
self.status.bench_status
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# no bench ์ํ ๋ฒค์น๋ฅผ ๋๋ฆฐ์ ์ด ์์ ํด๋๊ตฌ์กฐ๋ ์์
|
| 126 |
+
if self.status.bench_result == "no bench":
|
| 127 |
+
self.logger.error("๋ฒค์น๋งํฌ ์คํ์ ํ์ํ ๊ธฐ๋ณธ ํด๋๊ตฌ์กฐ๊ฐ ์์ต๋๋ค.")
|
| 128 |
+
return True
|
| 129 |
+
|
| 130 |
+
return True # ๊ทธ ์ธ์ ๊ฒฝ์ฐ๋ง ๋ค์ ๋จ๊ณ๋ก ์งํ
|
| 131 |
+
|
| 132 |
+
def _execute_based_on_status(self):
|
| 133 |
+
"""์ํ์ ๋ฐ๋ฅธ ์คํ ๋ก์ง"""
|
| 134 |
+
if self.status.bench_result == "all_passed":
|
| 135 |
+
self._execute_full_pipeline()
|
| 136 |
+
elif self.status.bench_result == "no_vectors":
|
| 137 |
+
self._execute_vector_generation()
|
| 138 |
+
elif self.status.bench_result == "no_metrics":
|
| 139 |
+
self._execute_metrics_generation()
|
| 140 |
+
else:
|
| 141 |
+
self._execute_vector_generation()
|
| 142 |
+
self.logger.warning("ํด๋๊ตฌ์กฐ๊ฐ ์์ต๋๋ค")
|
| 143 |
+
|
| 144 |
+
def _execute_full_pipeline(self):
|
| 145 |
+
"""๋ชจ๋ ์กฐ๊ฑด์ด ์ถฉ์กฑ๋ ๊ฒฝ์ฐ์ ์คํ ๋ก์ง"""
|
| 146 |
+
self.logger.info("์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ ์ค...")
|
| 147 |
+
pia_benchmark = PiaBenchMark(
|
| 148 |
+
benchmark_path = f"{BASE_BENCH_PATH}/{self.config.benchmark_name}" ,
|
| 149 |
+
model_name=self.config.model_name,
|
| 150 |
+
cfg_target_path= self.config.cfg_target_path ,
|
| 151 |
+
token=self.access_token )
|
| 152 |
+
pia_benchmark.preprocess_structure()
|
| 153 |
+
print("Categories identified:", pia_benchmark.categories)
|
| 154 |
+
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
|
| 155 |
+
label_dir=pia_benchmark.dataset_path,
|
| 156 |
+
save_dir=pia_benchmark.metric_path)
|
| 157 |
+
|
| 158 |
+
self.bench_result_dict = metric.evaluate()
|
| 159 |
+
|
| 160 |
+
def _execute_vector_generation(self):
|
| 161 |
+
"""๋ฒกํฐ ์์ฑ์ด ํ์ํ ๊ฒฝ์ฐ์ ์คํ ๋ก์ง"""
|
| 162 |
+
self.logger.info("๋ฒกํฐ ์์ฑ ์ค...")
|
| 163 |
+
# ๊ตฌํ ํ์
|
| 164 |
+
|
| 165 |
+
pia_benchmark = PiaBenchMark(
|
| 166 |
+
benchmark_path = f"{BASE_BENCH_PATH}/{self.config.benchmark_name}" ,
|
| 167 |
+
model_name=self.config.model_name,
|
| 168 |
+
cfg_target_path= self.config.cfg_target_path ,
|
| 169 |
+
token=self.access_token )
|
| 170 |
+
pia_benchmark.preprocess_structure()
|
| 171 |
+
pia_benchmark.preprocess_label_to_csv()
|
| 172 |
+
print("Categories identified:", pia_benchmark.categories)
|
| 173 |
+
|
| 174 |
+
pia_benchmark.extract_visual_vector()
|
| 175 |
+
|
| 176 |
+
detector = EventDetector(config_path=self.config.cfg_target_path,
|
| 177 |
+
model_name=self.config.model_name ,
|
| 178 |
+
token=pia_benchmark.token)
|
| 179 |
+
detector.process_and_save_predictions(pia_benchmark.vector_video_path,
|
| 180 |
+
pia_benchmark.dataset_path,
|
| 181 |
+
pia_benchmark.alram_path)
|
| 182 |
+
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
|
| 183 |
+
label_dir=pia_benchmark.dataset_path,
|
| 184 |
+
save_dir=pia_benchmark.metric_path)
|
| 185 |
+
|
| 186 |
+
self.bench_result_dict = metric.evaluate()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _execute_metrics_generation(self):
|
| 190 |
+
"""๋ฉํธ๋ฆญ ์์ฑ์ด ํ์ํ ๊ฒฝ์ฐ์ ์คํ ๋ก์ง"""
|
| 191 |
+
self.logger.info("๋ฉํธ๋ฆญ ์์ฑ ์ค...")
|
| 192 |
+
# ๊ตฌํ ํ์
|
| 193 |
+
pia_benchmark = PiaBenchMark(
|
| 194 |
+
benchmark_path = f"{BASE_BENCH_PATH}/{self.config.benchmark_name}" ,
|
| 195 |
+
model_name=self.config.model_name,
|
| 196 |
+
cfg_target_path= self.config.cfg_target_path ,
|
| 197 |
+
token=self.access_token )
|
| 198 |
+
pia_benchmark.preprocess_structure()
|
| 199 |
+
pia_benchmark.preprocess_label_to_csv()
|
| 200 |
+
print("Categories identified:", pia_benchmark.categories)
|
| 201 |
+
|
| 202 |
+
detector = EventDetector(config_path=self.config.cfg_target_path,
|
| 203 |
+
model_name=self.config.model_name ,
|
| 204 |
+
token=pia_benchmark.token)
|
| 205 |
+
detector.process_and_save_predictions(pia_benchmark.vector_video_path,
|
| 206 |
+
pia_benchmark.dataset_path,
|
| 207 |
+
pia_benchmark.alram_path)
|
| 208 |
+
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
|
| 209 |
+
label_dir=pia_benchmark.dataset_path,
|
| 210 |
+
save_dir=pia_benchmark.metric_path)
|
| 211 |
+
|
| 212 |
+
self.bench_result_dict = metric.evaluate()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
# ํ์ดํ๋ผ์ธ ์ค์
|
| 217 |
+
config = PipelineConfig(
|
| 218 |
+
model_name="T2V_CLIP4CLIP_MSRVTT",
|
| 219 |
+
benchmark_name="PIA",
|
| 220 |
+
cfg_target_path="topk.json",
|
| 221 |
+
base_path=f"{BASE_BENCH_PATH}"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# ํ์ดํ๋ผ์ธ ์คํ
|
| 225 |
+
pipeline = BenchmarkPipeline(config)
|
| 226 |
+
result = pipeline.run()
|
| 227 |
+
|
| 228 |
+
print(f"\nํ์ดํ๋ผ์ธ ์คํ ๊ฒฐ๊ณผ:")
|
| 229 |
+
print(str(result))
|
sheet_manager/sheet_checker/sheet_check.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
import logging
|
| 3 |
+
import gspread
|
| 4 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 5 |
+
|
| 6 |
+
class SheetChecker:
|
| 7 |
+
def __init__(self, sheet_manager: SheetManager):
|
| 8 |
+
"""SheetChecker ์ด๊ธฐํ"""
|
| 9 |
+
self.sheet_manager = sheet_manager
|
| 10 |
+
self.bench_sheet_manager = None
|
| 11 |
+
self.logger = logging.getLogger(__name__)
|
| 12 |
+
self._init_bench_sheet()
|
| 13 |
+
|
| 14 |
+
def _init_bench_sheet(self):
|
| 15 |
+
"""model ์ํธ์ฉ ์ํธ ๋งค๋์ ์ด๊ธฐํ"""
|
| 16 |
+
self.bench_sheet_manager = type(self.sheet_manager)(
|
| 17 |
+
spreadsheet_url=self.sheet_manager.spreadsheet_url,
|
| 18 |
+
worksheet_name="model",
|
| 19 |
+
column_name="Model name"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def add_benchmark_column(self, column_name: str):
|
| 23 |
+
"""์๋ก์ด ๋ฒค์น๋งํฌ ์ปฌ๋ผ ์ถ๊ฐ"""
|
| 24 |
+
try:
|
| 25 |
+
headers = self.bench_sheet_manager.get_available_columns()
|
| 26 |
+
|
| 27 |
+
if column_name in headers:
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
new_col_index = len(headers) + 1
|
| 31 |
+
cell = gspread.utils.rowcol_to_a1(1, new_col_index)
|
| 32 |
+
self.bench_sheet_manager.sheet.update(cell, [[column_name]])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ๊ด๋ จ ์ปฌ๋ผ ์ถ๊ฐ (๋ฒค์น๋งํฌ์ด๋ฆ*100)
|
| 36 |
+
next_col_index = new_col_index + 1
|
| 37 |
+
next_cell = gspread.utils.rowcol_to_a1(1, next_col_index)
|
| 38 |
+
self.bench_sheet_manager.sheet.update(next_cell, [[f"{column_name}*100"]])
|
| 39 |
+
|
| 40 |
+
self.logger.info(f"์๋ก์ด ๋ฒค์น๋งํฌ ์ปฌ๋ผ๋ค ์ถ๊ฐ๋จ: {column_name}, {column_name}*100")
|
| 41 |
+
# ์ปฌ๋ผ ์ถ๊ฐ ํ ์ํธ ๋งค๋์ ์ฌ์ฐ๊ฒฐ
|
| 42 |
+
self.bench_sheet_manager._connect_to_sheet(validate_column=False)
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
self.logger.error(f"๋ฒค์น๋งํฌ ์ปฌ๋ผ {column_name} ์ถ๊ฐ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
def check_model_and_benchmark(self, model_name: str, benchmark_name: str) -> Tuple[bool, bool]:
|
| 49 |
+
"""
|
| 50 |
+
๋ชจ๋ธ ์กด์ฌ ์ฌ๋ถ์ ๋ฒค์น๋งํฌ ์ํ๋ฅผ ํ์ธํ๊ณ , ํ์ํ ๊ฒฝ์ฐ ๋ชจ๋ธ ์ ๋ณด๋ฅผ ์ถ๊ฐ
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model_name: ํ์ธํ ๋ชจ๋ธ ์ด๋ฆ
|
| 54 |
+
benchmark_name: ํ์ธํ ๋ฒค์น๋งํฌ ์ด๋ฆ
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple[bool, bool]: (๋ชจ๋ธ์ด ์๋ก ์ถ๊ฐ๋์๋์ง ์ฌ๋ถ, ๋ฒค์น๋งํฌ๊ฐ ์ด๋ฏธ ์กด์ฌํ๋์ง ์ฌ๋ถ)
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
# ๋ชจ๋ธ ์กด์ฌ ์ฌ๋ถ ํ์ธ
|
| 61 |
+
model_exists = self._check_model_exists(model_name)
|
| 62 |
+
model_added = False
|
| 63 |
+
|
| 64 |
+
# ๋ชจ๋ธ์ด ์์ผ๋ฉด ์ถ๊ฐ
|
| 65 |
+
if not model_exists:
|
| 66 |
+
self._add_new_model(model_name)
|
| 67 |
+
model_added = True
|
| 68 |
+
self.logger.info(f"์๋ก์ด ๋ชจ๋ธ ์ถ๊ฐ๋จ: {model_name}")
|
| 69 |
+
|
| 70 |
+
# ๋ฒค์น๋งํฌ ์ปฌ๋ผ์ด ์์ผ๋ฉด ์ถ๊ฐ
|
| 71 |
+
available_columns = self.bench_sheet_manager.get_available_columns()
|
| 72 |
+
if benchmark_name not in available_columns:
|
| 73 |
+
self.add_benchmark_column(benchmark_name)
|
| 74 |
+
self.logger.info(f"์๋ก์ด ๋ฒค์น๋งํฌ ์ปฌ๋ผ ์ถ๊ฐ๋จ: {benchmark_name}")
|
| 75 |
+
|
| 76 |
+
# ๋ฒค์น๋งํฌ ์ํ ํ์ธ
|
| 77 |
+
benchmark_exists = self._check_benchmark_exists(model_name, benchmark_name)
|
| 78 |
+
|
| 79 |
+
return model_added, benchmark_exists
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
self.logger.error(f"๋ชจ๋ธ/๋ฒค์น๋งํฌ ํ์ธ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 83 |
+
raise
|
| 84 |
+
|
| 85 |
+
def _check_model_exists(self, model_name: str) -> bool:
|
| 86 |
+
"""๋ชจ๋ธ ์กด์ฌ ์ฌ๋ถ ํ์ธ"""
|
| 87 |
+
try:
|
| 88 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 89 |
+
values = self.bench_sheet_manager.get_all_values()
|
| 90 |
+
return model_name in values
|
| 91 |
+
except Exception as e:
|
| 92 |
+
self.logger.error(f"๋ชจ๋ธ ์กด์ฌ ์ฌ๋ถ ํ์ธ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
def _add_new_model(self, model_name: str):
|
| 96 |
+
"""์๋ก์ด ๋ชจ๋ธ ์ ๋ณด ์ถ๊ฐ"""
|
| 97 |
+
try:
|
| 98 |
+
model_info = {
|
| 99 |
+
"Model name": model_name,
|
| 100 |
+
"Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
|
| 101 |
+
"Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
for column_name, value in model_info.items():
|
| 105 |
+
self.bench_sheet_manager.change_column(column_name)
|
| 106 |
+
self.bench_sheet_manager.push(value)
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
self.logger.error(f"๋ชจ๋ธ ์ ๋ณด ์ถ๊ฐ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 110 |
+
raise
|
| 111 |
+
|
| 112 |
+
def _check_benchmark_exists(self, model_name: str, benchmark_name: str) -> bool:
|
| 113 |
+
"""๋ฒค์น๋งํฌ ๊ฐ ์กด์ฌ ์ฌ๋ถ ํ์ธ"""
|
| 114 |
+
try:
|
| 115 |
+
# ํด๋น ๋ชจ๋ธ์ ๋ฒค์น๋งํฌ ๊ฐ ํ์ธ
|
| 116 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 117 |
+
all_values = self.bench_sheet_manager.get_all_values()
|
| 118 |
+
row_index = all_values.index(model_name) + 2
|
| 119 |
+
|
| 120 |
+
self.bench_sheet_manager.change_column(benchmark_name)
|
| 121 |
+
value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
|
| 122 |
+
|
| 123 |
+
return bool(value and value.strip())
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
self.logger.error(f"๋ฒค์น๋งํฌ ์กด์ฌ ์ฌ๋ถ ํ์ธ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 127 |
+
raise
|
| 128 |
+
|
| 129 |
+
# ์ฌ์ฉ ์์
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
sheet_manager = SheetManager()
|
| 132 |
+
checker = SheetChecker(sheet_manager)
|
| 133 |
+
|
| 134 |
+
model_added, benchmark_exists = checker.check_model_and_benchmark(
|
| 135 |
+
model_name="test-model",
|
| 136 |
+
benchmark_name="COCO"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
print(f"Model added: {model_added}")
|
| 140 |
+
print(f"Benchmark exists: {benchmark_exists}")
|
sheet_manager/sheet_convert/json2sheet.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 3 |
+
import json
|
| 4 |
+
from typing import Optional, Dict
|
| 5 |
+
|
| 6 |
+
def update_benchmark_json(
|
| 7 |
+
model_name: str,
|
| 8 |
+
benchmark_data: dict,
|
| 9 |
+
worksheet_name: str = "metric",
|
| 10 |
+
target_column: str = "benchmark" # ํ๊ฒ ์นผ๋ผ ํ๋ผ๋ฏธํฐ ์ถ๊ฐ
|
| 11 |
+
):
|
| 12 |
+
"""
|
| 13 |
+
ํน์ ๋ชจ๋ธ์ ๋ฒค์น๋งํฌ ๋ฐ์ดํฐ๋ฅผ JSON ํํ๋ก ์
๋ฐ์ดํธํฉ๋๋ค.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
model_name (str): ์
๋ฐ์ดํธํ ๋ชจ๋ธ ์ด๋ฆ
|
| 17 |
+
benchmark_data (dict): ์
๋ฐ์ดํธํ ๋ฒค์น๋งํฌ ๋ฐ์ดํฐ ๋์
๋๋ฆฌ
|
| 18 |
+
worksheet_name (str): ์์
ํ ์ํฌ์ํธ ์ด๋ฆ (๊ธฐ๋ณธ๊ฐ: "metric")
|
| 19 |
+
target_column (str): ์
๋ฐ์ดํธํ ํ๊ฒ ์นผ๋ผ ์ด๋ฆ (๊ธฐ๋ณธ๊ฐ: "benchmark")
|
| 20 |
+
"""
|
| 21 |
+
sheet_manager = SheetManager(worksheet_name=worksheet_name)
|
| 22 |
+
|
| 23 |
+
# ๋์
๋๋ฆฌ๋ฅผ JSON ๋ฌธ์์ด๋ก ๋ณํ
|
| 24 |
+
json_str = json.dumps(benchmark_data, ensure_ascii=False)
|
| 25 |
+
|
| 26 |
+
# ๋ชจ๋ธ๋ช
์ ๊ธฐ์ค์ผ๋ก ์ง์ ๋ ์นผ๋ผ ์
๋ฐ์ดํธ
|
| 27 |
+
row = sheet_manager.update_cell_by_condition(
|
| 28 |
+
condition_column="Model name", # ๋ชจ๋ธ๋ช
์ด ์๋ ์นผ๋ผ
|
| 29 |
+
condition_value=model_name, # ์ฐพ์ ๋ชจ๋ธ๋ช
|
| 30 |
+
target_column=target_column, # ์
๋ฐ์ดํธํ ํ๊ฒ ์นผ๋ผ
|
| 31 |
+
target_value=json_str # ์
๋ฐ์ดํธํ JSON ๊ฐ
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if row:
|
| 35 |
+
print(f"Successfully updated {target_column} data for model: {model_name}")
|
| 36 |
+
else:
|
| 37 |
+
print(f"Model {model_name} not found in the sheet")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_benchmark_dict(
|
| 42 |
+
model_name: str,
|
| 43 |
+
worksheet_name: str = "metric",
|
| 44 |
+
target_column: str = "benchmark",
|
| 45 |
+
save_path: Optional[str] = None
|
| 46 |
+
) -> Dict:
|
| 47 |
+
"""
|
| 48 |
+
์ํธ์์ ํน์ ๋ชจ๋ธ์ ๋ฒค์น๋งํฌ JSON ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ ๋์
๋๋ฆฌ๋ก ๋ณํํฉ๋๋ค.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model_name (str): ๊ฐ์ ธ์ฌ ๋ชจ๋ธ ์ด๋ฆ
|
| 52 |
+
worksheet_name (str): ์์
ํ ์ํฌ์ํธ ์ด๋ฆ (๊ธฐ๋ณธ๊ฐ: "metric")
|
| 53 |
+
target_column (str): ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ฌ ์นผ๋ผ ์ด๋ฆ (๊ธฐ๋ณธ๊ฐ: "benchmark")
|
| 54 |
+
save_path (str, optional): ๋์
๋๋ฆฌ๋ฅผ ์ ์ฅํ JSON ํ์ผ ๊ฒฝ๋ก
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Dict: ๋ฒค์น๋งํฌ ๋ฐ์ดํฐ ๋์
๋๋ฆฌ. ๋ฐ์ดํฐ๊ฐ ์๊ฑฐ๋ JSON ํ์ฑ ์คํจ์ ๋น ๋์
๋๋ฆฌ ๋ฐํ
|
| 58 |
+
"""
|
| 59 |
+
sheet_manager = SheetManager(worksheet_name=worksheet_name)
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
# ๋ชจ๋ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
|
| 63 |
+
data = sheet_manager.sheet.get_all_records()
|
| 64 |
+
|
| 65 |
+
# ํด๋น ๋ชจ๋ธ ์ฐพ๊ธฐ
|
| 66 |
+
target_row = next(
|
| 67 |
+
(row for row in data if row.get("Model name") == model_name),
|
| 68 |
+
None
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if not target_row:
|
| 72 |
+
print(f"Model {model_name} not found in the sheet")
|
| 73 |
+
return {}
|
| 74 |
+
|
| 75 |
+
# ํ๊ฒ ์นผ๋ผ์ JSON ๋ฌธ์์ด ๊ฐ์ ธ์ค๊ธฐ
|
| 76 |
+
json_str = target_row.get(target_column)
|
| 77 |
+
|
| 78 |
+
if not json_str:
|
| 79 |
+
print(f"No data found in {target_column} for model: {model_name}")
|
| 80 |
+
return {}
|
| 81 |
+
|
| 82 |
+
# JSON ๋ฌธ์์ด์ ๋์
๋๋ฆฌ๋ก ๋ณํ
|
| 83 |
+
result_dict = json.loads(json_str)
|
| 84 |
+
|
| 85 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ (save_path๊ฐ ์ ๊ณต๋ ๊ฒฝ์ฐ)
|
| 86 |
+
if save_path:
|
| 87 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 88 |
+
json.dump(result_dict, f, ensure_ascii=False, indent=2)
|
| 89 |
+
print(f"Successfully saved dictionary to: {save_path}")
|
| 90 |
+
|
| 91 |
+
return result_dict
|
| 92 |
+
|
| 93 |
+
except json.JSONDecodeError:
|
| 94 |
+
print(f"Failed to parse JSON data for model: {model_name}")
|
| 95 |
+
return {}
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Error occurred: {str(e)}")
|
| 98 |
+
return {}
|
| 99 |
+
|
| 100 |
+
def str2json(json_str):
|
| 101 |
+
"""
|
| 102 |
+
๋ฌธ์์ด์ JSON ๊ฐ์ฒด๋ก ๋ณํํฉ๋๋ค.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
json_str (str): JSON ํ์์ ๋ฌธ์์ด
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
dict: ํ์ฑ๋ JSON ๊ฐ์ฒด, ์คํจ์ None
|
| 109 |
+
"""
|
| 110 |
+
try:
|
| 111 |
+
return json.loads(json_str)
|
| 112 |
+
except json.JSONDecodeError as e:
|
| 113 |
+
print(f"JSON Parsing Error: {e}")
|
| 114 |
+
return None
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Unexpected Error: {e}")
|
| 117 |
+
return None
|
sheet_manager/sheet_crud/create_col.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 3 |
+
import gspread
|
| 4 |
+
from huggingface_hub import HfApi
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from enviroments.convert import get_json_from_env_var
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
def push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization):
|
| 12 |
+
"""
|
| 13 |
+
Fetches model names from Hugging Face and updates a Google Sheet with the names, links, and HTML links.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
json_key_path (str): Path to the Google service account JSON key file.
|
| 17 |
+
spreadsheet_url (str): URL of the Google Spreadsheet.
|
| 18 |
+
sheet_name (str): Name of the sheet to update.
|
| 19 |
+
access_token (str): Hugging Face access token.
|
| 20 |
+
organization (str): Organization name on Hugging Face.
|
| 21 |
+
"""
|
| 22 |
+
# Authorize Google Sheets API
|
| 23 |
+
scope = ['https://spreadsheets.google.com/feeds',
|
| 24 |
+
'https://www.googleapis.com/auth/drive']
|
| 25 |
+
json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
|
| 26 |
+
credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
|
| 27 |
+
gc = gspread.authorize(credential)
|
| 28 |
+
|
| 29 |
+
# Open the Google Spreadsheet
|
| 30 |
+
doc = gc.open_by_url(spreadsheet_url)
|
| 31 |
+
sheet = doc.worksheet(sheet_name)
|
| 32 |
+
|
| 33 |
+
# Fetch existing data from the sheet
|
| 34 |
+
existing_data = pd.DataFrame(sheet.get_all_records())
|
| 35 |
+
|
| 36 |
+
# Fetch models from Hugging Face
|
| 37 |
+
api = HfApi()
|
| 38 |
+
models = api.list_models(author=organization, use_auth_token=access_token)
|
| 39 |
+
|
| 40 |
+
# Extract model names, links, and HTML links
|
| 41 |
+
model_details = [{
|
| 42 |
+
"Model name": model.modelId.split("/")[1],
|
| 43 |
+
"Model link": f"https://huggingface.co/{model.modelId}",
|
| 44 |
+
"Model": f"<a target=\"_blank\" href=\"https://huggingface.co/{model.modelId}\" style=\"color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;\">{model.modelId}</a>"
|
| 45 |
+
} for model in models]
|
| 46 |
+
|
| 47 |
+
new_data_df = pd.DataFrame(model_details)
|
| 48 |
+
|
| 49 |
+
# Check for duplicates and update only new model names
|
| 50 |
+
if "Model name" in existing_data.columns:
|
| 51 |
+
existing_model_names = existing_data["Model name"].tolist()
|
| 52 |
+
else:
|
| 53 |
+
existing_model_names = []
|
| 54 |
+
|
| 55 |
+
new_data_df = new_data_df[~new_data_df["Model name"].isin(existing_model_names)]
|
| 56 |
+
|
| 57 |
+
if not new_data_df.empty:
|
| 58 |
+
# Append new model names, links, and HTML links to the existing data
|
| 59 |
+
updated_data = pd.concat([existing_data, new_data_df], ignore_index=True)
|
| 60 |
+
|
| 61 |
+
# Push updated data back to the sheet
|
| 62 |
+
updated_data = updated_data.replace([float('inf'), float('-inf')], None) # Infinity ๊ฐ์ None์ผ๋ก ๋ณํ
|
| 63 |
+
updated_data = updated_data.fillna('') # NaN ๊ฐ์ ๋น ๋ฌธ์์ด๋ก ๋ณํ
|
| 64 |
+
sheet.update([updated_data.columns.values.tolist()] + updated_data.values.tolist())
|
| 65 |
+
print("New model names, links, and HTML links successfully added to Google Sheet.")
|
| 66 |
+
else:
|
| 67 |
+
print("No new model names to add.")
|
| 68 |
+
|
| 69 |
+
# Example usage
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
spreadsheet_url = os.getenv("SPREADSHEET_URL")
|
| 72 |
+
access_token = os.getenv("ACCESS_TOKEN")
|
| 73 |
+
sheet_name = "์ํธ1"
|
| 74 |
+
organization = "PIA-SPACE-LAB"
|
| 75 |
+
|
| 76 |
+
push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization)
|
sheet_manager/sheet_crud/sheet_crud.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 3 |
+
import gspread
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from enviroments.convert import get_json_from_env_var
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
|
| 8 |
+
load_dotenv(override=True)
|
| 9 |
+
|
| 10 |
+
class SheetManager:
|
| 11 |
+
def __init__(self, spreadsheet_url: Optional[str] = None,
|
| 12 |
+
worksheet_name: str = "flag",
|
| 13 |
+
column_name: str = "huggingface_id"):
|
| 14 |
+
"""
|
| 15 |
+
Initialize SheetManager with Google Sheets credentials and connection.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
spreadsheet_url (str, optional): URL of the Google Spreadsheet.
|
| 19 |
+
If None, takes from environment variable.
|
| 20 |
+
worksheet_name (str): Name of the worksheet to operate on.
|
| 21 |
+
Defaults to "flag".
|
| 22 |
+
column_name (str): Name of the column to operate on.
|
| 23 |
+
Defaults to "huggingface_id".
|
| 24 |
+
"""
|
| 25 |
+
self.spreadsheet_url = spreadsheet_url or os.getenv("SPREADSHEET_URL")
|
| 26 |
+
if not self.spreadsheet_url:
|
| 27 |
+
raise ValueError("Spreadsheet URL not provided and not found in environment variables")
|
| 28 |
+
|
| 29 |
+
self.worksheet_name = worksheet_name
|
| 30 |
+
self.column_name = column_name
|
| 31 |
+
|
| 32 |
+
# Initialize credentials and client
|
| 33 |
+
self._init_google_client()
|
| 34 |
+
|
| 35 |
+
# Initialize sheet connection
|
| 36 |
+
self.doc = None
|
| 37 |
+
self.sheet = None
|
| 38 |
+
self.col_index = None
|
| 39 |
+
self._connect_to_sheet(validate_column=True)
|
| 40 |
+
|
| 41 |
+
def _init_google_client(self):
|
| 42 |
+
"""Initialize Google Sheets client with credentials."""
|
| 43 |
+
scope = ['https://spreadsheets.google.com/feeds',
|
| 44 |
+
'https://www.googleapis.com/auth/drive']
|
| 45 |
+
json_key_dict = get_json_from_env_var("GOOGLE_CREDENTIALS")
|
| 46 |
+
credentials = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
|
| 47 |
+
self.client = gspread.authorize(credentials)
|
| 48 |
+
|
| 49 |
+
def _connect_to_sheet(self, validate_column: bool = True):
|
| 50 |
+
"""
|
| 51 |
+
Connect to the specified Google Sheet and initialize necessary attributes.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
validate_column (bool): Whether to validate the column name exists
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
self.doc = self.client.open_by_url(self.spreadsheet_url)
|
| 58 |
+
|
| 59 |
+
# Try to get the worksheet
|
| 60 |
+
try:
|
| 61 |
+
self.sheet = self.doc.worksheet(self.worksheet_name)
|
| 62 |
+
except gspread.exceptions.WorksheetNotFound:
|
| 63 |
+
raise ValueError(f"Worksheet '{self.worksheet_name}' not found in spreadsheet")
|
| 64 |
+
|
| 65 |
+
# Get headers
|
| 66 |
+
self.headers = self.sheet.row_values(1)
|
| 67 |
+
|
| 68 |
+
# Validate column only if requested
|
| 69 |
+
if validate_column:
|
| 70 |
+
try:
|
| 71 |
+
self.col_index = self.headers.index(self.column_name) + 1
|
| 72 |
+
except ValueError:
|
| 73 |
+
# If column not found, use first available column
|
| 74 |
+
if self.headers:
|
| 75 |
+
self.column_name = self.headers[0]
|
| 76 |
+
self.col_index = 1
|
| 77 |
+
print(f"Column '{self.column_name}' not found. Using first available column: '{self.headers[0]}'")
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError("No columns found in worksheet")
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
if isinstance(e, ValueError):
|
| 83 |
+
raise e
|
| 84 |
+
raise ConnectionError(f"Failed to connect to sheet: {str(e)}")
|
| 85 |
+
|
| 86 |
+
def change_worksheet(self, worksheet_name: str, column_name: Optional[str] = None):
|
| 87 |
+
"""
|
| 88 |
+
Change the current worksheet and optionally the column.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
worksheet_name (str): Name of the worksheet to switch to
|
| 92 |
+
column_name (str, optional): Name of the column to switch to
|
| 93 |
+
"""
|
| 94 |
+
old_worksheet = self.worksheet_name
|
| 95 |
+
old_column = self.column_name
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
self.worksheet_name = worksheet_name
|
| 99 |
+
if column_name:
|
| 100 |
+
self.column_name = column_name
|
| 101 |
+
|
| 102 |
+
# First connect without column validation
|
| 103 |
+
self._connect_to_sheet(validate_column=False)
|
| 104 |
+
|
| 105 |
+
# Then validate the column if specified
|
| 106 |
+
if column_name:
|
| 107 |
+
self.change_column(column_name)
|
| 108 |
+
else:
|
| 109 |
+
# Validate existing column in new worksheet
|
| 110 |
+
try:
|
| 111 |
+
self.col_index = self.headers.index(self.column_name) + 1
|
| 112 |
+
except ValueError:
|
| 113 |
+
# If column not found, use first available column
|
| 114 |
+
if self.headers:
|
| 115 |
+
self.column_name = self.headers[0]
|
| 116 |
+
self.col_index = 1
|
| 117 |
+
print(f"Column '{old_column}' not found in new worksheet. Using first available column: '{self.headers[0]}'")
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("No columns found in worksheet")
|
| 120 |
+
|
| 121 |
+
print(f"Successfully switched to worksheet: {worksheet_name}, using column: {self.column_name}")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
# Restore previous state on error
|
| 125 |
+
self.worksheet_name = old_worksheet
|
| 126 |
+
self.column_name = old_column
|
| 127 |
+
self._connect_to_sheet()
|
| 128 |
+
raise e
|
| 129 |
+
|
| 130 |
+
def change_column(self, column_name: str):
|
| 131 |
+
"""
|
| 132 |
+
Change the target column.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
column_name (str): Name of the column to switch to
|
| 136 |
+
"""
|
| 137 |
+
if not self.headers:
|
| 138 |
+
self.headers = self.sheet.row_values(1)
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
self.col_index = self.headers.index(column_name) + 1
|
| 142 |
+
self.column_name = column_name
|
| 143 |
+
print(f"Successfully switched to column: {column_name}")
|
| 144 |
+
except ValueError:
|
| 145 |
+
raise ValueError(f"Column '{column_name}' not found in worksheet. Available columns: {', '.join(self.headers)}")
|
| 146 |
+
|
| 147 |
+
def get_available_worksheets(self) -> List[str]:
|
| 148 |
+
"""Get list of all available worksheets in the spreadsheet."""
|
| 149 |
+
return [worksheet.title for worksheet in self.doc.worksheets()]
|
| 150 |
+
|
| 151 |
+
def get_available_columns(self) -> List[str]:
|
| 152 |
+
"""Get list of all available columns in the current worksheet."""
|
| 153 |
+
return self.headers if self.headers else self.sheet.row_values(1)
|
| 154 |
+
|
| 155 |
+
def _reconnect_if_needed(self):
|
| 156 |
+
"""Reconnect to the sheet if the connection is lost."""
|
| 157 |
+
try:
|
| 158 |
+
self.sheet.row_values(1)
|
| 159 |
+
except (gspread.exceptions.APIError, AttributeError):
|
| 160 |
+
self._init_google_client()
|
| 161 |
+
self._connect_to_sheet()
|
| 162 |
+
|
| 163 |
+
def _fetch_column_data(self) -> List[str]:
|
| 164 |
+
"""Fetch all data from the huggingface_id column."""
|
| 165 |
+
values = self.sheet.col_values(self.col_index)
|
| 166 |
+
return values[1:] # Exclude header
|
| 167 |
+
|
| 168 |
+
def _update_sheet(self, data: List[str]):
|
| 169 |
+
"""Update the entire column with new data."""
|
| 170 |
+
try:
|
| 171 |
+
# Prepare the range for update (excluding header)
|
| 172 |
+
start_cell = gspread.utils.rowcol_to_a1(2, self.col_index) # Start from row 2
|
| 173 |
+
end_cell = gspread.utils.rowcol_to_a1(len(data) + 2, self.col_index)
|
| 174 |
+
range_name = f"{start_cell}:{end_cell}"
|
| 175 |
+
|
| 176 |
+
# Convert data to 2D array format required by gspread
|
| 177 |
+
cells = [[value] for value in data]
|
| 178 |
+
|
| 179 |
+
# Update the range
|
| 180 |
+
self.sheet.update(range_name, cells)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"Error updating sheet: {str(e)}")
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
def push(self, text: str) -> int:
|
| 186 |
+
"""
|
| 187 |
+
Push a text value to the next empty cell in the huggingface_id column.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
text (str): Text to push to the sheet
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
int: The row number where the text was pushed
|
| 194 |
+
"""
|
| 195 |
+
try:
|
| 196 |
+
self._reconnect_if_needed()
|
| 197 |
+
|
| 198 |
+
# Get all values in the huggingface_id column
|
| 199 |
+
column_values = self.sheet.col_values(self.col_index)
|
| 200 |
+
|
| 201 |
+
# Find the next empty row
|
| 202 |
+
next_row = None
|
| 203 |
+
for i in range(1, len(column_values)):
|
| 204 |
+
if not column_values[i].strip():
|
| 205 |
+
next_row = i + 1
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
# If no empty row found, append to the end
|
| 209 |
+
if next_row is None:
|
| 210 |
+
next_row = len(column_values) + 1
|
| 211 |
+
|
| 212 |
+
# Update the cell
|
| 213 |
+
self.sheet.update_cell(next_row, self.col_index, text)
|
| 214 |
+
print(f"Successfully pushed value: {text} to row {next_row}")
|
| 215 |
+
return next_row
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Error pushing to sheet: {str(e)}")
|
| 219 |
+
raise
|
| 220 |
+
|
| 221 |
+
def pop(self) -> Optional[str]:
|
| 222 |
+
"""Remove and return the most recent value."""
|
| 223 |
+
try:
|
| 224 |
+
self._reconnect_if_needed()
|
| 225 |
+
data = self._fetch_column_data()
|
| 226 |
+
|
| 227 |
+
if not data or not data[0].strip():
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
value = data.pop(0) # Remove first value
|
| 231 |
+
data.append("") # Add empty string at the end to maintain sheet size
|
| 232 |
+
|
| 233 |
+
self._update_sheet(data)
|
| 234 |
+
print(f"Successfully popped value: {value}")
|
| 235 |
+
return value
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f"Error popping from sheet: {str(e)}")
|
| 239 |
+
raise
|
| 240 |
+
|
| 241 |
+
def delete(self, value: str) -> List[int]:
|
| 242 |
+
"""Delete all occurrences of a value."""
|
| 243 |
+
try:
|
| 244 |
+
self._reconnect_if_needed()
|
| 245 |
+
data = self._fetch_column_data()
|
| 246 |
+
|
| 247 |
+
# Find all indices before deletion
|
| 248 |
+
indices = [i + 1 for i, v in enumerate(data) if v.strip() == value.strip()]
|
| 249 |
+
if not indices:
|
| 250 |
+
print(f"Value '{value}' not found in sheet")
|
| 251 |
+
return []
|
| 252 |
+
|
| 253 |
+
# Remove matching values and add empty strings at the end
|
| 254 |
+
data = [v for v in data if v.strip() != value.strip()]
|
| 255 |
+
data.extend([""] * len(indices)) # Add empty strings to maintain sheet size
|
| 256 |
+
|
| 257 |
+
self._update_sheet(data)
|
| 258 |
+
print(f"Successfully deleted value '{value}' from rows: {indices}")
|
| 259 |
+
return indices
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error deleting from sheet: {str(e)}")
|
| 263 |
+
raise
|
| 264 |
+
|
| 265 |
+
def update_cell_by_condition(self, condition_column: str, condition_value: str, target_column: str, target_value: str) -> Optional[int]:
|
| 266 |
+
"""
|
| 267 |
+
Update the value of a cell based on a condition in another column.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
condition_column (str): The column to check the condition on.
|
| 271 |
+
condition_value (str): The value to match in the condition column.
|
| 272 |
+
target_column (str): The column where the value should be updated.
|
| 273 |
+
target_value (str): The new value to set in the target column.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Optional[int]: The row number where the value was updated, or None if no matching row was found.
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
self._reconnect_if_needed()
|
| 280 |
+
|
| 281 |
+
# Get all column headers
|
| 282 |
+
headers = self.sheet.row_values(1)
|
| 283 |
+
|
| 284 |
+
# Find the indices for the condition and target columns
|
| 285 |
+
try:
|
| 286 |
+
condition_col_index = headers.index(condition_column) + 1
|
| 287 |
+
except ValueError:
|
| 288 |
+
raise ValueError(f"์กฐ๊ฑด ์นผ๋ผ '{condition_column}'์ด(๊ฐ) ์์ต๋๋ค.")
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
target_col_index = headers.index(target_column) + 1
|
| 292 |
+
except ValueError:
|
| 293 |
+
raise ValueError(f"๋ชฉํ ์นผ๋ผ '{target_column}'์ด(๊ฐ) ์์ต๋๋ค.")
|
| 294 |
+
|
| 295 |
+
# Get all rows of data
|
| 296 |
+
data = self.sheet.get_all_records()
|
| 297 |
+
|
| 298 |
+
# Find the row that matches the condition
|
| 299 |
+
for i, row in enumerate(data):
|
| 300 |
+
if row.get(condition_column) == condition_value:
|
| 301 |
+
# Update the target column in the matching row
|
| 302 |
+
row_number = i + 2 # Row index starts at 2 (1 is header)
|
| 303 |
+
self.sheet.update_cell(row_number, target_col_index, target_value)
|
| 304 |
+
print(f"Updated row {row_number}: Set {target_column} to '{target_value}' where {condition_column} is '{condition_value}'")
|
| 305 |
+
return row_number
|
| 306 |
+
|
| 307 |
+
print(f"์กฐ๊ฑด์ ๋ง๋ ํ์ ์ฐพ์ ์ ์์ต๋๋ค: {condition_column} = '{condition_value}'")
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
print(f"Error updating cell by condition: {str(e)}")
|
| 312 |
+
raise
|
| 313 |
+
|
| 314 |
+
def get_all_values(self) -> List[str]:
|
| 315 |
+
"""Get all values from the huggingface_id column."""
|
| 316 |
+
self._reconnect_if_needed()
|
| 317 |
+
return [v for v in self._fetch_column_data() if v.strip()]
|
| 318 |
+
|
| 319 |
+
# Example usage
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
# Initialize sheet manager
|
| 322 |
+
sheet_manager = SheetManager()
|
| 323 |
+
|
| 324 |
+
# # Push some test values
|
| 325 |
+
# sheet_manager.push("test-model-1")
|
| 326 |
+
# sheet_manager.push("test-model-2")
|
| 327 |
+
# sheet_manager.push("test-model-3")
|
| 328 |
+
|
| 329 |
+
# print("Initial values:", sheet_manager.get_all_values())
|
| 330 |
+
|
| 331 |
+
# # Pop the most recent value
|
| 332 |
+
# popped = sheet_manager.pop()
|
| 333 |
+
# print(f"Popped value: {popped}")
|
| 334 |
+
# print("After pop:", sheet_manager.get_all_values())
|
| 335 |
+
|
| 336 |
+
# # Delete a specific value
|
| 337 |
+
# deleted_rows = sheet_manager.delete("test-model-2")
|
| 338 |
+
# print(f"Deleted from rows: {deleted_rows}")
|
| 339 |
+
# print("After delete:", sheet_manager.get_all_values())
|
| 340 |
+
|
| 341 |
+
row_updated = sheet_manager.update_cell_by_condition(
|
| 342 |
+
condition_column="model",
|
| 343 |
+
condition_value="msr",
|
| 344 |
+
target_column="pia",
|
| 345 |
+
target_value="new_value"
|
| 346 |
+
)
|
| 347 |
+
|
sheet_manager/sheet_loader/sheet2df.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 3 |
+
import gspread
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from enviroments.convert import get_json_from_env_var
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
def sheet2df(sheet_name:str = "model"):
|
| 12 |
+
"""
|
| 13 |
+
Reads data from a specified Google Spreadsheet and converts it into a Pandas DataFrame.
|
| 14 |
+
|
| 15 |
+
Steps:
|
| 16 |
+
1. Authenticate using a service account JSON key.
|
| 17 |
+
2. Open the spreadsheet by its URL.
|
| 18 |
+
3. Select the worksheet to read.
|
| 19 |
+
4. Convert the worksheet data to a Pandas DataFrame.
|
| 20 |
+
5. Clean up the DataFrame:
|
| 21 |
+
- Rename columns using the first row of data.
|
| 22 |
+
- Drop the first row after renaming columns.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
pd.DataFrame: A Pandas DataFrame containing the cleaned data from the spreadsheet.
|
| 26 |
+
|
| 27 |
+
Note:
|
| 28 |
+
- The following variables must be configured before using this function:
|
| 29 |
+
- `json_key_path`: Path to the service account JSON key file.
|
| 30 |
+
- `spreadsheet_url`: URL of the Google Spreadsheet.
|
| 31 |
+
- `sheet_name`: Name of the worksheet to load.
|
| 32 |
+
|
| 33 |
+
Dependencies:
|
| 34 |
+
- pandas
|
| 35 |
+
- gspread
|
| 36 |
+
- oauth2client
|
| 37 |
+
"""
|
| 38 |
+
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
|
| 39 |
+
json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
|
| 40 |
+
credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
|
| 41 |
+
gc = gspread.authorize(credential)
|
| 42 |
+
|
| 43 |
+
spreadsheet_url = os.getenv("SPREADSHEET_URL")
|
| 44 |
+
doc = gc.open_by_url(spreadsheet_url)
|
| 45 |
+
sheet = doc.worksheet(sheet_name)
|
| 46 |
+
|
| 47 |
+
# Convert to DataFrame
|
| 48 |
+
df = pd.DataFrame(sheet.get_all_values())
|
| 49 |
+
# Clean DataFrame
|
| 50 |
+
df.rename(columns=df.iloc[0], inplace=True)
|
| 51 |
+
df.drop(df.index[0], inplace=True)
|
| 52 |
+
|
| 53 |
+
return df
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def add_scaled_columns(df: pd.DataFrame, origin_col_name: str) -> pd.DataFrame:
|
| 57 |
+
"""
|
| 58 |
+
ํน์ ์นผ๋ผ์ ๋ชจ๋ ๊ฐ์ float๋ก ๋ณํํ ํ 100๋ฐฐ ํ ๋ค์, ์๋ก์ด ์นผ๋ผ์ ์ถ๊ฐํ๋ ํจ์.
|
| 59 |
+
|
| 60 |
+
:param df: ์๋ณธ pandas DataFrame
|
| 61 |
+
:param origin_col_name: ๋ณํํ ์นผ๋ผ ์ด๋ฆ (์ซ์ ๊ฐ์ด ๋ค์ด์๋ ์นผ๋ผ)
|
| 62 |
+
:return: ์๋ก์ด ์นผ๋ผ์ด ์ถ๊ฐ๋ DataFrame
|
| 63 |
+
"""
|
| 64 |
+
# df[origin_col_name] = df[origin_col_name].astype(float)
|
| 65 |
+
df[origin_col_name] = pd.to_numeric(df[origin_col_name], errors='coerce').fillna(0.0)
|
| 66 |
+
|
| 67 |
+
new_col_name = f"{origin_col_name}*100" # ์๋ก์ด ์นผ๋ผ๋ช
์์ฑ
|
| 68 |
+
df[new_col_name] = df[origin_col_name] * 100 # ๊ฐ ๋ณํ ํ ์๋ก์ด ์นผ๋ผ ์ถ๊ฐ
|
| 69 |
+
|
| 70 |
+
return df
|
sheet_manager/sheet_monitor/sheet_sync.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
from typing import Optional, Callable
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
class SheetMonitor:
|
| 7 |
+
def __init__(self, sheet_manager, check_interval: float = 1.0):
|
| 8 |
+
"""
|
| 9 |
+
Initialize SheetMonitor with a sheet manager instance.
|
| 10 |
+
"""
|
| 11 |
+
self.sheet_manager = sheet_manager
|
| 12 |
+
self.check_interval = check_interval
|
| 13 |
+
|
| 14 |
+
# Threading control
|
| 15 |
+
self.monitor_thread = None
|
| 16 |
+
self.is_running = threading.Event()
|
| 17 |
+
self.pause_monitoring = threading.Event()
|
| 18 |
+
self.monitor_paused = threading.Event()
|
| 19 |
+
|
| 20 |
+
# Queue status
|
| 21 |
+
self.has_data = threading.Event()
|
| 22 |
+
|
| 23 |
+
# Logging setup
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def start_monitoring(self):
|
| 28 |
+
"""Start the monitoring thread."""
|
| 29 |
+
if self.monitor_thread is not None and self.monitor_thread.is_alive():
|
| 30 |
+
self.logger.warning("Monitoring thread is already running")
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
self.is_running.set()
|
| 34 |
+
self.pause_monitoring.clear()
|
| 35 |
+
self.monitor_thread = threading.Thread(target=self._monitor_loop)
|
| 36 |
+
self.monitor_thread.daemon = True
|
| 37 |
+
self.monitor_thread.start()
|
| 38 |
+
self.logger.info("Started monitoring thread")
|
| 39 |
+
|
| 40 |
+
def stop_monitoring(self):
|
| 41 |
+
"""Stop the monitoring thread."""
|
| 42 |
+
self.is_running.clear()
|
| 43 |
+
if self.monitor_thread:
|
| 44 |
+
self.monitor_thread.join()
|
| 45 |
+
self.logger.info("Stopped monitoring thread")
|
| 46 |
+
|
| 47 |
+
def pause(self):
|
| 48 |
+
"""Pause the monitoring."""
|
| 49 |
+
self.pause_monitoring.set()
|
| 50 |
+
self.monitor_paused.wait()
|
| 51 |
+
self.logger.info("Monitoring paused")
|
| 52 |
+
|
| 53 |
+
def resume(self):
|
| 54 |
+
"""Resume the monitoring."""
|
| 55 |
+
self.pause_monitoring.clear()
|
| 56 |
+
self.monitor_paused.clear()
|
| 57 |
+
# ์ฆ์ ์ฒดํฌ ์ํ
|
| 58 |
+
self.logger.info("Monitoring resumed, checking for new data...")
|
| 59 |
+
values = self.sheet_manager.get_all_values()
|
| 60 |
+
if values:
|
| 61 |
+
self.has_data.set()
|
| 62 |
+
self.logger.info(f"Found data after resume: {values}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _monitor_loop(self):
|
| 66 |
+
"""Main monitoring loop that checks for data in sheet."""
|
| 67 |
+
while self.is_running.is_set():
|
| 68 |
+
if self.pause_monitoring.is_set():
|
| 69 |
+
self.monitor_paused.set()
|
| 70 |
+
self.pause_monitoring.wait()
|
| 71 |
+
self.monitor_paused.clear()
|
| 72 |
+
# continue
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Check if there's any data in the sheet
|
| 76 |
+
values = self.sheet_manager.get_all_values()
|
| 77 |
+
self.logger.info(f"Monitoring: Current column={self.sheet_manager.column_name}, "
|
| 78 |
+
f"Values found={len(values)}, "
|
| 79 |
+
f"Has data={self.has_data.is_set()}")
|
| 80 |
+
|
| 81 |
+
if values: # If there's any non-empty value
|
| 82 |
+
self.has_data.set()
|
| 83 |
+
self.logger.info(f"Data detected: {values}")
|
| 84 |
+
else:
|
| 85 |
+
self.has_data.clear()
|
| 86 |
+
self.logger.info("No data in sheet, waiting...")
|
| 87 |
+
|
| 88 |
+
time.sleep(self.check_interval)
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
self.logger.error(f"Error in monitoring loop: {str(e)}")
|
| 92 |
+
time.sleep(self.check_interval)
|
| 93 |
+
|
| 94 |
+
class MainLoop:
|
| 95 |
+
def __init__(self, sheet_manager, sheet_monitor, callback_function: Callable = None):
|
| 96 |
+
"""
|
| 97 |
+
Initialize MainLoop with sheet manager and monitor instances.
|
| 98 |
+
"""
|
| 99 |
+
self.sheet_manager = sheet_manager
|
| 100 |
+
self.monitor = sheet_monitor
|
| 101 |
+
self.callback = callback_function
|
| 102 |
+
self.is_running = threading.Event()
|
| 103 |
+
self.logger = logging.getLogger(__name__)
|
| 104 |
+
|
| 105 |
+
def start(self):
|
| 106 |
+
"""Start the main processing loop."""
|
| 107 |
+
self.is_running.set()
|
| 108 |
+
self.monitor.start_monitoring()
|
| 109 |
+
self._main_loop()
|
| 110 |
+
|
| 111 |
+
def stop(self):
|
| 112 |
+
"""Stop the main processing loop."""
|
| 113 |
+
self.is_running.clear()
|
| 114 |
+
self.monitor.stop_monitoring()
|
| 115 |
+
|
| 116 |
+
def process_new_value(self):
|
| 117 |
+
"""Process values by calling pop function for multiple columns and custom callback."""
|
| 118 |
+
try:
|
| 119 |
+
# Store original column
|
| 120 |
+
original_column = self.sheet_manager.column_name
|
| 121 |
+
|
| 122 |
+
# Pop from huggingface_id column
|
| 123 |
+
model_id = self.sheet_manager.pop()
|
| 124 |
+
|
| 125 |
+
if model_id:
|
| 126 |
+
# Pop from benchmark_name column
|
| 127 |
+
self.sheet_manager.change_column("benchmark_name")
|
| 128 |
+
benchmark_name = self.sheet_manager.pop()
|
| 129 |
+
|
| 130 |
+
# Pop from prompt_cfg_name column
|
| 131 |
+
self.sheet_manager.change_column("prompt_cfg_name")
|
| 132 |
+
prompt_cfg_name = self.sheet_manager.pop()
|
| 133 |
+
|
| 134 |
+
# Return to original column
|
| 135 |
+
self.sheet_manager.change_column(original_column)
|
| 136 |
+
|
| 137 |
+
self.logger.info(f"Processed values - model_id: {model_id}, "
|
| 138 |
+
f"benchmark_name: {benchmark_name}, "
|
| 139 |
+
f"prompt_cfg_name: {prompt_cfg_name}")
|
| 140 |
+
|
| 141 |
+
if self.callback:
|
| 142 |
+
# Pass all three values to callback
|
| 143 |
+
self.callback(model_id, benchmark_name, prompt_cfg_name)
|
| 144 |
+
|
| 145 |
+
return model_id, benchmark_name, prompt_cfg_name
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
self.logger.error(f"Error processing values: {str(e)}")
|
| 149 |
+
# Return to original column in case of error
|
| 150 |
+
try:
|
| 151 |
+
self.sheet_manager.change_column(original_column)
|
| 152 |
+
except:
|
| 153 |
+
pass
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
def _main_loop(self):
|
| 157 |
+
"""Main processing loop."""
|
| 158 |
+
while self.is_running.is_set():
|
| 159 |
+
# Wait for data to be available
|
| 160 |
+
if self.monitor.has_data.wait(timeout=1.0):
|
| 161 |
+
# Pause monitoring
|
| 162 |
+
self.monitor.pause()
|
| 163 |
+
|
| 164 |
+
# Process the value
|
| 165 |
+
self.process_new_value()
|
| 166 |
+
|
| 167 |
+
# Check if there's still data in the sheet
|
| 168 |
+
values = self.sheet_manager.get_all_values()
|
| 169 |
+
self.logger.info(f"After processing: Current column={self.sheet_manager.column_name}, "
|
| 170 |
+
f"Values remaining={len(values)}")
|
| 171 |
+
|
| 172 |
+
if not values:
|
| 173 |
+
self.monitor.has_data.clear()
|
| 174 |
+
self.logger.info("All data processed, clearing has_data flag")
|
| 175 |
+
else:
|
| 176 |
+
self.logger.info(f"Remaining data: {values}")
|
| 177 |
+
|
| 178 |
+
# Resume monitoring
|
| 179 |
+
self.monitor.resume()
|
| 180 |
+
## TODO
|
| 181 |
+
# API ๋ถ๋น ํธ์ถ ๋ฌธ์ ๋ก ๋ง์ฝ์ ์ฐธ์กฐํ๋ค๊ฐ ์คํจํ ๊ฒฝ์ฐ ๋๊ธฐํ๋ค๊ฐ ๋ค์ ์๋ํ๊ฒ๋ ์ค๊ณ
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Example usage
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
import sys
|
| 187 |
+
from pathlib import Path
|
| 188 |
+
sys.path.append(str(Path(__file__).parent.parent.parent))
|
| 189 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 190 |
+
from pia_bench.pipe_line.piepline import PiaBenchMark
|
| 191 |
+
def my_custom_function(huggingface_id, benchmark_name, prompt_cfg_name):
|
| 192 |
+
piabenchmark = PiaBenchMark(huggingface_id, benchmark_name, prompt_cfg_name)
|
| 193 |
+
piabenchmark.bench_start()
|
| 194 |
+
|
| 195 |
+
# Initialize components
|
| 196 |
+
sheet_manager = SheetManager()
|
| 197 |
+
monitor = SheetMonitor(sheet_manager, check_interval=10.0)
|
| 198 |
+
main_loop = MainLoop(sheet_manager, monitor, callback_function=my_custom_function)
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
main_loop.start()
|
| 202 |
+
while True:
|
| 203 |
+
time.sleep(5)
|
| 204 |
+
except KeyboardInterrupt:
|
| 205 |
+
main_loop.stop()
|
utils/bench_meta.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from utils.except_dir import cust_listdir
|
| 7 |
+
def get_video_metadata(video_path, category, benchmark):
|
| 8 |
+
"""Extract metadata from a video file."""
|
| 9 |
+
cap = cv2.VideoCapture(video_path)
|
| 10 |
+
|
| 11 |
+
if not cap.isOpened():
|
| 12 |
+
return None
|
| 13 |
+
# Extract metadata
|
| 14 |
+
video_name = os.path.basename(video_path)
|
| 15 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 16 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 17 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 18 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 19 |
+
resolution = f"{frame_width}x{frame_height}"
|
| 20 |
+
duration_seconds = frame_count / fps if fps > 0 else 0
|
| 21 |
+
aspect_ratio = round(frame_width / frame_height, 2) if frame_height > 0 else 0
|
| 22 |
+
file_size = os.path.getsize(video_path) / (1024 * 1024) # MB
|
| 23 |
+
file_format = os.path.splitext(video_name)[1].lower()
|
| 24 |
+
cap.release()
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"video_name": video_name,
|
| 28 |
+
"resolution": resolution,
|
| 29 |
+
"video_duration": f"{duration_seconds // 60:.0f}:{duration_seconds % 60:.0f}",
|
| 30 |
+
"category": category,
|
| 31 |
+
"benchmark": benchmark,
|
| 32 |
+
"duration_seconds": duration_seconds,
|
| 33 |
+
"total_frames": frame_count,
|
| 34 |
+
"file_format": file_format,
|
| 35 |
+
"file_size_mb": round(file_size, 2),
|
| 36 |
+
"aspect_ratio": aspect_ratio,
|
| 37 |
+
"fps": fps
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def process_videos_in_directory(root_dir):
|
| 41 |
+
"""Process all videos in the given directory structure."""
|
| 42 |
+
video_metadata_list = []
|
| 43 |
+
|
| 44 |
+
# ๋ฒค์น๋งํฌ ํด๋๋ค์ ์ํ
|
| 45 |
+
for benchmark in cust_listdir(root_dir):
|
| 46 |
+
benchmark_path = os.path.join(root_dir, benchmark)
|
| 47 |
+
if not os.path.isdir(benchmark_path):
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
# dataset ํด๋ ๊ฒฝ๋ก
|
| 51 |
+
dataset_path = os.path.join(benchmark_path, "dataset")
|
| 52 |
+
if not os.path.isdir(dataset_path):
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# dataset ํด๋ ์์ ์นดํ
๊ณ ๋ฆฌ ํด๋๋ค์ ์ํ
|
| 56 |
+
for category in cust_listdir(dataset_path):
|
| 57 |
+
category_path = os.path.join(dataset_path, category)
|
| 58 |
+
if not os.path.isdir(category_path):
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
# ๊ฐ ์นดํ
๊ณ ๋ฆฌ ํด๋ ์์ ๋น๋์ค ํ์ผ๋ค์ ์ฒ๋ฆฌ
|
| 62 |
+
for file in cust_listdir(category_path):
|
| 63 |
+
file_path = os.path.join(category_path, file)
|
| 64 |
+
|
| 65 |
+
if file_path.lower().endswith(('.mp4', '.avi', '.mkv', '.mov', 'MOV')):
|
| 66 |
+
metadata = get_video_metadata(file_path, category, benchmark)
|
| 67 |
+
if metadata:
|
| 68 |
+
video_metadata_list.append(metadata)
|
| 69 |
+
# df = pd.DataFrame(video_metadata_list)
|
| 70 |
+
# df.to_csv('sample.csv', index=False)
|
| 71 |
+
return pd.DataFrame(video_metadata_list)
|
| 72 |
+
|
utils/except_dir.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
import enviroments.config as config
|
| 4 |
+
|
| 5 |
+
def cust_listdir(directory: str) -> List[str]:
|
| 6 |
+
"""
|
| 7 |
+
os.listdir์ ์ ์ฌํ๊ฒ ์๋ํ์ง๋ง config์ ์ ์๋ ํด๋/ํ์ผ๋ค์ ์ ์ธํ๊ณ ๋ชฉ๋ก์ ๋ฐํํฉ๋๋ค.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
directory (str): ํ์ํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
List[str]: config์ EXCLUDE_DIRS์ ์ ์๋ ํด๋/ํ์ผ๋ค์ ์ ์ธํ ๋ชฉ๋ก
|
| 14 |
+
"""
|
| 15 |
+
return [item for item in os.listdir(directory) if item not in config.EXCLUDE_DIRS]
|
utils/hf_api.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import HfApi
|
| 2 |
+
from typing import Optional, List, Dict, Any
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class ModelInfo:
|
| 7 |
+
"""๋ชจ๋ธ ์ ๋ณด๋ฅผ ์ ์ฅํ๋ ๋ฐ์ดํฐ ํด๋์ค"""
|
| 8 |
+
model_id: str
|
| 9 |
+
last_modified: Any
|
| 10 |
+
downloads: int
|
| 11 |
+
private: bool
|
| 12 |
+
attributes: Dict[str, Any]
|
| 13 |
+
|
| 14 |
+
class HuggingFaceInfoManager:
|
| 15 |
+
def __init__(self, access_token: Optional[str] = None, organization: str = "PIA-SPACE-LAB"):
|
| 16 |
+
"""
|
| 17 |
+
HuggingFace API ๊ด๋ฆฌ์ ํด๋์ค ์ด๊ธฐํ
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
access_token (str, optional): HuggingFace ์ก์ธ์ค ํ ํฐ
|
| 21 |
+
organization (str): ์กฐ์ง ์ด๋ฆ (๊ธฐ๋ณธ๊ฐ: "PIA-SPACE-LAB")
|
| 22 |
+
|
| 23 |
+
Raises:
|
| 24 |
+
ValueError: access_token์ด None์ผ ๊ฒฝ์ฐ ๋ฐ์
|
| 25 |
+
"""
|
| 26 |
+
if access_token is None:
|
| 27 |
+
raise ValueError("์ก์ธ์ค ํ ํฐ์ ํ์ ์
๋ ฅ๊ฐ์
๋๋ค. HuggingFace์์ ๋ฐ๊ธ๋ฐ์ ํ ํฐ์ ์
๋ ฅํด์ฃผ์ธ์.")
|
| 28 |
+
|
| 29 |
+
self.api = HfApi()
|
| 30 |
+
self.access_token = access_token
|
| 31 |
+
self.organization = organization
|
| 32 |
+
|
| 33 |
+
# API ํธ์ถ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ๋ก ์ฒ๋ฆฌํ์ฌ ์ ์ฅ
|
| 34 |
+
api_models = self.api.list_models(author=self.organization, use_auth_token=self.access_token)
|
| 35 |
+
self._stored_models = []
|
| 36 |
+
self._model_infos = []
|
| 37 |
+
|
| 38 |
+
# ๋ชจ๋ ๋ชจ๋ธ ์ ๋ณด๋ฅผ ๋ฏธ๋ฆฌ ์ฒ๋ฆฌํ์ฌ ์ ์ฅ
|
| 39 |
+
for model in api_models:
|
| 40 |
+
# ๊ธฐ๋ณธ ์ ๋ณด ์ ์ฅ
|
| 41 |
+
model_attrs = {}
|
| 42 |
+
for attr in dir(model):
|
| 43 |
+
if not attr.startswith("_"):
|
| 44 |
+
model_attrs[attr] = getattr(model, attr)
|
| 45 |
+
|
| 46 |
+
# ModelInfo ๊ฐ์ฒด ์์ฑ ๋ฐ ์ ์ฅ
|
| 47 |
+
model_info = ModelInfo(
|
| 48 |
+
model_id=model.modelId,
|
| 49 |
+
last_modified=model.lastModified,
|
| 50 |
+
downloads=model.downloads,
|
| 51 |
+
private=model.private,
|
| 52 |
+
attributes=model_attrs
|
| 53 |
+
)
|
| 54 |
+
self._model_infos.append(model_info)
|
| 55 |
+
self._stored_models.append(model)
|
| 56 |
+
|
| 57 |
+
def get_model_info(self) -> List[Dict[str, Any]]:
|
| 58 |
+
"""๋ชจ๋ ๋ชจ๋ธ์ ์ ๋ณด๋ฅผ ๋ฐํ"""
|
| 59 |
+
return [
|
| 60 |
+
{
|
| 61 |
+
'model_id': info.model_id,
|
| 62 |
+
'last_modified': info.last_modified,
|
| 63 |
+
'downloads': info.downloads,
|
| 64 |
+
'private': info.private,
|
| 65 |
+
**info.attributes
|
| 66 |
+
}
|
| 67 |
+
for info in self._model_infos
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
def get_model_ids(self) -> List[str]:
|
| 71 |
+
"""๋ชจ๋ ๋ชจ๋ธ์ ID ๋ฆฌ์คํธ ๋ฐํ"""
|
| 72 |
+
return [info.model_id for info in self._model_infos]
|
| 73 |
+
|
| 74 |
+
def get_private_models(self) -> List[Dict[str, Any]]:
|
| 75 |
+
"""๋น๊ณต๊ฐ ๋ชจ๋ธ ์ ๋ณด ๋ฐํ"""
|
| 76 |
+
return [
|
| 77 |
+
{
|
| 78 |
+
'model_id': info.model_id,
|
| 79 |
+
'last_modified': info.last_modified,
|
| 80 |
+
'downloads': info.downloads,
|
| 81 |
+
'private': info.private,
|
| 82 |
+
**info.attributes
|
| 83 |
+
}
|
| 84 |
+
for info in self._model_infos if info.private
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
def get_public_models(self) -> List[Dict[str, Any]]:
|
| 88 |
+
"""๊ณต๊ฐ ๋ชจ๋ธ ์ ๋ณด ๋ฐํ"""
|
| 89 |
+
return [
|
| 90 |
+
{
|
| 91 |
+
'model_id': info.model_id,
|
| 92 |
+
'last_modified': info.last_modified,
|
| 93 |
+
'downloads': info.downloads,
|
| 94 |
+
'private': info.private,
|
| 95 |
+
**info.attributes
|
| 96 |
+
}
|
| 97 |
+
for info in self._model_infos if not info.private
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
def refresh_models(self) -> None:
|
| 101 |
+
"""๋ชจ๋ธ ์ ๋ณด ์๋ก๊ณ ์นจ (์๋ก์ด API ํธ์ถ ์ํ)"""
|
| 102 |
+
# ํด๋์ค ์ฌ์ด๊ธฐํ
|
| 103 |
+
self.__init__(self.access_token, self.organization)
|
utils/logger.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def custom_logger(name: str) -> logging.Logger:
|
| 9 |
+
"""
|
| 10 |
+
์ปค์คํ
๋ก๊ฑฐ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 11 |
+
์ฝ์ ํธ๋ค๋ฌ์ ํ์ผ ํธ๋ค๋ฌ๊ฐ ๋์์ ์๋ํ๋ฉฐ, ๊ฐ๊ฐ ๋ค๋ฅธ ๋ก๊ทธ ๋ ๋ฒจ์ ๊ฐ์ง๋๋ค.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
name (str): ๋ก๊ฑฐ ์ด๋ฆ (๋ณดํต __name__ ์ฌ์ฉ)
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
logging.Logger: ์ค์ ๋ Logger ๊ฐ์ฒด
|
| 18 |
+
|
| 19 |
+
๋ก๊ทธ ๋ ๋ฒจ ์ค์ :
|
| 20 |
+
- ์ฝ์ ํธ๋ค๋ฌ: INFO ๋ ๋ฒจ (INFO, WARNING, ERROR, CRITICAL๋ง ์ถ๋ ฅ)
|
| 21 |
+
- ํ์ผ ํธ๋ค๋ฌ: DEBUG ๋ ๋ฒจ (๋ชจ๋ ๋ ๋ฒจ ๊ธฐ๋ก)
|
| 22 |
+
- ๋ก๊ทธ ํ์ผ์ 'logs' ๋๋ ํ ๋ฆฌ์ ๋ ์ง๋ณ๋ก ์ ์ฅ๋๋ฉฐ 10์ผ์น๋ง ๋ณด๊ด
|
| 23 |
+
|
| 24 |
+
์ฌ์ฉ ์์:
|
| 25 |
+
```python
|
| 26 |
+
logger = custom_logger(__name__)
|
| 27 |
+
|
| 28 |
+
logger.debug("๋๋ฒ๊ทธ ๋ฉ์์ง") # ํ์ผ์๋ง ๊ธฐ๋ก
|
| 29 |
+
logger.info("์ ๋ณด ๋ฉ์์ง") # ์ฝ์ ์ถ๋ ฅ + ํ์ผ ๊ธฐ๋ก
|
| 30 |
+
logger.warning("๊ฒฝ๊ณ ๋ฉ์์ง") # ์ฝ์ ์ถ๋ ฅ + ํ์ผ ๊ธฐ๋ก
|
| 31 |
+
logger.error("์๋ฌ ๋ฉ์์ง") # ์ฝ์ ์ถ๋ ฅ + ํ์ผ ๊ธฐ๋ก
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
์ถ๋ ฅ ํ์:
|
| 35 |
+
์ฝ์: [HH:MM:SS] [๋ ๋ฒจ] [๋ชจ๋:๋ผ์ธ] [ํจ์๋ช
] ๋ฉ์์ง
|
| 36 |
+
ํ์ผ: [YYYY-MM-DD HH:MM:SS] [๋ ๋ฒจ] [๋ชจ๋:๋ผ์ธ] [ํจ์๋ช
] ๋ฉ์์ง
|
| 37 |
+
"""
|
| 38 |
+
# ๋ก๊ฑฐ ์์ฑ
|
| 39 |
+
logger = logging.getLogger(name)
|
| 40 |
+
logger.setLevel(logging.DEBUG)
|
| 41 |
+
|
| 42 |
+
# ์ด๋ฏธ ํธ๋ค๋ฌ๊ฐ ์๋ค๋ฉด ์ถ๊ฐํ์ง ์์
|
| 43 |
+
if logger.handlers:
|
| 44 |
+
return logger
|
| 45 |
+
|
| 46 |
+
# logs ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 47 |
+
log_dir = Path("logs")
|
| 48 |
+
log_dir.mkdir(exist_ok=True)
|
| 49 |
+
|
| 50 |
+
# ํฌ๋งทํฐ ์ค์
|
| 51 |
+
console_formatter = logging.Formatter(
|
| 52 |
+
"[%(asctime)s] [%(levelname)s] [%(module)s:%(lineno)d] [%(funcName)s] %(message)s",
|
| 53 |
+
datefmt="%H:%M:%S",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
file_formatter = logging.Formatter(
|
| 57 |
+
"[%(asctime)s] [%(levelname)s] [%(module)s:%(lineno)d] [%(funcName)s] %(message)s",
|
| 58 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# ์ฝ์ ํธ๋ค๋ฌ
|
| 62 |
+
console_handler = logging.StreamHandler()
|
| 63 |
+
console_handler.setLevel(logging.INFO)
|
| 64 |
+
console_handler.setFormatter(console_formatter)
|
| 65 |
+
|
| 66 |
+
# ํ์ผ ํธ๋ค๋ฌ
|
| 67 |
+
today = datetime.now().strftime("%Y%m%d")
|
| 68 |
+
file_handler = TimedRotatingFileHandler(
|
| 69 |
+
filename=log_dir / f"{today}.log",
|
| 70 |
+
when="midnight",
|
| 71 |
+
interval=1,
|
| 72 |
+
backupCount=10, # 10์ผ์น๋ง ๋ณด๊ด
|
| 73 |
+
encoding="utf-8",
|
| 74 |
+
)
|
| 75 |
+
file_handler.setLevel(logging.WARNING)
|
| 76 |
+
file_handler.setFormatter(file_formatter)
|
| 77 |
+
|
| 78 |
+
# ํธ๋ค๋ฌ ์ถ๊ฐ
|
| 79 |
+
logger.addHandler(console_handler)
|
| 80 |
+
logger.addHandler(file_handler)
|
| 81 |
+
|
| 82 |
+
# ์ค๋๋ ๋ก๊ทธ ํ์ผ ์ ๋ฆฌ
|
| 83 |
+
cleanup_old_logs(log_dir)
|
| 84 |
+
|
| 85 |
+
return logger
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def cleanup_old_logs(log_dir: Path):
|
| 89 |
+
"""10๊ฐ ์ด์์ ๋ก๊ทธ ํ์ผ์ด ์์ ๊ฒฝ์ฐ ๊ฐ์ฅ ์ค๋๋ ๊ฒ๋ถํฐ ์ญ์ """
|
| 90 |
+
log_files = sorted(log_dir.glob("*.log"), key=os.path.getctime)
|
| 91 |
+
while len(log_files) > 10:
|
| 92 |
+
log_files[0].unlink() # ๊ฐ์ฅ ์ค๋๋ ํ์ผ ์ญ์
|
| 93 |
+
log_files = log_files[1:]
|
utils/parser.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Dict, List, Tuple
|
| 3 |
+
|
| 4 |
+
def load_config(config_path: str) -> Dict:
|
| 5 |
+
"""
|
| 6 |
+
JSON ์ค์ ํ์ผ์ ์ฝ์ด์ ๋์
๋๋ฆฌ๋ก ๋ฐํํฉ๋๋ค.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
config_path (str): JSON ์ค์ ํ์ผ์ ๊ฒฝ๋ก
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
Dict: ์ค์ ์ ๋ณด๊ฐ ๋ด๊ธด ๋์
๋๋ฆฌ
|
| 13 |
+
"""
|
| 14 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 15 |
+
return json.load(f)
|
| 16 |
+
|
| 17 |
+
class PromptManager:
|
| 18 |
+
def __init__(self, config_path: str):
|
| 19 |
+
self.config = load_config(config_path)
|
| 20 |
+
self.sentences, self.index_mapping = self._extract_all_sentences_with_index()
|
| 21 |
+
self.reverse_mapping = self._create_reverse_mapping()
|
| 22 |
+
|
| 23 |
+
def _extract_all_sentences_with_index(self) -> Tuple[List[str], Dict]:
|
| 24 |
+
"""๋ชจ๋ sentence์ ์ธ๋ฑ์ค ๋งคํ ์ถ์ถ"""
|
| 25 |
+
sentences = []
|
| 26 |
+
index_mapping = {}
|
| 27 |
+
|
| 28 |
+
for event_idx, event_config in enumerate(self.config.get('PROMPT_CFG', [])):
|
| 29 |
+
prompts = event_config.get('prompts', {})
|
| 30 |
+
for status in ['normal', 'abnormal']:
|
| 31 |
+
for prompt_idx, prompt in enumerate(prompts.get(status, [])):
|
| 32 |
+
sentence = prompt.get('sentence', '')
|
| 33 |
+
sentences.append(sentence)
|
| 34 |
+
index_mapping[(event_idx, status, prompt_idx)] = sentence
|
| 35 |
+
|
| 36 |
+
return sentences, index_mapping
|
| 37 |
+
|
| 38 |
+
def _create_reverse_mapping(self) -> Dict:
|
| 39 |
+
"""sentence -> indices ์ญ๋ฐฉํฅ ๋งคํ ์์ฑ"""
|
| 40 |
+
reverse_map = {}
|
| 41 |
+
for indices, sent in self.index_mapping.items():
|
| 42 |
+
if sent not in reverse_map:
|
| 43 |
+
reverse_map[sent] = []
|
| 44 |
+
reverse_map[sent].append(indices)
|
| 45 |
+
return reverse_map
|
| 46 |
+
|
| 47 |
+
def get_sentence_indices(self, sentence: str) -> List[Tuple[int, str, int]]:
|
| 48 |
+
"""ํน์ sentence์ ๋ชจ๋ ์ธ๋ฑ์ค ์์น ๋ฐํ"""
|
| 49 |
+
return self.reverse_mapping.get(sentence, [])
|
| 50 |
+
|
| 51 |
+
def get_details_by_sentence(self, sentence: str) -> List[Dict]:
|
| 52 |
+
"""sentence๋ก ๋ชจ๋ ๊ด๋ จ ์์ธ ์ ๋ณด ์ฐพ์ ๋ฐํ"""
|
| 53 |
+
indices = self.get_sentence_indices(sentence)
|
| 54 |
+
return [self.get_details_by_index(*idx) for idx in indices]
|
| 55 |
+
|
| 56 |
+
def get_details_by_index(self, event_idx: int, status: str, prompt_idx: int) -> Dict:
|
| 57 |
+
"""์ธ๋ฑ์ค๋ก ์์ธ ์ ๋ณด ์ฐพ์ ๋ฐํ"""
|
| 58 |
+
event_config = self.config['PROMPT_CFG'][event_idx]
|
| 59 |
+
prompt = event_config['prompts'][status][prompt_idx]
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
'event': event_config['event'],
|
| 63 |
+
'status': status,
|
| 64 |
+
'sentence': prompt['sentence'],
|
| 65 |
+
'top_candidates': event_config['top_candidates'],
|
| 66 |
+
'alert_threshold': event_config['alert_threshold'],
|
| 67 |
+
'event_idx': event_idx,
|
| 68 |
+
'prompt_idx': prompt_idx
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def get_all_sentences(self) -> List[str]:
|
| 72 |
+
"""๋ชจ๋ sentence ๋ฆฌ์คํธ ๋ฐํ"""
|
| 73 |
+
return self.sentences
|