Spaces:
Runtime error
Runtime error
first
Browse files- .gitignore +181 -0
- app.py +21 -0
- enviroments/.gitkeep +0 -0
- enviroments/config.py +53 -0
- enviroments/convert.py +54 -0
- leaderboard_ui/tab/dataset_visual_tab.py +160 -0
- leaderboard_ui/tab/leaderboard_tab.py +52 -0
- leaderboard_ui/tab/metric_visaul_tab.py +418 -0
- leaderboard_ui/tab/submit_tab.py +103 -0
- main.py +76 -0
- pia_bench/bench.py +157 -0
- pia_bench/checker/bench_checker.py +184 -0
- pia_bench/checker/sheet_checker.py +284 -0
- pia_bench/event_alarm.py +225 -0
- pia_bench/metric.py +322 -0
- pia_bench/pipe_line/piepline.py +227 -0
- requirements.txt +15 -0
- sample.csv +99 -0
- sheet_manager/sheet_checker/sheet_check.py +140 -0
- sheet_manager/sheet_convert/json2sheet.py +117 -0
- sheet_manager/sheet_crud/create_col.py +76 -0
- sheet_manager/sheet_crud/sheet_crud.py +347 -0
- sheet_manager/sheet_loader/sheet2df.py +52 -0
- sheet_manager/sheet_monitor/sheet_sync.py +205 -0
- topk.json +88 -0
- utils/bench_meta.py +72 -0
- utils/except_dir.py +15 -0
- utils/hf_api.py +103 -0
- utils/parser.py +65 -0
.gitignore
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# PyPI configuration file
|
| 171 |
+
.pypirc
|
| 172 |
+
|
| 173 |
+
# *.json
|
| 174 |
+
|
| 175 |
+
assets
|
| 176 |
+
DevMACS-AI-solution-devmacs
|
| 177 |
+
Research-AI-research-t2v_f1score_evaluator
|
| 178 |
+
.env
|
| 179 |
+
enviroments/abnormal-situation-leaderboard-3ca42d06719e.json
|
| 180 |
+
leaderboard_test
|
| 181 |
+
enviroments/deep-byte-352904-a072fdf439e7.json
|
app.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from leaderboard_ui.tab.submit_tab import submit_tab
|
| 4 |
+
from leaderboard_ui.tab.leaderboard_tab import leaderboard_tab
|
| 5 |
+
from leaderboard_ui.tab.dataset_visual_tab import visual_tab
|
| 6 |
+
from leaderboard_ui.tab.metric_visaul_tab import metric_visual_tab
|
| 7 |
+
|
| 8 |
+
abs_path = Path(__file__).parent
|
| 9 |
+
|
| 10 |
+
with gr.Blocks() as demo:
|
| 11 |
+
gr.Markdown("""
|
| 12 |
+
# 🥇 PIA_leaderboard
|
| 13 |
+
""")
|
| 14 |
+
with gr.Tabs():
|
| 15 |
+
leaderboard_tab()
|
| 16 |
+
submit_tab()
|
| 17 |
+
visual_tab()
|
| 18 |
+
metric_visual_tab()
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
demo.launch()
|
enviroments/.gitkeep
ADDED
|
File without changes
|
enviroments/config.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
EXCLUDE_DIRS = {"@eaDir", 'temp'}
|
| 4 |
+
|
| 5 |
+
TYPES = [
|
| 6 |
+
"markdown",
|
| 7 |
+
"markdown",
|
| 8 |
+
"number",
|
| 9 |
+
"number",
|
| 10 |
+
"number",
|
| 11 |
+
"number",
|
| 12 |
+
"number",
|
| 13 |
+
"number",
|
| 14 |
+
"number",
|
| 15 |
+
"str",
|
| 16 |
+
"str",
|
| 17 |
+
"str",
|
| 18 |
+
"str",
|
| 19 |
+
"bool",
|
| 20 |
+
"str",
|
| 21 |
+
"number",
|
| 22 |
+
"number",
|
| 23 |
+
"bool",
|
| 24 |
+
"str",
|
| 25 |
+
"bool",
|
| 26 |
+
"bool",
|
| 27 |
+
"str",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
ON_LOAD_COLUMNS = [
|
| 31 |
+
"TASK",
|
| 32 |
+
"Model",
|
| 33 |
+
"PIA" # 모델 이름
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
OFF_LOAD_COLUMNS = ["Model link", "PIA", "PIA * 100" , "Model name" ]
|
| 37 |
+
|
| 38 |
+
HIDE_COLUMNS = ["PIA * 100"]
|
| 39 |
+
|
| 40 |
+
FILTER_COLUMNS = ["T"]
|
| 41 |
+
|
| 42 |
+
NUMERIC_COLUMNS = ["PIA"]
|
| 43 |
+
|
| 44 |
+
NUMERIC_INTERVALS = {
|
| 45 |
+
"?": pd.Interval(-1, 0, closed="right"),
|
| 46 |
+
"~1.5": pd.Interval(0, 2, closed="right"),
|
| 47 |
+
"~3": pd.Interval(2, 4, closed="right"),
|
| 48 |
+
"~7": pd.Interval(4, 9, closed="right"),
|
| 49 |
+
"~13": pd.Interval(9, 20, closed="right"),
|
| 50 |
+
"~35": pd.Interval(20, 45, closed="right"),
|
| 51 |
+
"~60": pd.Interval(45, 70, closed="right"),
|
| 52 |
+
"70+": pd.Interval(70, 10000, closed="right"),
|
| 53 |
+
}
|
enviroments/convert.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
def get_json_from_env_var(env_var_name):
|
| 7 |
+
"""
|
| 8 |
+
환경 변수에서 JSON 데이터를 가져와 딕셔너리로 변환하는 함수.
|
| 9 |
+
:param env_var_name: 환경 변수 이름
|
| 10 |
+
:return: 딕셔너리 형태의 JSON 데이터
|
| 11 |
+
"""
|
| 12 |
+
json_string = os.getenv(env_var_name)
|
| 13 |
+
if not json_string:
|
| 14 |
+
raise EnvironmentError(f"환경 변수 '{env_var_name}'가 설정되지 않았습니다.")
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
# 줄바꿈(\n)을 이스케이프 문자(\\n)로 변환
|
| 18 |
+
json_string = json_string.replace("\n", "\\n")
|
| 19 |
+
|
| 20 |
+
# JSON 문자열을 딕셔너리로 변환
|
| 21 |
+
json_data = json.loads(json_string)
|
| 22 |
+
except json.JSONDecodeError as e:
|
| 23 |
+
raise ValueError(f"JSON 변환 실패: {e}")
|
| 24 |
+
|
| 25 |
+
return json_data
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def json_to_env_var(json_file_path, env_var_name="JSON_ENV_VAR"):
|
| 30 |
+
"""
|
| 31 |
+
주어진 JSON 파일의 데이터를 환경 변수 형태로 변환하여 출력하는 함수.
|
| 32 |
+
|
| 33 |
+
:param json_file_path: JSON 파일 경로
|
| 34 |
+
:param env_var_name: 환경 변수 이름 (기본값: JSON_ENV_VAR)
|
| 35 |
+
:return: None
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
# JSON 파일 읽기
|
| 39 |
+
with open(json_file_path, 'r') as json_file:
|
| 40 |
+
json_data = json.load(json_file)
|
| 41 |
+
|
| 42 |
+
# JSON 데이터를 문자열로 변환
|
| 43 |
+
json_string = json.dumps(json_data)
|
| 44 |
+
|
| 45 |
+
# 환경 변수 형태로 출력
|
| 46 |
+
env_variable = f'{env_var_name}={json_string}'
|
| 47 |
+
print("\n환경 변수로 사용할 수 있는 출력값:\n")
|
| 48 |
+
print(env_variable)
|
| 49 |
+
print("\n위 값을 .env 파일에 복사하여 붙여넣으세요.")
|
| 50 |
+
except FileNotFoundError:
|
| 51 |
+
print(f"파일을 찾을 수 없습니다: {json_file_path}")
|
| 52 |
+
except json.JSONDecodeError:
|
| 53 |
+
print(f"유효한 JSON 파일이 아닙니다: {json_file_path}")
|
| 54 |
+
|
leaderboard_ui/tab/dataset_visual_tab.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from leaderboard_ui.tab.submit_tab import submit_tab
|
| 4 |
+
from leaderboard_ui.tab.leaderboard_tab import leaderboard_tab
|
| 5 |
+
abs_path = Path(__file__).parent
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from utils.bench_meta import process_videos_in_directory
|
| 11 |
+
# Mock 데이터 생성
|
| 12 |
+
def create_mock_data():
|
| 13 |
+
benchmarks = ['VQA-2023', 'ImageQuality-2024', 'VideoEnhance-2024']
|
| 14 |
+
categories = ['Animation', 'Game', 'Movie', 'Sports', 'Vlog']
|
| 15 |
+
|
| 16 |
+
data_list = []
|
| 17 |
+
|
| 18 |
+
for benchmark in benchmarks:
|
| 19 |
+
n_videos = np.random.randint(50, 100)
|
| 20 |
+
for _ in range(n_videos):
|
| 21 |
+
category = np.random.choice(categories)
|
| 22 |
+
|
| 23 |
+
data_list.append({
|
| 24 |
+
"video_name": f"video_{np.random.randint(1000, 9999)}.mp4",
|
| 25 |
+
"resolution": np.random.choice(["1920x1080", "3840x2160", "1280x720"]),
|
| 26 |
+
"video_duration": f"{np.random.randint(0, 10)}:{np.random.randint(0, 60)}",
|
| 27 |
+
"category": category,
|
| 28 |
+
"benchmark": benchmark,
|
| 29 |
+
"duration_seconds": np.random.randint(30, 600),
|
| 30 |
+
"total_frames": np.random.randint(1000, 10000),
|
| 31 |
+
"file_format": ".mp4",
|
| 32 |
+
"file_size_mb": round(np.random.uniform(10, 1000), 2),
|
| 33 |
+
"aspect_ratio": 16/9,
|
| 34 |
+
"fps": np.random.choice([24, 30, 60])
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
return pd.DataFrame(data_list)
|
| 38 |
+
|
| 39 |
+
# Mock 데이터 생성
|
| 40 |
+
# df = process_videos_in_directory("/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench")
|
| 41 |
+
df = pd.read_csv("sample.csv")
|
| 42 |
+
print("DataFrame shape:", df.shape)
|
| 43 |
+
print("DataFrame columns:", df.columns)
|
| 44 |
+
print("DataFrame head:\n", df.head())
|
| 45 |
+
def create_category_pie_chart(df, selected_benchmark, selected_categories=None):
|
| 46 |
+
filtered_df = df[df['benchmark'] == selected_benchmark]
|
| 47 |
+
|
| 48 |
+
if selected_categories:
|
| 49 |
+
filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
|
| 50 |
+
|
| 51 |
+
category_counts = filtered_df['category'].value_counts()
|
| 52 |
+
|
| 53 |
+
fig = px.pie(
|
| 54 |
+
values=category_counts.values,
|
| 55 |
+
names=category_counts.index,
|
| 56 |
+
title=f'{selected_benchmark} - Video Distribution by Category',
|
| 57 |
+
hole=0.3
|
| 58 |
+
)
|
| 59 |
+
fig.update_traces(textposition='inside', textinfo='percent+label')
|
| 60 |
+
|
| 61 |
+
return fig
|
| 62 |
+
|
| 63 |
+
###TODO 스트링일경우 어케 처리
|
| 64 |
+
|
| 65 |
+
def create_bar_chart(df, selected_benchmark, selected_categories, selected_column):
|
| 66 |
+
# Filter by benchmark and categories
|
| 67 |
+
filtered_df = df[df['benchmark'] == selected_benchmark]
|
| 68 |
+
if selected_categories:
|
| 69 |
+
filtered_df = filtered_df[filtered_df['category'].isin(selected_categories)]
|
| 70 |
+
|
| 71 |
+
# Create bar chart for selected column
|
| 72 |
+
fig = px.bar(
|
| 73 |
+
filtered_df,
|
| 74 |
+
x=selected_column,
|
| 75 |
+
y='video_name',
|
| 76 |
+
color='category', # Color by category
|
| 77 |
+
title=f'{selected_benchmark} - Video {selected_column}',
|
| 78 |
+
orientation='h', # Horizontal bar chart
|
| 79 |
+
color_discrete_sequence=px.colors.qualitative.Set3 # Color palette
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Adjust layout
|
| 83 |
+
fig.update_layout(
|
| 84 |
+
height=max(400, len(filtered_df) * 30), # Adjust height based on data
|
| 85 |
+
yaxis={'categoryorder': 'total ascending'}, # Sort by value
|
| 86 |
+
margin=dict(l=200), # Margin for long video names
|
| 87 |
+
showlegend=True, # Show legend
|
| 88 |
+
legend=dict(
|
| 89 |
+
orientation="h", # Horizontal legend
|
| 90 |
+
yanchor="bottom",
|
| 91 |
+
y=1.02, # Place legend above graph
|
| 92 |
+
xanchor="right",
|
| 93 |
+
x=1
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return fig
|
| 98 |
+
|
| 99 |
+
def submit_tab():
|
| 100 |
+
with gr.Tab("🚀 Submit here! "):
|
| 101 |
+
with gr.Row():
|
| 102 |
+
gr.Markdown("# ✉️✨ Submit your Result here!")
|
| 103 |
+
|
| 104 |
+
def visual_tab():
|
| 105 |
+
with gr.Tab("📊 Bench Info"):
|
| 106 |
+
with gr.Row():
|
| 107 |
+
benchmark_dropdown = gr.Dropdown(
|
| 108 |
+
choices=sorted(df['benchmark'].unique().tolist()),
|
| 109 |
+
value=sorted(df['benchmark'].unique().tolist())[0],
|
| 110 |
+
label="Select Benchmark",
|
| 111 |
+
interactive=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
category_multiselect = gr.CheckboxGroup(
|
| 115 |
+
choices=sorted(df['category'].unique().tolist()),
|
| 116 |
+
label="Select Categories (empty for all)",
|
| 117 |
+
interactive=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Pie chart
|
| 121 |
+
pie_plot_output = gr.Plot(label="pie")
|
| 122 |
+
|
| 123 |
+
# Column selection dropdown
|
| 124 |
+
column_options = [
|
| 125 |
+
"video_duration", "duration_seconds", "total_frames",
|
| 126 |
+
"file_size_mb", "aspect_ratio", "fps", "file_format"
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
column_dropdown = gr.Dropdown(
|
| 130 |
+
choices=column_options,
|
| 131 |
+
value=column_options[0],
|
| 132 |
+
label="Select Data to Compare",
|
| 133 |
+
interactive=True
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Bar chart
|
| 137 |
+
bar_plot_output = gr.Plot(label="video")
|
| 138 |
+
|
| 139 |
+
def update_plots(benchmark, categories, selected_column):
|
| 140 |
+
pie_chart = create_category_pie_chart(df, benchmark, categories)
|
| 141 |
+
bar_chart = create_bar_chart(df, benchmark, categories, selected_column)
|
| 142 |
+
return pie_chart, bar_chart
|
| 143 |
+
|
| 144 |
+
# Connect event handlers
|
| 145 |
+
benchmark_dropdown.change(
|
| 146 |
+
fn=update_plots,
|
| 147 |
+
inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
|
| 148 |
+
outputs=[pie_plot_output, bar_plot_output]
|
| 149 |
+
)
|
| 150 |
+
category_multiselect.change(
|
| 151 |
+
fn=update_plots,
|
| 152 |
+
inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
|
| 153 |
+
outputs=[pie_plot_output, bar_plot_output]
|
| 154 |
+
)
|
| 155 |
+
column_dropdown.change(
|
| 156 |
+
fn=update_plots,
|
| 157 |
+
inputs=[benchmark_dropdown, category_multiselect, column_dropdown],
|
| 158 |
+
outputs=[pie_plot_output, bar_plot_output]
|
| 159 |
+
)
|
| 160 |
+
|
leaderboard_ui/tab/leaderboard_tab.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from gradio_leaderboard import Leaderboard, SelectColumns, ColumnFilter,SearchColumns
|
| 3 |
+
import enviroments.config as config
|
| 4 |
+
from sheet_manager.sheet_loader.sheet2df import sheet2df
|
| 5 |
+
|
| 6 |
+
def leaderboard_tab():
|
| 7 |
+
with gr.Tab("🏆Leaderboard"):
|
| 8 |
+
leaderboard = Leaderboard(
|
| 9 |
+
value=sheet2df(),
|
| 10 |
+
select_columns=SelectColumns(
|
| 11 |
+
default_selection=config.ON_LOAD_COLUMNS,
|
| 12 |
+
cant_deselect=config.OFF_LOAD_COLUMNS,
|
| 13 |
+
label="Select Columns to Display:",
|
| 14 |
+
info="Check"
|
| 15 |
+
),
|
| 16 |
+
|
| 17 |
+
search_columns=SearchColumns(
|
| 18 |
+
primary_column="Model name",
|
| 19 |
+
secondary_columns=["TASK"],
|
| 20 |
+
placeholder="Search",
|
| 21 |
+
label="Search"
|
| 22 |
+
),
|
| 23 |
+
hide_columns=config.HIDE_COLUMNS,
|
| 24 |
+
filter_columns=[
|
| 25 |
+
ColumnFilter(
|
| 26 |
+
column= "TASK",
|
| 27 |
+
),
|
| 28 |
+
ColumnFilter(
|
| 29 |
+
column="PIA * 100",
|
| 30 |
+
type="slider",
|
| 31 |
+
min=0, # 77
|
| 32 |
+
max=100, # 92
|
| 33 |
+
# default=[min_val, max_val],
|
| 34 |
+
default = [77 ,92],
|
| 35 |
+
label="PIA" # 실제 값의 100배로 표시됨,
|
| 36 |
+
)
|
| 37 |
+
],
|
| 38 |
+
|
| 39 |
+
datatype=config.TYPES,
|
| 40 |
+
# column_widths=["33%", "10%"],
|
| 41 |
+
)
|
| 42 |
+
refresh_button = gr.Button("🔄 Refresh Leaderboard")
|
| 43 |
+
|
| 44 |
+
def refresh_leaderboard():
|
| 45 |
+
return sheet2df()
|
| 46 |
+
|
| 47 |
+
refresh_button.click(
|
| 48 |
+
refresh_leaderboard,
|
| 49 |
+
inputs=[],
|
| 50 |
+
outputs=leaderboard,
|
| 51 |
+
)
|
| 52 |
+
|
leaderboard_ui/tab/metric_visaul_tab.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
abs_path = Path(__file__).parent
|
| 4 |
+
import plotly.express as px
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sheet_manager.sheet_loader.sheet2df import sheet2df
|
| 9 |
+
from sheet_manager.sheet_convert.json2sheet import str2json
|
| 10 |
+
# Mock 데이터 생성
|
| 11 |
+
def calculate_avg_metrics(df):
|
| 12 |
+
"""
|
| 13 |
+
각 모델의 카테고리별 평균 성능 지표를 계산
|
| 14 |
+
"""
|
| 15 |
+
metrics_data = []
|
| 16 |
+
|
| 17 |
+
for _, row in df.iterrows():
|
| 18 |
+
model_name = row['Model name']
|
| 19 |
+
|
| 20 |
+
# PIA가 비어있거나 다른 값인 경우 건너뛰기
|
| 21 |
+
if pd.isna(row['PIA']) or not isinstance(row['PIA'], str):
|
| 22 |
+
print(f"Skipping model {model_name}: Invalid PIA data")
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
metrics = str2json(row['PIA'])
|
| 27 |
+
|
| 28 |
+
# metrics가 None이거나 dict가 아닌 경우 건너뛰기
|
| 29 |
+
if not metrics or not isinstance(metrics, dict):
|
| 30 |
+
print(f"Skipping model {model_name}: Invalid JSON format")
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# 필요한 카테고리가 모두 있는지 확인
|
| 34 |
+
required_categories = ['falldown', 'violence', 'fire']
|
| 35 |
+
if not all(cat in metrics for cat in required_categories):
|
| 36 |
+
print(f"Skipping model {model_name}: Missing required categories")
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
# 필요한 메트릭이 모두 있는지 확인
|
| 40 |
+
required_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 41 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
|
| 42 |
+
|
| 43 |
+
avg_metrics = {}
|
| 44 |
+
for metric in required_metrics:
|
| 45 |
+
try:
|
| 46 |
+
values = [metrics[cat][metric] for cat in required_categories
|
| 47 |
+
if metric in metrics[cat]]
|
| 48 |
+
if values: # 값이 있는 경우만 평균 계산
|
| 49 |
+
avg_metrics[metric] = sum(values) / len(values)
|
| 50 |
+
else:
|
| 51 |
+
avg_metrics[metric] = 0 # 또는 다른 기본값 설정
|
| 52 |
+
except (KeyError, TypeError) as e:
|
| 53 |
+
print(f"Error calculating {metric} for {model_name}: {str(e)}")
|
| 54 |
+
avg_metrics[metric] = 0 # 에러 발생 시 기본값 설정
|
| 55 |
+
|
| 56 |
+
metrics_data.append({
|
| 57 |
+
'model_name': model_name,
|
| 58 |
+
**avg_metrics
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Error processing model {model_name}: {str(e)}")
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
return pd.DataFrame(metrics_data)
|
| 66 |
+
|
| 67 |
+
def create_performance_chart(df, selected_metrics):
|
| 68 |
+
"""
|
| 69 |
+
모델별 선택된 성능 지표의 수평 막대 그래프 생성
|
| 70 |
+
"""
|
| 71 |
+
fig = go.Figure()
|
| 72 |
+
|
| 73 |
+
# 모델 이름 길이에 따른 마진 계산
|
| 74 |
+
max_name_length = max([len(name) for name in df['model_name']])
|
| 75 |
+
left_margin = min(max_name_length * 7, 500) # 글자 수에 따라 마진 조정, 최대 500
|
| 76 |
+
|
| 77 |
+
for metric in selected_metrics:
|
| 78 |
+
fig.add_trace(go.Bar(
|
| 79 |
+
name=metric,
|
| 80 |
+
y=df['model_name'], # y축에 모델 이름
|
| 81 |
+
x=df[metric], # x축에 성능 지표 값
|
| 82 |
+
text=[f'{val:.3f}' for val in df[metric]],
|
| 83 |
+
textposition='auto',
|
| 84 |
+
orientation='h' # 수평 방향 막대
|
| 85 |
+
))
|
| 86 |
+
|
| 87 |
+
fig.update_layout(
|
| 88 |
+
title='Model Performance Comparison',
|
| 89 |
+
yaxis_title='Model Name',
|
| 90 |
+
xaxis_title='Performance',
|
| 91 |
+
barmode='group',
|
| 92 |
+
height=max(400, len(df) * 40), # 모델 수에 따라 높이 조정
|
| 93 |
+
margin=dict(l=left_margin, r=50, t=50, b=50), # 왼쪽 마진 동적 조정
|
| 94 |
+
showlegend=True,
|
| 95 |
+
legend=dict(
|
| 96 |
+
orientation="h",
|
| 97 |
+
yanchor="bottom",
|
| 98 |
+
y=1.02,
|
| 99 |
+
xanchor="right",
|
| 100 |
+
x=1
|
| 101 |
+
),
|
| 102 |
+
yaxis={'categoryorder': 'total ascending'} # 성능 순으로 정렬
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# y축 레이블 스타일 조정
|
| 106 |
+
fig.update_yaxes(tickfont=dict(size=10)) # 글자 크기 조정
|
| 107 |
+
|
| 108 |
+
return fig
|
| 109 |
+
def create_confusion_matrix(metrics_data, selected_category):
|
| 110 |
+
"""혼동 행렬 시각화 생성"""
|
| 111 |
+
# 선택된 카테고리의 혼동 행렬 데이터
|
| 112 |
+
tp = metrics_data[selected_category]['tp']
|
| 113 |
+
tn = metrics_data[selected_category]['tn']
|
| 114 |
+
fp = metrics_data[selected_category]['fp']
|
| 115 |
+
fn = metrics_data[selected_category]['fn']
|
| 116 |
+
|
| 117 |
+
# 혼동 행렬 데이터
|
| 118 |
+
z = [[tn, fp], [fn, tp]]
|
| 119 |
+
x = ['Negative', 'Positive']
|
| 120 |
+
y = ['Negative', 'Positive']
|
| 121 |
+
|
| 122 |
+
# 히트맵 생성
|
| 123 |
+
fig = go.Figure(data=go.Heatmap(
|
| 124 |
+
z=z,
|
| 125 |
+
x=x,
|
| 126 |
+
y=y,
|
| 127 |
+
colorscale=[[0, '#f7fbff'], [1, '#08306b']],
|
| 128 |
+
showscale=False,
|
| 129 |
+
text=[[str(val) for val in row] for row in z],
|
| 130 |
+
texttemplate="%{text}",
|
| 131 |
+
textfont={"color": "black", "size": 16}, # 글자 색���을 검정색으로 고정
|
| 132 |
+
))
|
| 133 |
+
|
| 134 |
+
# 레이아웃 업데이트
|
| 135 |
+
fig.update_layout(
|
| 136 |
+
title={
|
| 137 |
+
'text': f'Confusion Matrix - {selected_category}',
|
| 138 |
+
'y':0.9,
|
| 139 |
+
'x':0.5,
|
| 140 |
+
'xanchor': 'center',
|
| 141 |
+
'yanchor': 'top'
|
| 142 |
+
},
|
| 143 |
+
xaxis_title='Predicted',
|
| 144 |
+
yaxis_title='Actual',
|
| 145 |
+
width=600, # 너비 증가
|
| 146 |
+
height=600, # 높이 증가
|
| 147 |
+
margin=dict(l=80, r=80, t=100, b=80), # 여백 조정
|
| 148 |
+
paper_bgcolor='white',
|
| 149 |
+
plot_bgcolor='white',
|
| 150 |
+
font=dict(size=14) # 전체 폰트 크기 조정
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# 축 설정
|
| 154 |
+
fig.update_xaxes(side="bottom", tickfont=dict(size=14))
|
| 155 |
+
fig.update_yaxes(side="left", tickfont=dict(size=14))
|
| 156 |
+
|
| 157 |
+
return fig
|
| 158 |
+
|
| 159 |
+
def get_metrics_for_model(df, model_name, benchmark_name):
|
| 160 |
+
"""특정 모델과 벤치마크에 대한 메트릭스 데이터 추출"""
|
| 161 |
+
row = df[(df['Model name'] == model_name) & (df['Benchmark'] == benchmark_name)]
|
| 162 |
+
if not row.empty:
|
| 163 |
+
metrics = str2json(row['PIA'].iloc[0])
|
| 164 |
+
return metrics
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
def metric_visual_tab():
|
| 168 |
+
# 데이터 로드
|
| 169 |
+
df = sheet2df(sheet_name="metric")
|
| 170 |
+
avg_metrics_df = calculate_avg_metrics(df)
|
| 171 |
+
|
| 172 |
+
# 가능한 모든 메트릭 리스트
|
| 173 |
+
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 174 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
|
| 175 |
+
|
| 176 |
+
with gr.Tab("📊 Performance Visualization"):
|
| 177 |
+
with gr.Row():
|
| 178 |
+
metrics_multiselect = gr.CheckboxGroup(
|
| 179 |
+
choices=all_metrics,
|
| 180 |
+
value=[], # 초기 선택 없음
|
| 181 |
+
label="Select Performance Metrics",
|
| 182 |
+
interactive=True
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Performance comparison chart (초기값 없음)
|
| 186 |
+
performance_plot = gr.Plot()
|
| 187 |
+
|
| 188 |
+
def update_plot(selected_metrics):
|
| 189 |
+
if not selected_metrics: # 선택된 메트릭이 없는 경우
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
# accuracy 기준으로 정렬
|
| 194 |
+
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
|
| 195 |
+
return create_performance_chart(sorted_df, selected_metrics)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"Error in update_plot: {str(e)}")
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
# Connect event handler
|
| 201 |
+
metrics_multiselect.change(
|
| 202 |
+
fn=update_plot,
|
| 203 |
+
inputs=[metrics_multiselect],
|
| 204 |
+
outputs=[performance_plot]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def create_category_metrics_chart(metrics_data, selected_metrics):
|
| 208 |
+
"""
|
| 209 |
+
선택된 모델의 각 카테고리별 성능 지표 시각화
|
| 210 |
+
"""
|
| 211 |
+
fig = go.Figure()
|
| 212 |
+
categories = ['falldown', 'violence', 'fire']
|
| 213 |
+
|
| 214 |
+
for metric in selected_metrics:
|
| 215 |
+
values = []
|
| 216 |
+
for category in categories:
|
| 217 |
+
values.append(metrics_data[category][metric])
|
| 218 |
+
|
| 219 |
+
fig.add_trace(go.Bar(
|
| 220 |
+
name=metric,
|
| 221 |
+
x=categories,
|
| 222 |
+
y=values,
|
| 223 |
+
text=[f'{val:.3f}' for val in values],
|
| 224 |
+
textposition='auto',
|
| 225 |
+
))
|
| 226 |
+
|
| 227 |
+
fig.update_layout(
|
| 228 |
+
title='Performance Metrics by Category',
|
| 229 |
+
xaxis_title='Category',
|
| 230 |
+
yaxis_title='Score',
|
| 231 |
+
barmode='group',
|
| 232 |
+
height=500,
|
| 233 |
+
showlegend=True,
|
| 234 |
+
legend=dict(
|
| 235 |
+
orientation="h",
|
| 236 |
+
yanchor="bottom",
|
| 237 |
+
y=1.02,
|
| 238 |
+
xanchor="right",
|
| 239 |
+
x=1
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return fig
|
| 244 |
+
|
| 245 |
+
def metric_visual_tab():
|
| 246 |
+
# 데이터 로드 및 첫 번째 시각화 부분
|
| 247 |
+
df = sheet2df(sheet_name="metric")
|
| 248 |
+
avg_metrics_df = calculate_avg_metrics(df)
|
| 249 |
+
|
| 250 |
+
# 가능한 모든 메트릭 리스트
|
| 251 |
+
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 252 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
|
| 253 |
+
|
| 254 |
+
with gr.Tab("📊 Performance Visualization"):
|
| 255 |
+
with gr.Row():
|
| 256 |
+
metrics_multiselect = gr.CheckboxGroup(
|
| 257 |
+
choices=all_metrics,
|
| 258 |
+
value=[], # 초기 선택 없음
|
| 259 |
+
label="Select Performance Metrics",
|
| 260 |
+
interactive=True
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
performance_plot = gr.Plot()
|
| 264 |
+
|
| 265 |
+
def update_plot(selected_metrics):
|
| 266 |
+
if not selected_metrics:
|
| 267 |
+
return None
|
| 268 |
+
try:
|
| 269 |
+
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
|
| 270 |
+
return create_performance_chart(sorted_df, selected_metrics)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Error in update_plot: {str(e)}")
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
metrics_multiselect.change(
|
| 276 |
+
fn=update_plot,
|
| 277 |
+
inputs=[metrics_multiselect],
|
| 278 |
+
outputs=[performance_plot]
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# 두 번째 시각화 섹션
|
| 282 |
+
gr.Markdown("## Detailed Model Analysis")
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
# 모델 선택
|
| 286 |
+
model_dropdown = gr.Dropdown(
|
| 287 |
+
choices=sorted(df['Model name'].unique().tolist()),
|
| 288 |
+
label="Select Model",
|
| 289 |
+
interactive=True
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# 컬럼 선택 (Model name 제외)
|
| 293 |
+
column_dropdown = gr.Dropdown(
|
| 294 |
+
choices=[col for col in df.columns if col != 'Model name'],
|
| 295 |
+
label="Select Metric Column",
|
| 296 |
+
interactive=True
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# 카테고리 선택
|
| 300 |
+
category_dropdown = gr.Dropdown(
|
| 301 |
+
choices=['falldown', 'violence', 'fire'],
|
| 302 |
+
label="Select Category",
|
| 303 |
+
interactive=True
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# 혼동 행렬 시각화
|
| 307 |
+
with gr.Row():
|
| 308 |
+
with gr.Column(scale=1):
|
| 309 |
+
gr.Markdown("") # 빈 공간
|
| 310 |
+
with gr.Column(scale=2):
|
| 311 |
+
confusion_matrix_plot = gr.Plot(container=True) # container=True 추가
|
| 312 |
+
with gr.Column(scale=1):
|
| 313 |
+
gr.Markdown("") # 빈 공간
|
| 314 |
+
|
| 315 |
+
with gr.Column(scale=2):
|
| 316 |
+
# 성능 지표 선택
|
| 317 |
+
metrics_select = gr.CheckboxGroup(
|
| 318 |
+
choices=['accuracy', 'precision', 'recall', 'specificity', 'f1',
|
| 319 |
+
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'],
|
| 320 |
+
value=['accuracy'], # 기본값
|
| 321 |
+
label="Select Metrics to Display",
|
| 322 |
+
interactive=True
|
| 323 |
+
)
|
| 324 |
+
category_metrics_plot = gr.Plot()
|
| 325 |
+
|
| 326 |
+
def update_visualizations(model, column, category, selected_metrics):
|
| 327 |
+
if not all([model, column]): # category는 혼동행렬에만 필요
|
| 328 |
+
return None, None
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
# 선택된 모델의 데이터 가져오기
|
| 332 |
+
selected_data = df[df['Model name'] == model][column].iloc[0]
|
| 333 |
+
metrics = str2json(selected_data)
|
| 334 |
+
|
| 335 |
+
if not metrics:
|
| 336 |
+
return None, None
|
| 337 |
+
|
| 338 |
+
# 혼동 행렬 (왼쪽)
|
| 339 |
+
confusion_fig = create_confusion_matrix(metrics, category) if category else None
|
| 340 |
+
|
| 341 |
+
# 카테고리별 성능 지표 (오른쪽)
|
| 342 |
+
if not selected_metrics:
|
| 343 |
+
selected_metrics = ['accuracy']
|
| 344 |
+
category_fig = create_category_metrics_chart(metrics, selected_metrics)
|
| 345 |
+
|
| 346 |
+
return confusion_fig, category_fig
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"Error updating visualizations: {str(e)}")
|
| 350 |
+
return None, None
|
| 351 |
+
|
| 352 |
+
# 이벤트 핸들러 연결
|
| 353 |
+
for input_component in [model_dropdown, column_dropdown, category_dropdown, metrics_select]:
|
| 354 |
+
input_component.change(
|
| 355 |
+
fn=update_visualizations,
|
| 356 |
+
inputs=[model_dropdown, column_dropdown, category_dropdown, metrics_select],
|
| 357 |
+
outputs=[confusion_matrix_plot, category_metrics_plot]
|
| 358 |
+
)
|
| 359 |
+
# def update_confusion_matrix(model, column, category):
|
| 360 |
+
# if not all([model, column, category]):
|
| 361 |
+
# return None
|
| 362 |
+
|
| 363 |
+
# try:
|
| 364 |
+
# # 선택된 모델의 데이터 가져오기
|
| 365 |
+
# selected_data = df[df['Model name'] == model][column].iloc[0]
|
| 366 |
+
# metrics = str2json(selected_data)
|
| 367 |
+
|
| 368 |
+
# if metrics and category in metrics:
|
| 369 |
+
# category_data = metrics[category]
|
| 370 |
+
|
| 371 |
+
# # 혼동 행렬 데이터
|
| 372 |
+
# confusion_data = {
|
| 373 |
+
# 'tp': category_data['tp'],
|
| 374 |
+
# 'tn': category_data['tn'],
|
| 375 |
+
# 'fp': category_data['fp'],
|
| 376 |
+
# 'fn': category_data['fn']
|
| 377 |
+
# }
|
| 378 |
+
|
| 379 |
+
# # 히트맵 생성
|
| 380 |
+
# z = [[confusion_data['tn'], confusion_data['fp']],
|
| 381 |
+
# [confusion_data['fn'], confusion_data['tp']]]
|
| 382 |
+
|
| 383 |
+
# fig = go.Figure(data=go.Heatmap(
|
| 384 |
+
# z=z,
|
| 385 |
+
# x=['Negative', 'Positive'],
|
| 386 |
+
# y=['Negative', 'Positive'],
|
| 387 |
+
# text=[[str(val) for val in row] for row in z],
|
| 388 |
+
# texttemplate="%{text}",
|
| 389 |
+
# textfont={"size": 16},
|
| 390 |
+
# colorscale='Blues',
|
| 391 |
+
# showscale=False
|
| 392 |
+
# ))
|
| 393 |
+
|
| 394 |
+
# fig.update_layout(
|
| 395 |
+
# title=f'Confusion Matrix - {category}',
|
| 396 |
+
# xaxis_title='Predicted',
|
| 397 |
+
# yaxis_title='Actual',
|
| 398 |
+
# width=500,
|
| 399 |
+
# height=500
|
| 400 |
+
# )
|
| 401 |
+
|
| 402 |
+
# return fig
|
| 403 |
+
|
| 404 |
+
# except Exception as e:
|
| 405 |
+
# print(f"Error updating confusion matrix: {str(e)}")
|
| 406 |
+
# return None
|
| 407 |
+
|
| 408 |
+
# # 이벤트 핸들러 연결
|
| 409 |
+
# for dropdown in [model_dropdown, column_dropdown, category_dropdown]:
|
| 410 |
+
# dropdown.change(
|
| 411 |
+
# fn=update_confusion_matrix,
|
| 412 |
+
# inputs=[model_dropdown, column_dropdown, category_dropdown],
|
| 413 |
+
# outputs=confusion_matrix_plot
|
| 414 |
+
# )
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
leaderboard_ui/tab/submit_tab.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
def list_to_dataframe(data):
|
| 6 |
+
"""
|
| 7 |
+
리스트 데이터를 데이터프레임으로 변환하는 함수.
|
| 8 |
+
각 값이 데이터프레임의 한 행(row)에 들어가도록 설정.
|
| 9 |
+
|
| 10 |
+
:param data: 리스트 형태의 데이터
|
| 11 |
+
:return: pandas.DataFrame
|
| 12 |
+
"""
|
| 13 |
+
if not isinstance(data, list):
|
| 14 |
+
raise ValueError("입력 데이터는 리스트 형태여야 합니다.")
|
| 15 |
+
|
| 16 |
+
# 열 이름을 문자열로 설정
|
| 17 |
+
headers = [f"Queue {i}" for i in range(len(data))]
|
| 18 |
+
df = pd.DataFrame([data], columns=headers)
|
| 19 |
+
return df
|
| 20 |
+
|
| 21 |
+
def model_submit(model_id , benchmark_name, prompt_cfg_name):
|
| 22 |
+
model_id = model_id.split("/")[-1]
|
| 23 |
+
sheet_manager = SheetManager()
|
| 24 |
+
sheet_manager.push(model_id)
|
| 25 |
+
model_q = list_to_dataframe(sheet_manager.get_all_values())
|
| 26 |
+
sheet_manager.change_column("benchmark_name")
|
| 27 |
+
sheet_manager.push(benchmark_name)
|
| 28 |
+
sheet_manager.change_column("prompt_cfg_name")
|
| 29 |
+
sheet_manager.push(prompt_cfg_name)
|
| 30 |
+
|
| 31 |
+
return model_q
|
| 32 |
+
|
| 33 |
+
def read_queue():
|
| 34 |
+
sheet_manager = SheetManager()
|
| 35 |
+
return list_to_dataframe(sheet_manager.get_all_values())
|
| 36 |
+
|
| 37 |
+
def submit_tab():
|
| 38 |
+
with gr.Tab("🚀 Submit here! "):
|
| 39 |
+
with gr.Row():
|
| 40 |
+
gr.Markdown("# ✉️✨ Submit your Result here!")
|
| 41 |
+
|
| 42 |
+
with gr.Row():
|
| 43 |
+
with gr.Tab("Model"):
|
| 44 |
+
with gr.Row():
|
| 45 |
+
with gr.Column():
|
| 46 |
+
model_id_textbox = gr.Textbox(
|
| 47 |
+
label="huggingface_id",
|
| 48 |
+
placeholder="PIA-SPACE-LAB/T2V_CLIP4Clip",
|
| 49 |
+
interactive = True
|
| 50 |
+
)
|
| 51 |
+
benchmark_name_textbox = gr.Textbox(
|
| 52 |
+
label="benchmark_name",
|
| 53 |
+
placeholder="PiaFSV",
|
| 54 |
+
interactive = True,
|
| 55 |
+
value="PIA"
|
| 56 |
+
)
|
| 57 |
+
prompt_cfg_name_textbox = gr.Textbox(
|
| 58 |
+
label="prompt_cfg_name",
|
| 59 |
+
placeholder="topk",
|
| 60 |
+
interactive = True,
|
| 61 |
+
value="topk"
|
| 62 |
+
)
|
| 63 |
+
with gr.Column():
|
| 64 |
+
gr.Markdown("## 평가를 받아보세요 반드시 허깅페이스에 업로드된 모델이어야 합니다.")
|
| 65 |
+
gr.Markdown("#### 현재 평가 대기중 모델입니다.")
|
| 66 |
+
model_queue = gr.Dataframe()
|
| 67 |
+
refresh_button = gr.Button("refresh")
|
| 68 |
+
refresh_button.click(
|
| 69 |
+
fn=read_queue,
|
| 70 |
+
outputs=model_queue
|
| 71 |
+
)
|
| 72 |
+
with gr.Row():
|
| 73 |
+
model_submit_button = gr.Button("Submit Eval")
|
| 74 |
+
model_submit_button.click(
|
| 75 |
+
fn=model_submit,
|
| 76 |
+
inputs=[model_id_textbox,
|
| 77 |
+
benchmark_name_textbox ,
|
| 78 |
+
prompt_cfg_name_textbox],
|
| 79 |
+
outputs=model_queue
|
| 80 |
+
)
|
| 81 |
+
with gr.Tab("Prompt"):
|
| 82 |
+
with gr.Row():
|
| 83 |
+
with gr.Column():
|
| 84 |
+
prompt_cfg_selector = gr.Dropdown(
|
| 85 |
+
choices=["전부"],
|
| 86 |
+
label="Prompt_CFG",
|
| 87 |
+
multiselect=False,
|
| 88 |
+
value=None,
|
| 89 |
+
interactive=True,
|
| 90 |
+
)
|
| 91 |
+
weight_type = gr.Dropdown(
|
| 92 |
+
choices=["전부"],
|
| 93 |
+
label="Weights type",
|
| 94 |
+
multiselect=False,
|
| 95 |
+
value=None,
|
| 96 |
+
interactive=True,
|
| 97 |
+
)
|
| 98 |
+
with gr.Column():
|
| 99 |
+
gr.Markdown("## 평가를 받아보세요 반드시 허깅페이스에 업로드된 모델이어야 합니다.")
|
| 100 |
+
|
| 101 |
+
with gr.Row():
|
| 102 |
+
prompt_submit_button = gr.Button("Submit Eval")
|
| 103 |
+
|
main.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 5 |
+
from sheet_manager.sheet_monitor.sheet_sync import SheetMonitor, MainLoop
|
| 6 |
+
import time
|
| 7 |
+
from pia_bench.pipe_line.piepline import BenchmarkPipeline, PipelineConfig
|
| 8 |
+
from sheet_manager.sheet_convert.json2sheet import update_benchmark_json
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
def calculate_total_accuracy(metrics: dict) -> float:
|
| 14 |
+
"""
|
| 15 |
+
Calculate the average accuracy across all categories excluding 'micro_avg'.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
metrics (dict): Metrics dictionary containing accuracy values.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
float: The average accuracy across categories.
|
| 22 |
+
"""
|
| 23 |
+
total_accuracy = 0
|
| 24 |
+
total_count = 0
|
| 25 |
+
|
| 26 |
+
for category, values in metrics.items():
|
| 27 |
+
if category == "micro_avg":
|
| 28 |
+
continue # Skip 'micro_avg'
|
| 29 |
+
|
| 30 |
+
if "accuracy" in values:
|
| 31 |
+
total_accuracy += values["accuracy"]
|
| 32 |
+
total_count += 1
|
| 33 |
+
|
| 34 |
+
if total_count == 0:
|
| 35 |
+
raise ValueError("No accuracy values found in the provided metrics dictionary.")
|
| 36 |
+
|
| 37 |
+
return total_accuracy / total_count
|
| 38 |
+
|
| 39 |
+
def my_custom_function(huggingface_id, benchmark_name, prompt_cfg_name):
|
| 40 |
+
model_name = huggingface_id.split("/")[-1]
|
| 41 |
+
config = PipelineConfig(
|
| 42 |
+
model_name=model_name,
|
| 43 |
+
benchmark_name=benchmark_name,
|
| 44 |
+
cfg_target_path=f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{benchmark_name}/CFG/{prompt_cfg_name}.json",
|
| 45 |
+
base_path="/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
|
| 46 |
+
)
|
| 47 |
+
pipeline = BenchmarkPipeline(config)
|
| 48 |
+
pipeline.run()
|
| 49 |
+
result = pipeline.bench_result_dict
|
| 50 |
+
value = calculate_total_accuracy(result)
|
| 51 |
+
print("---"*50)
|
| 52 |
+
sheet = SheetManager()
|
| 53 |
+
sheet.change_worksheet("model")
|
| 54 |
+
sheet.update_cell_by_condition(condition_column="Model name",
|
| 55 |
+
condition_value=model_name ,
|
| 56 |
+
target_column=benchmark_name,
|
| 57 |
+
target_value=value)
|
| 58 |
+
|
| 59 |
+
update_benchmark_json(
|
| 60 |
+
model_name = model_name,
|
| 61 |
+
benchmark_data = result,
|
| 62 |
+
target_column = benchmark_name # 타겟 칼럼 파라미터 추가
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
print(f"\n파이프라인 실행 결과:")
|
| 66 |
+
|
| 67 |
+
sheet_manager = SheetManager()
|
| 68 |
+
monitor = SheetMonitor(sheet_manager, check_interval=60.0)
|
| 69 |
+
main_loop = MainLoop(sheet_manager, monitor, callback_function=my_custom_function)
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
main_loop.start()
|
| 73 |
+
while True:
|
| 74 |
+
time.sleep(5)
|
| 75 |
+
except KeyboardInterrupt:
|
| 76 |
+
main_loop.stop()
|
pia_bench/bench.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from devmacs_core.devmacs_core import DevMACSCore
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from utils.except_dir import cust_listdir
|
| 9 |
+
def load_config(config_path: str) -> Dict:
|
| 10 |
+
"""JSON 설정 파일을 읽어서 딕셔너리로 반환"""
|
| 11 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 12 |
+
return json.load(f)
|
| 13 |
+
|
| 14 |
+
DATA_SET = "dataset"
|
| 15 |
+
CFG = "CFG"
|
| 16 |
+
VECTOR = "vector"
|
| 17 |
+
TEXT = "text"
|
| 18 |
+
VIDEO = "video"
|
| 19 |
+
EXECPT = ["@eaDir", "README.md"]
|
| 20 |
+
ALRAM = "alarm"
|
| 21 |
+
METRIC = "metric"
|
| 22 |
+
MSRVTT = "MSRVTT"
|
| 23 |
+
MODEL = "models"
|
| 24 |
+
|
| 25 |
+
class PiaBenchMark:
|
| 26 |
+
def __init__(self, benchmark_path , cfg_target_path : str = None , model_name : str = MSRVTT , token:str =None):
|
| 27 |
+
self.benchmark_path = benchmark_path
|
| 28 |
+
self.token = token
|
| 29 |
+
self.model_name = model_name
|
| 30 |
+
self.devmacs_core = None
|
| 31 |
+
self.cfg_target_path = cfg_target_path
|
| 32 |
+
self.cfg_name = Path(cfg_target_path).stem
|
| 33 |
+
self.cfg_dict = load_config(self.cfg_target_path)
|
| 34 |
+
|
| 35 |
+
self.dataset_path = os.path.join(benchmark_path, DATA_SET)
|
| 36 |
+
self.cfg_path = os.path.join(benchmark_path , CFG)
|
| 37 |
+
|
| 38 |
+
self.model_path = os.path.join(self.benchmark_path , MODEL)
|
| 39 |
+
self.model_name_path = os.path.join(self.model_path ,self.model_name)
|
| 40 |
+
self.model_name_cfg_path = os.path.join(self.model_name_path , CFG)
|
| 41 |
+
self.model_name_cfg_name_path = os.path.join(self.model_name_cfg_path , self.cfg_name)
|
| 42 |
+
self.alram_path = os.path.join(self.model_name_cfg_name_path , ALRAM)
|
| 43 |
+
self.metric_path = os.path.join(self.model_name_cfg_name_path , METRIC)
|
| 44 |
+
|
| 45 |
+
self.vector_path = os.path.join(self.model_name_path , VECTOR)
|
| 46 |
+
self.vector_text_path = os.path.join(self.vector_path , TEXT)
|
| 47 |
+
self.vector_video_path = os.path.join(self.vector_path , VIDEO)
|
| 48 |
+
|
| 49 |
+
self.categories = []
|
| 50 |
+
|
| 51 |
+
def _create_frame_labels(self, label_data: Dict, total_frames: int) -> pd.DataFrame:
|
| 52 |
+
"""프레임 기반의 레이블 데이터프레임 생성"""
|
| 53 |
+
colmuns = ['frame'] + sorted(self.categories)
|
| 54 |
+
df = pd.DataFrame(0, index=range(total_frames), columns=colmuns)
|
| 55 |
+
df['frame'] = range(total_frames)
|
| 56 |
+
|
| 57 |
+
for clip_info in label_data['clips'].values():
|
| 58 |
+
category = clip_info['category']
|
| 59 |
+
if category in self.categories: # 해당 카테고리가 목록에 있는 경우만 처리
|
| 60 |
+
start_frame, end_frame = clip_info['timestamp']
|
| 61 |
+
df.loc[start_frame:end_frame, category] = 1
|
| 62 |
+
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
def preprocess_label_to_csv(self):
|
| 66 |
+
"""데이터셋의 모든 JSON 라벨을 프레임 기반 CSV로 변환"""
|
| 67 |
+
json_files = []
|
| 68 |
+
csv_files = []
|
| 69 |
+
|
| 70 |
+
# categories가 비어있는 경우에만 채우도록 수정
|
| 71 |
+
if not self.categories:
|
| 72 |
+
for cate in cust_listdir(self.dataset_path):
|
| 73 |
+
if os.path.isdir(os.path.join(self.dataset_path, cate)):
|
| 74 |
+
self.categories.append(cate)
|
| 75 |
+
|
| 76 |
+
for category in self.categories:
|
| 77 |
+
category_path = os.path.join(self.dataset_path, category)
|
| 78 |
+
category_jsons = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.json')]
|
| 79 |
+
json_files.extend(category_jsons)
|
| 80 |
+
category_csvs = [os.path.join(category, f) for f in cust_listdir(category_path) if f.endswith('.csv')]
|
| 81 |
+
csv_files.extend(category_csvs)
|
| 82 |
+
|
| 83 |
+
if not json_files:
|
| 84 |
+
raise ValueError("No JSON files found in any category directory")
|
| 85 |
+
|
| 86 |
+
if len(json_files) == len(csv_files):
|
| 87 |
+
print("All JSON files have already been processed to CSV. No further processing needed.")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
for json_file in json_files:
|
| 91 |
+
json_path = os.path.join(self.dataset_path, json_file)
|
| 92 |
+
video_name = os.path.splitext(json_file)[0]
|
| 93 |
+
|
| 94 |
+
label_info = load_config(json_path)
|
| 95 |
+
video_info = label_info['video_info']
|
| 96 |
+
total_frames = video_info['total_frame']
|
| 97 |
+
|
| 98 |
+
df = self._create_frame_labels( label_info, total_frames)
|
| 99 |
+
|
| 100 |
+
output_path = os.path.join(self.dataset_path, f"{video_name}.csv")
|
| 101 |
+
df.to_csv(output_path , index=False)
|
| 102 |
+
print("Complete !")
|
| 103 |
+
|
| 104 |
+
def preprocess_structure(self):
|
| 105 |
+
os.makedirs(self.dataset_path, exist_ok=True)
|
| 106 |
+
os.makedirs(self.cfg_path, exist_ok=True)
|
| 107 |
+
os.makedirs(self.vector_text_path, exist_ok=True)
|
| 108 |
+
os.makedirs(self.vector_video_path, exist_ok=True)
|
| 109 |
+
os.makedirs(self.alram_path, exist_ok=True)
|
| 110 |
+
os.makedirs(self.metric_path, exist_ok=True)
|
| 111 |
+
os.makedirs(self.model_name_cfg_name_path , exist_ok=True)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# dataset 폴더가 이미 존재하고 그 안에 카테고리 폴더들이 있는지 확인
|
| 115 |
+
if os.path.exists(self.dataset_path) and any(os.path.isdir(os.path.join(self.dataset_path, d)) for d in cust_listdir(self.dataset_path)):
|
| 116 |
+
# 이미 구성된 구조라면, dataset 폴더에서 카테고리들을 가져옴
|
| 117 |
+
self.categories = [d for d in cust_listdir(self.dataset_path) if os.path.isdir(os.path.join(self.dataset_path, d))]
|
| 118 |
+
else:
|
| 119 |
+
# 처음 실행되는 경우, 기존 로직대로 진행
|
| 120 |
+
for item in cust_listdir(self.benchmark_path):
|
| 121 |
+
item_path = os.path.join(self.benchmark_path, item)
|
| 122 |
+
|
| 123 |
+
if item.startswith("@") or item in [METRIC ,"README.md",MODEL, CFG, DATA_SET, VECTOR, ALRAM] or not os.path.isdir(item_path):
|
| 124 |
+
continue
|
| 125 |
+
target_path = os.path.join(self.dataset_path, item)
|
| 126 |
+
if not os.path.exists(target_path):
|
| 127 |
+
shutil.move(item_path, target_path)
|
| 128 |
+
self.categories.append(item)
|
| 129 |
+
|
| 130 |
+
for category in self.categories:
|
| 131 |
+
category_path = os.path.join(self.vector_video_path, category)
|
| 132 |
+
os.makedirs(category_path, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
print("Folder preprocessing completed.")
|
| 135 |
+
|
| 136 |
+
def extract_visual_vector(self):
|
| 137 |
+
self.devmacs_core = DevMACSCore.from_huggingface(token=self.token, repo_id=f"PIA-SPACE-LAB/{self.model_name}")
|
| 138 |
+
self.devmacs_core.save_visual_results(
|
| 139 |
+
vid_dir = self.dataset_path,
|
| 140 |
+
result_dir = self.vector_video_path
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
from dotenv import load_dotenv
|
| 145 |
+
import os
|
| 146 |
+
load_dotenv()
|
| 147 |
+
|
| 148 |
+
access_token = os.getenv("ACCESS_TOKEN")
|
| 149 |
+
model_name = "T2V_CLIP4CLIP_MSRVTT"
|
| 150 |
+
|
| 151 |
+
benchmark_path = "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA"
|
| 152 |
+
cfg_target_path= "/home/jungseoik/data/Abnormal_situation_leader_board/assets/PIA/CFG/topk.json"
|
| 153 |
+
|
| 154 |
+
pia_benchmark = PiaBenchMark(benchmark_path ,model_name=model_name, cfg_target_path= cfg_target_path , token=access_token )
|
| 155 |
+
pia_benchmark.preprocess_structure()
|
| 156 |
+
pia_benchmark.preprocess_label_to_csv()
|
| 157 |
+
print("Categories identified:", pia_benchmark.categories)
|
pia_bench/checker/bench_checker.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Dict, Optional, Tuple
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
logging.basicConfig(level=logging.INFO)
|
| 8 |
+
|
| 9 |
+
class BenchChecker:
|
| 10 |
+
def __init__(self, base_path: str):
|
| 11 |
+
"""Initialize BenchChecker with base assets path.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
base_path (str): Base path to assets directory containing benchmark folders
|
| 15 |
+
"""
|
| 16 |
+
self.base_path = Path(base_path)
|
| 17 |
+
self.logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
def check_benchmark_exists(self, benchmark_name: str) -> bool:
|
| 20 |
+
"""Check if benchmark folder exists."""
|
| 21 |
+
benchmark_path = self.base_path / benchmark_name
|
| 22 |
+
exists = benchmark_path.exists() and benchmark_path.is_dir()
|
| 23 |
+
if exists:
|
| 24 |
+
self.logger.info(f"Found benchmark directory: {benchmark_name}")
|
| 25 |
+
else:
|
| 26 |
+
self.logger.error(f"Benchmark directory not found: {benchmark_name}")
|
| 27 |
+
return exists
|
| 28 |
+
|
| 29 |
+
def get_video_list(self, benchmark_name: str) -> List[str]:
|
| 30 |
+
"""Get list of videos from benchmark's dataset directory. Return empty list if no videos found."""
|
| 31 |
+
dataset_path = self.base_path / benchmark_name / "dataset"
|
| 32 |
+
videos = []
|
| 33 |
+
|
| 34 |
+
if not dataset_path.exists():
|
| 35 |
+
self.logger.info(f"Dataset directory exists but no videos found for {benchmark_name}")
|
| 36 |
+
return videos # 빈 리스트 반환
|
| 37 |
+
|
| 38 |
+
# Recursively find all .mp4 files
|
| 39 |
+
for category in dataset_path.glob("*"):
|
| 40 |
+
if category.is_dir():
|
| 41 |
+
for video_file in category.glob("*.mp4"):
|
| 42 |
+
videos.append(video_file.stem)
|
| 43 |
+
|
| 44 |
+
self.logger.info(f"Found {len(videos)} videos in {benchmark_name} dataset")
|
| 45 |
+
return videos
|
| 46 |
+
|
| 47 |
+
def check_model_exists(self, benchmark_name: str, model_name: str) -> bool:
|
| 48 |
+
"""Check if model directory exists in benchmark's models directory."""
|
| 49 |
+
model_path = self.base_path / benchmark_name / "models" / model_name
|
| 50 |
+
exists = model_path.exists() and model_path.is_dir()
|
| 51 |
+
if exists:
|
| 52 |
+
self.logger.info(f"Found model directory: {model_name}")
|
| 53 |
+
else:
|
| 54 |
+
self.logger.error(f"Model directory not found: {model_name}")
|
| 55 |
+
return exists
|
| 56 |
+
|
| 57 |
+
def check_cfg_files(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Tuple[bool, bool]:
|
| 58 |
+
"""Check if CFG files/directories exist in both benchmark and model directories."""
|
| 59 |
+
# Check benchmark CFG json
|
| 60 |
+
benchmark_cfg = self.base_path / benchmark_name / "CFG" / f"{cfg_prompt}.json"
|
| 61 |
+
benchmark_cfg_exists = benchmark_cfg.exists() and benchmark_cfg.is_file()
|
| 62 |
+
|
| 63 |
+
# Check model CFG directory
|
| 64 |
+
model_cfg = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt
|
| 65 |
+
model_cfg_exists = model_cfg.exists() and model_cfg.is_dir()
|
| 66 |
+
|
| 67 |
+
if benchmark_cfg_exists:
|
| 68 |
+
self.logger.info(f"Found benchmark CFG file: {cfg_prompt}.json")
|
| 69 |
+
else:
|
| 70 |
+
self.logger.error(f"Benchmark CFG file not found: {cfg_prompt}.json")
|
| 71 |
+
|
| 72 |
+
if model_cfg_exists:
|
| 73 |
+
self.logger.info(f"Found model CFG directory: {cfg_prompt}")
|
| 74 |
+
else:
|
| 75 |
+
self.logger.error(f"Model CFG directory not found: {cfg_prompt}")
|
| 76 |
+
|
| 77 |
+
return benchmark_cfg_exists, model_cfg_exists
|
| 78 |
+
def check_vector_files(self, benchmark_name: str, model_name: str, video_list: List[str]) -> bool:
|
| 79 |
+
"""Check if video vectors match with dataset."""
|
| 80 |
+
vector_path = self.base_path / benchmark_name / "models" / model_name / "vector" / "video"
|
| 81 |
+
|
| 82 |
+
# 비디오가 없는 경우는 무조건 False
|
| 83 |
+
if not video_list:
|
| 84 |
+
self.logger.error("No videos found in dataset - cannot proceed")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
# 벡터 디렉토리가 있는지 확인
|
| 88 |
+
if not vector_path.exists():
|
| 89 |
+
self.logger.error("Vector directory doesn't exist")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
# 벡터 파일 리스트 가져오기
|
| 93 |
+
# vector_files = [f.stem for f in vector_path.glob("*.npy")]
|
| 94 |
+
vector_files = [f.stem for f in vector_path.rglob("*.npy")]
|
| 95 |
+
|
| 96 |
+
missing_vectors = set(video_list) - set(vector_files)
|
| 97 |
+
extra_vectors = set(vector_files) - set(video_list)
|
| 98 |
+
|
| 99 |
+
if missing_vectors:
|
| 100 |
+
self.logger.error(f"Missing vectors for videos: {missing_vectors}")
|
| 101 |
+
return False
|
| 102 |
+
if extra_vectors:
|
| 103 |
+
self.logger.error(f"Extra vectors found: {extra_vectors}")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
self.logger.info(f"Vector status: videos={len(video_list)}, vectors={len(vector_files)}")
|
| 107 |
+
return len(video_list) == len(vector_files)
|
| 108 |
+
|
| 109 |
+
def check_metrics_file(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> bool:
|
| 110 |
+
"""Check if overall_metrics.json exists in the model's CFG/metrics directory."""
|
| 111 |
+
metrics_path = self.base_path / benchmark_name / "models" / model_name / "CFG" / cfg_prompt / "metric" / "overall_metrics.json"
|
| 112 |
+
exists = metrics_path.exists() and metrics_path.is_file()
|
| 113 |
+
|
| 114 |
+
if exists:
|
| 115 |
+
self.logger.info(f"Found overall metrics file for {model_name}")
|
| 116 |
+
else:
|
| 117 |
+
self.logger.error(f"Overall metrics file not found for {model_name}")
|
| 118 |
+
return exists
|
| 119 |
+
|
| 120 |
+
def check_benchmark(self, benchmark_name: str, model_name: str, cfg_prompt: str) -> Dict[str, bool]:
|
| 121 |
+
"""
|
| 122 |
+
Perform all benchmark checks and return status.
|
| 123 |
+
"""
|
| 124 |
+
status = {
|
| 125 |
+
'benchmark_exists': False,
|
| 126 |
+
'model_exists': False,
|
| 127 |
+
'cfg_files_exist': False,
|
| 128 |
+
'vectors_match': False,
|
| 129 |
+
'metrics_exist': False
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Check benchmark directory
|
| 133 |
+
status['benchmark_exists'] = self.check_benchmark_exists(benchmark_name)
|
| 134 |
+
if not status['benchmark_exists']:
|
| 135 |
+
return status
|
| 136 |
+
|
| 137 |
+
# Get video list
|
| 138 |
+
video_list = self.get_video_list(benchmark_name)
|
| 139 |
+
|
| 140 |
+
# Check model directory
|
| 141 |
+
status['model_exists'] = self.check_model_exists(benchmark_name, model_name)
|
| 142 |
+
if not status['model_exists']:
|
| 143 |
+
return status
|
| 144 |
+
|
| 145 |
+
# Check CFG files
|
| 146 |
+
benchmark_cfg, model_cfg = self.check_cfg_files(benchmark_name, model_name, cfg_prompt)
|
| 147 |
+
status['cfg_files_exist'] = benchmark_cfg and model_cfg
|
| 148 |
+
if not status['cfg_files_exist']:
|
| 149 |
+
return status
|
| 150 |
+
|
| 151 |
+
# Check vectors
|
| 152 |
+
status['vectors_match'] = self.check_vector_files(benchmark_name, model_name, video_list)
|
| 153 |
+
|
| 154 |
+
# Check metrics file (only if vectors match)
|
| 155 |
+
if status['vectors_match']:
|
| 156 |
+
status['metrics_exist'] = self.check_metrics_file(benchmark_name, model_name, cfg_prompt)
|
| 157 |
+
|
| 158 |
+
return status
|
| 159 |
+
|
| 160 |
+
def get_benchmark_status(self, check_status: Dict[str, bool]) -> str:
|
| 161 |
+
"""Determine which execution path to take based on check results."""
|
| 162 |
+
basic_checks = ['benchmark_exists', 'model_exists', 'cfg_files_exist']
|
| 163 |
+
if not all(check_status[check] for check in basic_checks):
|
| 164 |
+
return "cannot_execute"
|
| 165 |
+
if check_status['vectors_match'] and check_status['metrics_exist']:
|
| 166 |
+
return "all_passed"
|
| 167 |
+
elif not check_status['vectors_match']:
|
| 168 |
+
return "no_vectors"
|
| 169 |
+
else: # vectors exist but no metrics
|
| 170 |
+
return "no_metrics"
|
| 171 |
+
|
| 172 |
+
# Example usage
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
|
| 175 |
+
bench_checker = BenchChecker("assets")
|
| 176 |
+
status = bench_checker.check_benchmark(
|
| 177 |
+
benchmark_name="huggingface_benchmarks_dataset",
|
| 178 |
+
model_name="MSRVTT",
|
| 179 |
+
cfg_prompt="topk"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
execution_path = bench_checker.get_benchmark_status(status)
|
| 183 |
+
print(f"Checks completed. Execution path: {execution_path}")
|
| 184 |
+
print(f"Status: {status}")
|
pia_bench/checker/sheet_checker.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional, Set, Tuple
|
| 2 |
+
import logging
|
| 3 |
+
import gspread
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
class SheetChecker:
|
| 10 |
+
def __init__(self, sheet_manager):
|
| 11 |
+
"""Initialize SheetChecker with a sheet manager instance."""
|
| 12 |
+
self.sheet_manager = sheet_manager
|
| 13 |
+
self.bench_sheet_manager = None
|
| 14 |
+
self.logger = logging.getLogger(__name__)
|
| 15 |
+
self._init_bench_sheet()
|
| 16 |
+
|
| 17 |
+
def _init_bench_sheet(self):
|
| 18 |
+
"""Initialize sheet manager for the model sheet."""
|
| 19 |
+
self.bench_sheet_manager = type(self.sheet_manager)(
|
| 20 |
+
spreadsheet_url=self.sheet_manager.spreadsheet_url,
|
| 21 |
+
worksheet_name="model",
|
| 22 |
+
column_name="Model name"
|
| 23 |
+
)
|
| 24 |
+
def add_benchmark_column(self, column_name: str):
|
| 25 |
+
"""Add a new benchmark column to the sheet."""
|
| 26 |
+
try:
|
| 27 |
+
# Get current headers
|
| 28 |
+
headers = self.bench_sheet_manager.get_available_columns()
|
| 29 |
+
|
| 30 |
+
# If column already exists, return
|
| 31 |
+
if column_name in headers:
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
# Add new column header
|
| 35 |
+
new_col_index = len(headers) + 1
|
| 36 |
+
cell = gspread.utils.rowcol_to_a1(1, new_col_index)
|
| 37 |
+
# Update with 2D array format
|
| 38 |
+
self.bench_sheet_manager.sheet.update(cell, [[column_name]]) # 값을 2D 배열로 변경
|
| 39 |
+
self.logger.info(f"Added new benchmark column: {column_name}")
|
| 40 |
+
|
| 41 |
+
# Update headers in bench_sheet_manager
|
| 42 |
+
self.bench_sheet_manager._connect_to_sheet(validate_column=False)
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
self.logger.error(f"Error adding benchmark column {column_name}: {str(e)}")
|
| 46 |
+
raise
|
| 47 |
+
def validate_benchmark_columns(self, benchmark_columns: List[str]) -> Tuple[List[str], List[str]]:
|
| 48 |
+
"""
|
| 49 |
+
Validate benchmark columns and add missing ones.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
benchmark_columns: List of benchmark column names to validate
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Tuple[List[str], List[str]]: (valid columns, invalid columns)
|
| 56 |
+
"""
|
| 57 |
+
available_columns = self.bench_sheet_manager.get_available_columns()
|
| 58 |
+
valid_columns = []
|
| 59 |
+
invalid_columns = []
|
| 60 |
+
|
| 61 |
+
for col in benchmark_columns:
|
| 62 |
+
if col in available_columns:
|
| 63 |
+
valid_columns.append(col)
|
| 64 |
+
else:
|
| 65 |
+
try:
|
| 66 |
+
self.add_benchmark_column(col)
|
| 67 |
+
valid_columns.append(col)
|
| 68 |
+
self.logger.info(f"Added new benchmark column: {col}")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
invalid_columns.append(col)
|
| 71 |
+
self.logger.error(f"Failed to add benchmark column '{col}': {str(e)}")
|
| 72 |
+
|
| 73 |
+
return valid_columns, invalid_columns
|
| 74 |
+
|
| 75 |
+
def check_model_and_benchmarks(self, model_name: str, benchmark_columns: List[str]) -> Dict[str, List[str]]:
|
| 76 |
+
"""
|
| 77 |
+
Check model existence and which benchmarks need to be filled.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_name: Name of the model to check
|
| 81 |
+
benchmark_columns: List of benchmark column names to check
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Dict with keys:
|
| 85 |
+
'status': 'model_not_found' or 'model_exists'
|
| 86 |
+
'empty_benchmarks': List of benchmark columns that need to be filled
|
| 87 |
+
'filled_benchmarks': List of benchmark columns that are already filled
|
| 88 |
+
'invalid_benchmarks': List of benchmark columns that don't exist
|
| 89 |
+
"""
|
| 90 |
+
result = {
|
| 91 |
+
'status': '',
|
| 92 |
+
'empty_benchmarks': [],
|
| 93 |
+
'filled_benchmarks': [],
|
| 94 |
+
'invalid_benchmarks': []
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
# First check if model exists
|
| 98 |
+
exists = self.check_model_exists(model_name)
|
| 99 |
+
if not exists:
|
| 100 |
+
result['status'] = 'model_not_found'
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
result['status'] = 'model_exists'
|
| 104 |
+
|
| 105 |
+
# Validate benchmark columns
|
| 106 |
+
valid_columns, invalid_columns = self.validate_benchmark_columns(benchmark_columns)
|
| 107 |
+
result['invalid_benchmarks'] = invalid_columns
|
| 108 |
+
|
| 109 |
+
if not valid_columns:
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
# Check which valid benchmarks are empty
|
| 113 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 114 |
+
all_values = self.bench_sheet_manager.get_all_values()
|
| 115 |
+
row_index = all_values.index(model_name) + 2
|
| 116 |
+
|
| 117 |
+
for column in valid_columns:
|
| 118 |
+
try:
|
| 119 |
+
self.bench_sheet_manager.change_column(column)
|
| 120 |
+
value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
|
| 121 |
+
if not value or not value.strip():
|
| 122 |
+
result['empty_benchmarks'].append(column)
|
| 123 |
+
else:
|
| 124 |
+
result['filled_benchmarks'].append(column)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
self.logger.error(f"Error checking column {column}: {str(e)}")
|
| 127 |
+
result['empty_benchmarks'].append(column)
|
| 128 |
+
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
def update_model_info(self, model_name: str, model_info: Dict[str, str]):
|
| 132 |
+
"""Update basic model information columns."""
|
| 133 |
+
try:
|
| 134 |
+
for column_name, value in model_info.items():
|
| 135 |
+
self.bench_sheet_manager.change_column(column_name)
|
| 136 |
+
self.bench_sheet_manager.push(value)
|
| 137 |
+
self.logger.info(f"Successfully added new model: {model_name}")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
self.logger.error(f"Error updating model info: {str(e)}")
|
| 140 |
+
raise
|
| 141 |
+
|
| 142 |
+
def update_benchmarks(self, model_name: str, benchmark_values: Dict[str, str]):
|
| 143 |
+
"""
|
| 144 |
+
Update benchmark values.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
model_name: Name of the model
|
| 148 |
+
benchmark_values: Dictionary of benchmark column names and their values
|
| 149 |
+
"""
|
| 150 |
+
try:
|
| 151 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 152 |
+
all_values = self.bench_sheet_manager.get_all_values()
|
| 153 |
+
row_index = all_values.index(model_name) + 2
|
| 154 |
+
|
| 155 |
+
for column, value in benchmark_values.items():
|
| 156 |
+
self.bench_sheet_manager.change_column(column)
|
| 157 |
+
self.bench_sheet_manager.sheet.update_cell(row_index, self.bench_sheet_manager.col_index, value)
|
| 158 |
+
self.logger.info(f"Updated benchmark {column} for model {model_name}")
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
self.logger.error(f"Error updating benchmarks: {str(e)}")
|
| 162 |
+
raise
|
| 163 |
+
|
| 164 |
+
def check_model_exists(self, model_name: str) -> bool:
|
| 165 |
+
"""Check if model exists in the sheet."""
|
| 166 |
+
try:
|
| 167 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 168 |
+
values = self.bench_sheet_manager.get_all_values()
|
| 169 |
+
return model_name in values
|
| 170 |
+
except Exception as e:
|
| 171 |
+
self.logger.error(f"Error checking model existence: {str(e)}")
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def process_model_benchmarks(
|
| 176 |
+
model_name: str,
|
| 177 |
+
bench_checker: SheetChecker,
|
| 178 |
+
model_info_func,
|
| 179 |
+
benchmark_processor_func: callable,
|
| 180 |
+
benchmark_columns: List[str],
|
| 181 |
+
cfg_prompt: str
|
| 182 |
+
) -> None:
|
| 183 |
+
"""
|
| 184 |
+
Process model benchmarks according to the specified workflow.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
model_name: Name of the model to process
|
| 188 |
+
bench_checker: SheetChecker instance
|
| 189 |
+
model_info_func: Function that returns model info (name, link, etc.)
|
| 190 |
+
benchmark_processor_func: Function that processes empty benchmarks and returns values
|
| 191 |
+
benchmark_columns: List of benchmark columns to check
|
| 192 |
+
"""
|
| 193 |
+
try:
|
| 194 |
+
# Check model and benchmarks
|
| 195 |
+
check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
|
| 196 |
+
|
| 197 |
+
# Handle invalid benchmark columns
|
| 198 |
+
if check_result['invalid_benchmarks']:
|
| 199 |
+
bench_checker.logger.warning(
|
| 200 |
+
f"Skipping invalid benchmark columns: {', '.join(check_result['invalid_benchmarks'])}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# If model doesn't exist, add it
|
| 204 |
+
if check_result['status'] == 'model_not_found':
|
| 205 |
+
model_info = model_info_func(model_name)
|
| 206 |
+
bench_checker.update_model_info(model_name, model_info)
|
| 207 |
+
bench_checker.logger.info(f"Added new model: {model_name}")
|
| 208 |
+
# Recheck benchmarks after adding model
|
| 209 |
+
check_result = bench_checker.check_model_and_benchmarks(model_name, benchmark_columns)
|
| 210 |
+
|
| 211 |
+
# Log filled benchmarks
|
| 212 |
+
if check_result['filled_benchmarks']:
|
| 213 |
+
bench_checker.logger.info(
|
| 214 |
+
f"Skipping filled benchmark columns: {', '.join(check_result['filled_benchmarks'])}"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Process empty benchmarks
|
| 218 |
+
if check_result['empty_benchmarks']:
|
| 219 |
+
bench_checker.logger.info(
|
| 220 |
+
f"Processing empty benchmark columns: {', '.join(check_result['empty_benchmarks'])}"
|
| 221 |
+
)
|
| 222 |
+
# Get benchmark values from processor function
|
| 223 |
+
benchmark_values = benchmark_processor_func(
|
| 224 |
+
model_name,
|
| 225 |
+
check_result['empty_benchmarks'],
|
| 226 |
+
cfg_prompt
|
| 227 |
+
)
|
| 228 |
+
# Update benchmarks
|
| 229 |
+
bench_checker.update_benchmarks(model_name, benchmark_values)
|
| 230 |
+
else:
|
| 231 |
+
bench_checker.logger.info("No empty benchmark columns to process")
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
bench_checker.logger.error(f"Error processing model {model_name}: {str(e)}")
|
| 235 |
+
raise
|
| 236 |
+
|
| 237 |
+
def get_model_info(model_name: str) -> Dict[str, str]:
|
| 238 |
+
return {
|
| 239 |
+
"Model name": model_name,
|
| 240 |
+
"Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
|
| 241 |
+
"Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 242 |
+
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
def process_benchmarks(
|
| 246 |
+
model_name: str,
|
| 247 |
+
empty_benchmarks: List[str],
|
| 248 |
+
cfg_prompt: str
|
| 249 |
+
) -> Dict[str, str]:
|
| 250 |
+
"""
|
| 251 |
+
Measure benchmark scores for given model with specific configuration.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
model_name: Name of the model to evaluate
|
| 255 |
+
empty_benchmarks: List of benchmarks to measure
|
| 256 |
+
cfg_prompt: Prompt configuration for evaluation
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Dict[str, str]: Dictionary mapping benchmark names to their scores
|
| 260 |
+
"""
|
| 261 |
+
result = {}
|
| 262 |
+
for benchmark in empty_benchmarks:
|
| 263 |
+
# 실제 벤치마크 측정 수행
|
| 264 |
+
# score = measure_benchmark(model_name, benchmark, cfg_prompt)
|
| 265 |
+
if benchmark == "COCO":
|
| 266 |
+
score = 0.5
|
| 267 |
+
elif benchmark == "ImageNet":
|
| 268 |
+
score = 15.0
|
| 269 |
+
result[benchmark] = str(score)
|
| 270 |
+
return result
|
| 271 |
+
# Example usage
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
|
| 274 |
+
sheet_manager = SheetManager()
|
| 275 |
+
bench_checker = SheetChecker(sheet_manager)
|
| 276 |
+
|
| 277 |
+
process_model_benchmarks(
|
| 278 |
+
"test-model",
|
| 279 |
+
bench_checker,
|
| 280 |
+
get_model_info,
|
| 281 |
+
process_benchmarks,
|
| 282 |
+
["COCO", "ImageNet"],
|
| 283 |
+
"cfg_prompt_value"
|
| 284 |
+
)
|
pia_bench/event_alarm.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
+
from devmacs_core.devmacs_core import DevMACSCore
|
| 6 |
+
# from devmacs_core.devmacs_core_copy import DevMACSCore
|
| 7 |
+
|
| 8 |
+
from devmacs_core.utils.common.cal import loose_similarity
|
| 9 |
+
from utils.parser import load_config, PromptManager
|
| 10 |
+
import json
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import logging
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from utils.except_dir import cust_listdir
|
| 16 |
+
|
| 17 |
+
class EventDetector:
|
| 18 |
+
def __init__(self, config_path: str , model_name:str = None, token:str = None):
|
| 19 |
+
self.config = load_config(config_path)
|
| 20 |
+
self.macs = DevMACSCore.from_huggingface(token=token, repo_id=f"PIA-SPACE-LAB/{model_name}")
|
| 21 |
+
# self.macs = DevMACSCore(model_type="clip4clip_web")
|
| 22 |
+
|
| 23 |
+
self.prompt_manager = PromptManager(config_path)
|
| 24 |
+
self.sentences = self.prompt_manager.sentences
|
| 25 |
+
self.text_vectors = self.macs.get_text_vector(self.sentences)
|
| 26 |
+
|
| 27 |
+
def process_and_save_predictions(self, vector_base_dir: str, label_base_dir: str, save_base_dir: str):
|
| 28 |
+
"""비디오 벡터를 처리하고 결과를 CSV로 저장"""
|
| 29 |
+
|
| 30 |
+
# 전체 비디오 파일 수 계산
|
| 31 |
+
total_videos = sum(len([f for f in cust_listdir(os.path.join(vector_base_dir, d))
|
| 32 |
+
if f.endswith('.npy')])
|
| 33 |
+
for d in cust_listdir(vector_base_dir)
|
| 34 |
+
if os.path.isdir(os.path.join(vector_base_dir, d)))
|
| 35 |
+
pbar = tqdm(total=total_videos, desc="Processing videos")
|
| 36 |
+
|
| 37 |
+
for category in cust_listdir(vector_base_dir):
|
| 38 |
+
category_path = os.path.join(vector_base_dir, category)
|
| 39 |
+
if not os.path.isdir(category_path):
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
# 저장 디렉토리 생성
|
| 43 |
+
save_category_dir = os.path.join(save_base_dir, category)
|
| 44 |
+
os.makedirs(save_category_dir, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
for file in cust_listdir(category_path):
|
| 47 |
+
if file.endswith('.npy'):
|
| 48 |
+
video_name = os.path.splitext(file)[0]
|
| 49 |
+
vector_path = os.path.join(category_path, file)
|
| 50 |
+
|
| 51 |
+
# 라벨 파일 읽기
|
| 52 |
+
label_path = os.path.join(label_base_dir, category, f"{video_name}.json")
|
| 53 |
+
with open(label_path, 'r') as f:
|
| 54 |
+
label_data = json.load(f)
|
| 55 |
+
total_frames = label_data['video_info']['total_frame']
|
| 56 |
+
|
| 57 |
+
# 예측 결과 생성 및 저장
|
| 58 |
+
self._process_and_save_single_video(
|
| 59 |
+
vector_path=vector_path,
|
| 60 |
+
total_frames=total_frames,
|
| 61 |
+
save_path=os.path.join(save_category_dir, f"{video_name}.csv")
|
| 62 |
+
)
|
| 63 |
+
pbar.update(1)
|
| 64 |
+
pbar.close()
|
| 65 |
+
|
| 66 |
+
def _process_and_save_single_video(self, vector_path: str, total_frames: int, save_path: str):
|
| 67 |
+
"""단일 비디오 처리 및 저장"""
|
| 68 |
+
# 기본 예측 수행
|
| 69 |
+
sparse_predictions = self._process_single_vector(vector_path)
|
| 70 |
+
|
| 71 |
+
# 데이터프레임으로 변환 및 확장
|
| 72 |
+
df = self._expand_predictions(sparse_predictions, total_frames)
|
| 73 |
+
|
| 74 |
+
# CSV로 저장
|
| 75 |
+
df.to_csv(save_path, index=False)
|
| 76 |
+
|
| 77 |
+
def _process_single_vector(self, vector_path: str) -> Dict:
|
| 78 |
+
"""기존 예측 로직"""
|
| 79 |
+
video_vector = np.load(vector_path)
|
| 80 |
+
processed_vectors = []
|
| 81 |
+
frame_interval = 15
|
| 82 |
+
|
| 83 |
+
for vector in video_vector:
|
| 84 |
+
v = vector.squeeze(0) # numpy array
|
| 85 |
+
v = torch.from_numpy(v).unsqueeze(0).cuda() # torch tensor로 변환 후 GPU로
|
| 86 |
+
processed_vectors.append(v)
|
| 87 |
+
|
| 88 |
+
frame_results = {}
|
| 89 |
+
for vector_idx, v in enumerate(processed_vectors):
|
| 90 |
+
actual_frame = vector_idx * frame_interval
|
| 91 |
+
sim_scores = loose_similarity(
|
| 92 |
+
sequence_output=self.text_vectors.cuda(),
|
| 93 |
+
visual_output=v.unsqueeze(1)
|
| 94 |
+
)
|
| 95 |
+
frame_results[actual_frame] = self._calculate_alarms(sim_scores)
|
| 96 |
+
|
| 97 |
+
return frame_results
|
| 98 |
+
|
| 99 |
+
def _expand_predictions(self, sparse_predictions: Dict, total_frames: int) -> pd.DataFrame:
|
| 100 |
+
"""예측을 전체 프레임으로 확장"""
|
| 101 |
+
# 카테고리 목록 추출 (첫 번째 프레임의 알람 결과에서)
|
| 102 |
+
first_frame = list(sparse_predictions.keys())[0]
|
| 103 |
+
categories = list(sparse_predictions[first_frame].keys())
|
| 104 |
+
|
| 105 |
+
# 전체 프레임 생성
|
| 106 |
+
df = pd.DataFrame({'frame': range(total_frames)})
|
| 107 |
+
|
| 108 |
+
# 각 카테고리에 대한 알람 값 초기화
|
| 109 |
+
for category in categories:
|
| 110 |
+
df[category] = 0
|
| 111 |
+
|
| 112 |
+
# 예측값 채우기
|
| 113 |
+
frame_keys = sorted(sparse_predictions.keys())
|
| 114 |
+
for i in range(len(frame_keys)):
|
| 115 |
+
current_frame = frame_keys[i]
|
| 116 |
+
next_frame = frame_keys[i + 1] if i + 1 < len(frame_keys) else total_frames
|
| 117 |
+
|
| 118 |
+
# 각 카테고리의 알람 값 설정
|
| 119 |
+
for category in categories:
|
| 120 |
+
alarm_value = sparse_predictions[current_frame][category]['alarm']
|
| 121 |
+
df.loc[current_frame:next_frame-1, category] = alarm_value
|
| 122 |
+
|
| 123 |
+
return df
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _calculate_alarms(self, sim_scores: torch.Tensor) -> Dict:
|
| 127 |
+
"""유사도 점수를 기반으로 각 이벤트의 알람 상태 계산"""
|
| 128 |
+
# 로거 설정
|
| 129 |
+
log_filename = f"alarm_calculation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
| 130 |
+
logging.basicConfig(
|
| 131 |
+
filename=log_filename,
|
| 132 |
+
level=logging.ERROR,
|
| 133 |
+
format='%(asctime)s - %(message)s',
|
| 134 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 135 |
+
)
|
| 136 |
+
logger = logging.getLogger(__name__)
|
| 137 |
+
|
| 138 |
+
event_alarms = {}
|
| 139 |
+
|
| 140 |
+
for event_config in self.config['PROMPT_CFG']:
|
| 141 |
+
event = event_config['event']
|
| 142 |
+
top_k = event_config['top_candidates']
|
| 143 |
+
threshold = event_config['alert_threshold']
|
| 144 |
+
|
| 145 |
+
# logger.info(f"\nProcessing event: {event}")
|
| 146 |
+
# logger.info(f"Top K: {top_k}, Threshold: {threshold}")
|
| 147 |
+
|
| 148 |
+
event_prompts = self._get_event_prompts(event)
|
| 149 |
+
|
| 150 |
+
# logger.debug(f"\nEvent Prompts Debug for {event}:")
|
| 151 |
+
# logger.debug(f"Indices: {event_prompts['indices']}")
|
| 152 |
+
# logger.debug(f"Types: {event_prompts['types']}")
|
| 153 |
+
# logger.debug(f"\nSim Scores Debug:")
|
| 154 |
+
# logger.debug(f"Shape: {sim_scores.shape}")
|
| 155 |
+
# logger.debug(f"Raw scores: {sim_scores}")
|
| 156 |
+
|
| 157 |
+
# event_scores = sim_scores[event_prompts['indices']]
|
| 158 |
+
event_scores = sim_scores[event_prompts['indices']].squeeze(-1) # shape 변경
|
| 159 |
+
|
| 160 |
+
# logger.debug(f"Event scores shape: {event_scores.shape}")
|
| 161 |
+
# logger.debug(f"Event scores: {event_scores}")
|
| 162 |
+
# 각 프롬프트와 점수 출력
|
| 163 |
+
# logger.info("\nDEBUG VALUES:")
|
| 164 |
+
# logger.info(f"event_scores: {event_scores}")
|
| 165 |
+
# logger.info(f"indices: {event_prompts['indices']}")
|
| 166 |
+
# logger.info(f"types: {event_prompts['types']}")
|
| 167 |
+
|
| 168 |
+
# logger.info("\nAll prompts and scores:")
|
| 169 |
+
# for idx, (score, prompt_type) in enumerate(zip(event_scores, event_prompts['types'])):
|
| 170 |
+
# logger.info(f"Type: {prompt_type}, Score: {score.item():.4f}")
|
| 171 |
+
|
| 172 |
+
top_k_values, top_k_indices = torch.topk(event_scores, min(top_k, len(event_scores)))
|
| 173 |
+
|
| 174 |
+
# logger.info(f"top_k_values: {top_k_values}")
|
| 175 |
+
# logger.info(f"top_k_indices (raw): {top_k_indices}")
|
| 176 |
+
# Top K 결과 출력
|
| 177 |
+
# logger.info(f"\nTop {top_k} selections:")
|
| 178 |
+
for idx, (value, index) in enumerate(zip(top_k_values, top_k_indices)):
|
| 179 |
+
# indices[index]가 아닌 index를 직접 사용
|
| 180 |
+
prompt_type = event_prompts['types'][index] # 수정된 부분
|
| 181 |
+
# logger.info(f"DEBUG: index={index}, types={event_prompts['types']}, selected_type={prompt_type}")
|
| 182 |
+
# logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
|
| 183 |
+
|
| 184 |
+
abnormal_count = sum(1 for idx in top_k_indices
|
| 185 |
+
if event_prompts['types'][idx] == 'abnormal') # 수정된 부분
|
| 186 |
+
# for idx, (value, orig_idx) in enumerate(zip(top_k_values, top_k_indices)):
|
| 187 |
+
# prompt_type = event_prompts['types'][orig_idx.item()]
|
| 188 |
+
# logger.info(f"Rank {idx+1}: Type: {prompt_type}, Score: {value.item():.4f}")
|
| 189 |
+
|
| 190 |
+
# abnormal_count = sum(1 for idx in top_k_indices
|
| 191 |
+
# if event_prompts['types'][idx.item()] == 'abnormal')
|
| 192 |
+
|
| 193 |
+
# 알람 결정 과정 출력
|
| 194 |
+
# logger.info(f"\nAbnormal count: {abnormal_count}")
|
| 195 |
+
alarm_result = 1 if abnormal_count >= threshold else 0
|
| 196 |
+
# logger.info(f"Final alarm decision: {alarm_result}")
|
| 197 |
+
# logger.info("-" * 50)
|
| 198 |
+
|
| 199 |
+
event_alarms[event] = {
|
| 200 |
+
'alarm': alarm_result,
|
| 201 |
+
'scores': top_k_values.tolist(),
|
| 202 |
+
'top_k_types': [event_prompts['types'][idx.item()] for idx in top_k_indices]
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
# 로거 종료
|
| 206 |
+
logging.shutdown()
|
| 207 |
+
|
| 208 |
+
return event_alarms
|
| 209 |
+
|
| 210 |
+
def _get_event_prompts(self, event: str) -> Dict:
|
| 211 |
+
indices = []
|
| 212 |
+
types = []
|
| 213 |
+
current_idx = 0
|
| 214 |
+
|
| 215 |
+
for event_config in self.config['PROMPT_CFG']:
|
| 216 |
+
if event_config['event'] == event:
|
| 217 |
+
for status in ['normal', 'abnormal']:
|
| 218 |
+
for _ in range(len(event_config['prompts'][status])):
|
| 219 |
+
indices.append(current_idx)
|
| 220 |
+
types.append(status)
|
| 221 |
+
current_idx += 1
|
| 222 |
+
|
| 223 |
+
return {'indices': indices, 'types': types}
|
| 224 |
+
|
| 225 |
+
|
pia_bench/metric.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
import json
|
| 7 |
+
from utils.except_dir import cust_listdir
|
| 8 |
+
|
| 9 |
+
class MetricsEvaluator:
|
| 10 |
+
def __init__(self, pred_dir: str, label_dir: str, save_dir: str):
|
| 11 |
+
"""
|
| 12 |
+
Args:
|
| 13 |
+
pred_dir: 예측 csv 파일들이 있는 디렉토리 경로
|
| 14 |
+
label_dir: 정답 csv 파일들이 있는 디렉토리 경로
|
| 15 |
+
save_dir: 결과를 저장할 디렉토리 경로
|
| 16 |
+
"""
|
| 17 |
+
self.pred_dir = pred_dir
|
| 18 |
+
self.label_dir = label_dir
|
| 19 |
+
self.save_dir = save_dir
|
| 20 |
+
|
| 21 |
+
def evaluate(self) -> Dict:
|
| 22 |
+
"""전체 평가 수행"""
|
| 23 |
+
category_metrics = {} # 카테고리별 평균 성능 저장
|
| 24 |
+
all_metrics = { # 모든 카테고리 통합 메트릭
|
| 25 |
+
'falldown': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
|
| 26 |
+
'violence': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []},
|
| 27 |
+
'fire': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'specificity': []}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
# 모든 카테고리의 metrics를 저장할 DataFrame 리스트
|
| 31 |
+
all_categories_metrics = []
|
| 32 |
+
|
| 33 |
+
for category in cust_listdir(self.pred_dir):
|
| 34 |
+
if not os.path.isdir(os.path.join(self.pred_dir, category)):
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
pred_category_path = os.path.join(self.pred_dir, category)
|
| 38 |
+
label_category_path = os.path.join(self.label_dir, category)
|
| 39 |
+
save_category_path = os.path.join(self.save_dir, category)
|
| 40 |
+
os.makedirs(save_category_path, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# 결과 저장을 위한 데이터프레임 생성
|
| 43 |
+
metrics_df = self._evaluate_category(category, pred_category_path, label_category_path)
|
| 44 |
+
|
| 45 |
+
metrics_df['category'] = category
|
| 46 |
+
|
| 47 |
+
metrics_df.to_csv(os.path.join(save_category_path, f"{category}_metrics.csv"), index=False)
|
| 48 |
+
|
| 49 |
+
all_categories_metrics.append(metrics_df)
|
| 50 |
+
|
| 51 |
+
# 카테고리별 평균 성능 저장
|
| 52 |
+
category_metrics[category] = metrics_df.iloc[-1].to_dict() # 마지막 row(평균)
|
| 53 |
+
|
| 54 |
+
# 전체 평균을 위한 메트릭 수집
|
| 55 |
+
# for col in metrics_df.columns:
|
| 56 |
+
# if col != 'video_name':
|
| 57 |
+
# event_type, metric_type = col.split('_')
|
| 58 |
+
# all_metrics[event_type][metric_type].append(category_metrics[category][col])
|
| 59 |
+
|
| 60 |
+
for col in metrics_df.columns:
|
| 61 |
+
if col != 'video_name':
|
| 62 |
+
try:
|
| 63 |
+
# 첫 번째 언더스코어를 기준으로 이벤트 타입과 메트릭 타입 분리
|
| 64 |
+
parts = col.split('_', 1) # maxsplit=1로 첫 번째 언더스코어에서만 분리
|
| 65 |
+
if len(parts) == 2:
|
| 66 |
+
event_type, metric_type = parts
|
| 67 |
+
if event_type in all_metrics and metric_type in all_metrics[event_type]:
|
| 68 |
+
all_metrics[event_type][metric_type].append(category_metrics[category][col])
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Warning: Could not process column {col}: {str(e)}")
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
# 각 DataFrame에서 마지막 행(average)을 제거
|
| 74 |
+
all_categories_metrics_without_avg = [df.iloc[:-1] for df in all_categories_metrics]
|
| 75 |
+
# 모든 카테고리의 metrics를 하나의 DataFrame으로 합치기
|
| 76 |
+
combined_metrics_df = pd.concat(all_categories_metrics_without_avg, ignore_index=True)
|
| 77 |
+
# 합쳐진 metrics를 json 파일과 같은 위치에 저장
|
| 78 |
+
combined_metrics_df.to_csv(os.path.join(self.save_dir, "all_categories_metrics.csv"), index=False)
|
| 79 |
+
# 결과 출력
|
| 80 |
+
# print("\nCategory-wise Average Metrics:")
|
| 81 |
+
# for category, metrics in category_metrics.items():
|
| 82 |
+
# print(f"\n{category}:")
|
| 83 |
+
# for metric_name, value in metrics.items():
|
| 84 |
+
# if metric_name != "video_name":
|
| 85 |
+
# print(f"{metric_name}: {value:.3f}")
|
| 86 |
+
|
| 87 |
+
print("\nCategory-wise Average Metrics:")
|
| 88 |
+
for category, metrics in category_metrics.items():
|
| 89 |
+
print(f"\n{category}:")
|
| 90 |
+
for metric_name, value in metrics.items():
|
| 91 |
+
if metric_name != "video_name":
|
| 92 |
+
try:
|
| 93 |
+
if isinstance(value, str):
|
| 94 |
+
print(f"{metric_name}: {value}")
|
| 95 |
+
elif metric_name in ['tp', 'tn', 'fp', 'fn']:
|
| 96 |
+
print(f"{metric_name}: {int(value)}")
|
| 97 |
+
else:
|
| 98 |
+
print(f"{metric_name}: {float(value):.3f}")
|
| 99 |
+
except (ValueError, TypeError):
|
| 100 |
+
print(f"{metric_name}: {value}")
|
| 101 |
+
# 전체 평균 계산 및 출력
|
| 102 |
+
print("\n" + "="*50)
|
| 103 |
+
print("Overall Average Metrics Across All Categories:")
|
| 104 |
+
print("="*50)
|
| 105 |
+
|
| 106 |
+
# for event_type in all_metrics:
|
| 107 |
+
# print(f"\n{event_type}:")
|
| 108 |
+
# for metric_type, values in all_metrics[event_type].items():
|
| 109 |
+
# avg_value = np.mean(values)
|
| 110 |
+
# print(f"{metric_type}: {avg_value:.3f}")
|
| 111 |
+
|
| 112 |
+
for event_type in all_metrics:
|
| 113 |
+
print(f"\n{event_type}:")
|
| 114 |
+
for metric_type, values in all_metrics[event_type].items():
|
| 115 |
+
avg_value = np.mean(values)
|
| 116 |
+
if metric_type in ['tp', 'tn', 'fp', 'fn']: # 정수 값
|
| 117 |
+
print(f"{metric_type}: {int(avg_value)}")
|
| 118 |
+
else: # 소수점 값
|
| 119 |
+
print(f"{metric_type}: {avg_value:.3f}")
|
| 120 |
+
##################################################################################################
|
| 121 |
+
# 최종 결과를 저장할 딕셔너리
|
| 122 |
+
final_results = {
|
| 123 |
+
"category_metrics": {},
|
| 124 |
+
"overall_metrics": {}
|
| 125 |
+
}
|
| 126 |
+
# 카테고리별 메트릭 저장
|
| 127 |
+
|
| 128 |
+
for category, metrics in category_metrics.items():
|
| 129 |
+
final_results["category_metrics"][category] = {}
|
| 130 |
+
for metric_name, value in metrics.items():
|
| 131 |
+
if metric_name != "video_name":
|
| 132 |
+
if isinstance(value, (int, float)):
|
| 133 |
+
final_results["category_metrics"][category][metric_name] = float(value)
|
| 134 |
+
|
| 135 |
+
# 전체 평균 계산 및 저장
|
| 136 |
+
for event_type in all_metrics:
|
| 137 |
+
# print(f"\n{event_type}:")
|
| 138 |
+
final_results["overall_metrics"][event_type] = {}
|
| 139 |
+
for metric_type, values in all_metrics[event_type].items():
|
| 140 |
+
avg_value = float(np.mean(values))
|
| 141 |
+
# print(f"{metric_type}: {avg_value:.3f}")
|
| 142 |
+
final_results["overall_metrics"][event_type][metric_type] = avg_value
|
| 143 |
+
|
| 144 |
+
# JSON 파일로 저장
|
| 145 |
+
json_path = os.path.join(self.save_dir, "overall_metrics.json")
|
| 146 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 147 |
+
json.dump(final_results, f, indent=4)
|
| 148 |
+
|
| 149 |
+
# return category_metrics
|
| 150 |
+
|
| 151 |
+
# 누적 메트릭 계산
|
| 152 |
+
accumulated_metrics = self.calculate_accumulated_metrics(combined_metrics_df)
|
| 153 |
+
|
| 154 |
+
# JSON에 누적 메트릭 추가
|
| 155 |
+
final_results["accumulated_metrics"] = accumulated_metrics
|
| 156 |
+
|
| 157 |
+
# 누적 메트릭만 따로 저장
|
| 158 |
+
accumulated_json_path = os.path.join(self.save_dir, "accumulated_metrics.json")
|
| 159 |
+
with open(accumulated_json_path, 'w', encoding='utf-8') as f:
|
| 160 |
+
json.dump(accumulated_metrics, f, indent=4)
|
| 161 |
+
|
| 162 |
+
return accumulated_metrics
|
| 163 |
+
|
| 164 |
+
def _evaluate_category(self, category: str, pred_path: str, label_path: str) -> pd.DataFrame:
|
| 165 |
+
"""카테고리별 평가 수행"""
|
| 166 |
+
results = []
|
| 167 |
+
metrics_columns = ['video_name']
|
| 168 |
+
|
| 169 |
+
for pred_file in cust_listdir(pred_path):
|
| 170 |
+
if not pred_file.endswith('.csv'):
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
video_name = os.path.splitext(pred_file)[0]
|
| 174 |
+
pred_df = pd.read_csv(os.path.join(pred_path, pred_file))
|
| 175 |
+
|
| 176 |
+
# 해당 비디오의 정답 CSV 파일 로드
|
| 177 |
+
label_file = f"{video_name}.csv"
|
| 178 |
+
label_path_full = os.path.join(label_path, label_file)
|
| 179 |
+
|
| 180 |
+
if not os.path.exists(label_path_full):
|
| 181 |
+
print(f"Warning: Label file not found for {video_name}")
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
label_df = pd.read_csv(label_path_full)
|
| 185 |
+
|
| 186 |
+
# 각 카테고리별 메트릭 계산
|
| 187 |
+
video_metrics = {'video_name': video_name}
|
| 188 |
+
categories = [col for col in pred_df.columns if col != 'frame']
|
| 189 |
+
|
| 190 |
+
for cat in categories:
|
| 191 |
+
# 정답값과 예측값
|
| 192 |
+
y_true = label_df[cat].values
|
| 193 |
+
y_pred = pred_df[cat].values
|
| 194 |
+
|
| 195 |
+
# 메트릭 계산
|
| 196 |
+
metrics = self._calculate_metrics(y_true, y_pred)
|
| 197 |
+
|
| 198 |
+
# 결과 저장
|
| 199 |
+
for metric_name, value in metrics.items():
|
| 200 |
+
col_name = f"{cat}_{metric_name}"
|
| 201 |
+
video_metrics[col_name] = value
|
| 202 |
+
if col_name not in metrics_columns:
|
| 203 |
+
metrics_columns.append(col_name)
|
| 204 |
+
|
| 205 |
+
results.append(video_metrics)
|
| 206 |
+
|
| 207 |
+
# 결과를 데이터프레임으로 변환
|
| 208 |
+
metrics_df = pd.DataFrame(results, columns=metrics_columns)
|
| 209 |
+
|
| 210 |
+
# 평균 계산하여 추가
|
| 211 |
+
avg_metrics = {'video_name': 'average'}
|
| 212 |
+
for col in metrics_columns[1:]: # video_name 제외
|
| 213 |
+
avg_metrics[col] = metrics_df[col].mean()
|
| 214 |
+
|
| 215 |
+
metrics_df = pd.concat([metrics_df, pd.DataFrame([avg_metrics])], ignore_index=True)
|
| 216 |
+
|
| 217 |
+
return metrics_df
|
| 218 |
+
|
| 219 |
+
# def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
|
| 220 |
+
# """성능 지표 계산"""
|
| 221 |
+
# tn = np.sum((y_true == 0) & (y_pred == 0))
|
| 222 |
+
# fp = np.sum((y_true == 0) & (y_pred == 1))
|
| 223 |
+
|
| 224 |
+
# metrics = {
|
| 225 |
+
# 'f1': f1_score(y_true, y_pred, zero_division=0),
|
| 226 |
+
# 'accuracy': accuracy_score(y_true, y_pred),
|
| 227 |
+
# 'precision': precision_score(y_true, y_pred, zero_division=0),
|
| 228 |
+
# 'recall': recall_score(y_true, y_pred, zero_division=0),
|
| 229 |
+
# 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 230 |
+
# }
|
| 231 |
+
|
| 232 |
+
# return metrics
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def calculate_accumulated_metrics(self, all_categories_metrics_df: pd.DataFrame) -> Dict:
|
| 236 |
+
"""누적된 혼동행렬로 각 카테고리별 성능 지표 계산"""
|
| 237 |
+
accumulated_results = {"micro_avg": {}}
|
| 238 |
+
categories = ['falldown', 'violence', 'fire']
|
| 239 |
+
|
| 240 |
+
for category in categories:
|
| 241 |
+
# 해당 카테고리의 혼동행렬 값들 누적
|
| 242 |
+
tp = all_categories_metrics_df[f'{category}_tp'].sum()
|
| 243 |
+
tn = all_categories_metrics_df[f'{category}_tn'].sum()
|
| 244 |
+
fp = all_categories_metrics_df[f'{category}_fp'].sum()
|
| 245 |
+
fn = all_categories_metrics_df[f'{category}_fn'].sum()
|
| 246 |
+
|
| 247 |
+
# 기본 메트릭 계산
|
| 248 |
+
metrics = {
|
| 249 |
+
'tp': int(tp),
|
| 250 |
+
'tn': int(tn),
|
| 251 |
+
'fp': int(fp),
|
| 252 |
+
'fn': int(fn),
|
| 253 |
+
'accuracy': (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0,
|
| 254 |
+
'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
|
| 255 |
+
'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
|
| 256 |
+
'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
|
| 257 |
+
'f1': 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# 추가 메트릭 계산
|
| 261 |
+
tpr = metrics['recall'] # TPR = recall
|
| 262 |
+
tnr = metrics['specificity'] # TNR = specificity
|
| 263 |
+
|
| 264 |
+
# Balanced Accuracy
|
| 265 |
+
metrics['balanced_accuracy'] = (tpr + tnr) / 2
|
| 266 |
+
|
| 267 |
+
# G-Mean
|
| 268 |
+
metrics['g_mean'] = np.sqrt(tpr * tnr) if (tpr * tnr) > 0 else 0
|
| 269 |
+
|
| 270 |
+
# MCC (Matthews Correlation Coefficient)
|
| 271 |
+
numerator = (tp * tn) - (fp * fn)
|
| 272 |
+
denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
|
| 273 |
+
metrics['mcc'] = numerator / denominator if denominator > 0 else 0
|
| 274 |
+
|
| 275 |
+
# NPV (Negative Predictive Value)
|
| 276 |
+
metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
|
| 277 |
+
|
| 278 |
+
# FAR (False Alarm Rate) = FPR = 1 - specificity
|
| 279 |
+
metrics['far'] = 1 - metrics['specificity']
|
| 280 |
+
|
| 281 |
+
accumulated_results[category] = metrics
|
| 282 |
+
|
| 283 |
+
# 전체 카테고리의 누적 값으로 계산
|
| 284 |
+
total_tp = sum(accumulated_results[cat]['tp'] for cat in categories)
|
| 285 |
+
total_tn = sum(accumulated_results[cat]['tn'] for cat in categories)
|
| 286 |
+
total_fp = sum(accumulated_results[cat]['fp'] for cat in categories)
|
| 287 |
+
total_fn = sum(accumulated_results[cat]['fn'] for cat in categories)
|
| 288 |
+
|
| 289 |
+
# micro average 계산 (전체 누적 값으로 계산)
|
| 290 |
+
accumulated_results["micro_avg"] = {
|
| 291 |
+
'tp': int(total_tp),
|
| 292 |
+
'tn': int(total_tn),
|
| 293 |
+
'fp': int(total_fp),
|
| 294 |
+
'fn': int(total_fn),
|
| 295 |
+
'accuracy': (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn),
|
| 296 |
+
'precision': total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0,
|
| 297 |
+
'recall': total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0,
|
| 298 |
+
'f1': 2 * total_tp / (2 * total_tp + total_fp + total_fn) if (2 * total_tp + total_fp + total_fn) > 0 else 0,
|
| 299 |
+
# ... (다른 메트릭들도 동일한 방식으로 계산)
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
return accumulated_results
|
| 303 |
+
def _calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
|
| 304 |
+
"""성능 지표 계산"""
|
| 305 |
+
tn = np.sum((y_true == 0) & (y_pred == 0))
|
| 306 |
+
fp = np.sum((y_true == 0) & (y_pred == 1))
|
| 307 |
+
fn = np.sum((y_true == 1) & (y_pred == 0))
|
| 308 |
+
tp = np.sum((y_true == 1) & (y_pred == 1))
|
| 309 |
+
|
| 310 |
+
metrics = {
|
| 311 |
+
'f1': f1_score(y_true, y_pred, zero_division=0),
|
| 312 |
+
'accuracy': accuracy_score(y_true, y_pred),
|
| 313 |
+
'precision': precision_score(y_true, y_pred, zero_division=0),
|
| 314 |
+
'recall': recall_score(y_true, y_pred, zero_division=0),
|
| 315 |
+
'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
|
| 316 |
+
'tp': int(tp),
|
| 317 |
+
'tn': int(tn),
|
| 318 |
+
'fp': int(fp),
|
| 319 |
+
'fn': int(fn)
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
return metrics
|
pia_bench/pipe_line/piepline.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pia_bench.checker.bench_checker import BenchChecker
|
| 2 |
+
from pia_bench.checker.sheet_checker import SheetChecker
|
| 3 |
+
from pia_bench.event_alarm import EventDetector
|
| 4 |
+
from pia_bench.metric import MetricsEvaluator
|
| 5 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 6 |
+
from pia_bench.bench import PiaBenchMark
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from typing import Optional, List , Dict
|
| 9 |
+
import os
|
| 10 |
+
load_dotenv()
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Dict, Tuple
|
| 13 |
+
from typing import Dict, Optional, Tuple
|
| 14 |
+
import logging
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from sheet_manager.sheet_checker.sheet_check import SheetChecker
|
| 17 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 18 |
+
from pia_bench.checker.bench_checker import BenchChecker
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class PipelineConfig:
|
| 24 |
+
"""파이프라인 설정을 위한 데이터 클래스"""
|
| 25 |
+
model_name: str
|
| 26 |
+
benchmark_name: str
|
| 27 |
+
cfg_target_path: str
|
| 28 |
+
base_path: str = "/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
|
| 29 |
+
|
| 30 |
+
class BenchmarkPipelineStatus:
|
| 31 |
+
"""파이프라인 상태 및 결과 관리"""
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.sheet_status: Tuple[bool, bool] = (False, False) # (model_added, benchmark_exists)
|
| 34 |
+
self.bench_status: Dict[str, bool] = {}
|
| 35 |
+
self.bench_result: str = ""
|
| 36 |
+
self.current_stage: str = "not_started"
|
| 37 |
+
|
| 38 |
+
def is_success(self) -> bool:
|
| 39 |
+
"""전체 파이프라인 성공 여부"""
|
| 40 |
+
return (not self.sheet_status[0] # 모델이 이미 존재하고
|
| 41 |
+
and self.sheet_status[1] # 벤치마크가 존재하고
|
| 42 |
+
and self.bench_result == "all_passed") # 벤치마크 체크도 통과
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return (f"Current Stage: {self.current_stage}\n"
|
| 46 |
+
f"Sheet Status: {self.sheet_status}\n"
|
| 47 |
+
f"Bench Status: {self.bench_status}\n"
|
| 48 |
+
f"Bench Result: {self.bench_result}")
|
| 49 |
+
|
| 50 |
+
class BenchmarkPipeline:
|
| 51 |
+
"""벤치마크 실행을 위한 파이프라인"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, config: PipelineConfig):
|
| 54 |
+
self.config = config
|
| 55 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 56 |
+
self.status = BenchmarkPipelineStatus()
|
| 57 |
+
self.access_token = os.getenv("ACCESS_TOKEN")
|
| 58 |
+
self.cfg_prompt = os.path.splitext(os.path.basename(self.config.cfg_target_path))[0]
|
| 59 |
+
|
| 60 |
+
# Initialize checkers
|
| 61 |
+
self.sheet_manager = SheetManager()
|
| 62 |
+
self.sheet_checker = SheetChecker(self.sheet_manager)
|
| 63 |
+
self.bench_checker = BenchChecker(self.config.base_path)
|
| 64 |
+
|
| 65 |
+
self.bench_result_dict = None
|
| 66 |
+
|
| 67 |
+
def run(self) -> BenchmarkPipelineStatus:
|
| 68 |
+
"""전체 파이프라인 실행"""
|
| 69 |
+
try:
|
| 70 |
+
self.status.current_stage = "sheet_check"
|
| 71 |
+
proceed = self._check_sheet()
|
| 72 |
+
|
| 73 |
+
if not proceed:
|
| 74 |
+
self.status.current_stage = "completed_no_action_needed"
|
| 75 |
+
self.logger.info("벤치마크가 이미 존재하여 추가 작업이 필요하지 않습니다.")
|
| 76 |
+
return self.status
|
| 77 |
+
|
| 78 |
+
self.status.current_stage = "bench_check"
|
| 79 |
+
if not self._check_bench():
|
| 80 |
+
return self.status
|
| 81 |
+
|
| 82 |
+
self.status.current_stage = "execution"
|
| 83 |
+
self._execute_based_on_status()
|
| 84 |
+
|
| 85 |
+
self.status.current_stage = "completed"
|
| 86 |
+
return self.status
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
self.logger.error(f"파이프라인 실행 중 에러 발생: {str(e)}")
|
| 90 |
+
self.status.current_stage = "error"
|
| 91 |
+
return self.status
|
| 92 |
+
|
| 93 |
+
def _check_sheet(self) -> bool:
|
| 94 |
+
"""구글 시트 상태 체크"""
|
| 95 |
+
self.logger.info("시트 상태 체크 시작")
|
| 96 |
+
model_added, benchmark_exists = self.sheet_checker.check_model_and_benchmark(
|
| 97 |
+
self.config.model_name,
|
| 98 |
+
self.config.benchmark_name
|
| 99 |
+
)
|
| 100 |
+
self.status.sheet_status = (model_added, benchmark_exists)
|
| 101 |
+
|
| 102 |
+
if model_added:
|
| 103 |
+
self.logger.info("새로운 모델이 추가되었습니다")
|
| 104 |
+
if not benchmark_exists:
|
| 105 |
+
self.logger.info("벤치마크 측정이 필요합니다")
|
| 106 |
+
return True # 벤치마크 측정이 필요한 경우만 다음 단계로 진행
|
| 107 |
+
|
| 108 |
+
self.logger.info("이미 벤치마크가 존재합니다. 파이프라인을 종료합니다.")
|
| 109 |
+
return False # 벤치마크가 이미 있으면 여기서 중단
|
| 110 |
+
|
| 111 |
+
def _check_bench(self) -> bool:
|
| 112 |
+
"""로컬 벤치마크 환경 체크"""
|
| 113 |
+
self.logger.info("벤치마크 환경 체크 시작")
|
| 114 |
+
self.status.bench_status = self.bench_checker.check_benchmark(
|
| 115 |
+
self.config.benchmark_name,
|
| 116 |
+
self.config.model_name,
|
| 117 |
+
self.cfg_prompt
|
| 118 |
+
)
|
| 119 |
+
self.status.bench_result = self.bench_checker.get_benchmark_status(
|
| 120 |
+
self.status.bench_status
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# no bench 상태 벤치를 돌린적이 없음 폴더구조도 없음
|
| 124 |
+
if self.status.bench_result == "no bench":
|
| 125 |
+
self.logger.error("벤치마크 실행에 필요한 기본 폴더구조가 없습니다.")
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
return True # 그 외의 경우만 다음 단계로 진행
|
| 129 |
+
|
| 130 |
+
def _execute_based_on_status(self):
|
| 131 |
+
"""상태에 따른 실행 로직"""
|
| 132 |
+
if self.status.bench_result == "all_passed":
|
| 133 |
+
self._execute_full_pipeline()
|
| 134 |
+
elif self.status.bench_result == "no_vectors":
|
| 135 |
+
self._execute_vector_generation()
|
| 136 |
+
elif self.status.bench_result == "no_metrics":
|
| 137 |
+
self._execute_metrics_generation()
|
| 138 |
+
else:
|
| 139 |
+
self._execute_vector_generation()
|
| 140 |
+
self.logger.warning("폴더구조가 없습니다")
|
| 141 |
+
|
| 142 |
+
def _execute_full_pipeline(self):
|
| 143 |
+
"""모든 조건이 충족된 경우의 실행 로직"""
|
| 144 |
+
self.logger.info("전체 파이프라인 실행 중...")
|
| 145 |
+
pia_benchmark = PiaBenchMark(
|
| 146 |
+
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
|
| 147 |
+
model_name=self.config.model_name,
|
| 148 |
+
cfg_target_path= self.config.cfg_target_path ,
|
| 149 |
+
token=self.access_token )
|
| 150 |
+
pia_benchmark.preprocess_structure()
|
| 151 |
+
print("Categories identified:", pia_benchmark.categories)
|
| 152 |
+
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
|
| 153 |
+
label_dir=pia_benchmark.dataset_path,
|
| 154 |
+
save_dir=pia_benchmark.metric_path)
|
| 155 |
+
|
| 156 |
+
self.bench_result_dict = metric.evaluate()
|
| 157 |
+
|
| 158 |
+
def _execute_vector_generation(self):
|
| 159 |
+
"""벡터 생성이 필요한 경우의 실행 로직"""
|
| 160 |
+
self.logger.info("벡터 생성 중...")
|
| 161 |
+
# 구현 필요
|
| 162 |
+
|
| 163 |
+
pia_benchmark = PiaBenchMark(
|
| 164 |
+
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
|
| 165 |
+
model_name=self.config.model_name,
|
| 166 |
+
cfg_target_path= self.config.cfg_target_path ,
|
| 167 |
+
token=self.access_token )
|
| 168 |
+
pia_benchmark.preprocess_structure()
|
| 169 |
+
pia_benchmark.preprocess_label_to_csv()
|
| 170 |
+
print("Categories identified:", pia_benchmark.categories)
|
| 171 |
+
|
| 172 |
+
pia_benchmark.extract_visual_vector()
|
| 173 |
+
|
| 174 |
+
detector = EventDetector(config_path=self.config.cfg_target_path,
|
| 175 |
+
model_name=self.config.model_name ,
|
| 176 |
+
token=pia_benchmark.token)
|
| 177 |
+
detector.process_and_save_predictions(pia_benchmark.vector_video_path,
|
| 178 |
+
pia_benchmark.dataset_path,
|
| 179 |
+
pia_benchmark.alram_path)
|
| 180 |
+
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
|
| 181 |
+
label_dir=pia_benchmark.dataset_path,
|
| 182 |
+
save_dir=pia_benchmark.metric_path)
|
| 183 |
+
|
| 184 |
+
self.bench_result_dict = metric.evaluate()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _execute_metrics_generation(self):
|
| 188 |
+
"""메트릭 생성이 필요한 경우의 실행 로직"""
|
| 189 |
+
self.logger.info("메트릭 생성 중...")
|
| 190 |
+
# 구현 필요
|
| 191 |
+
pia_benchmark = PiaBenchMark(
|
| 192 |
+
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
|
| 193 |
+
model_name=self.config.model_name,
|
| 194 |
+
cfg_target_path= self.config.cfg_target_path ,
|
| 195 |
+
token=self.access_token )
|
| 196 |
+
pia_benchmark.preprocess_structure()
|
| 197 |
+
pia_benchmark.preprocess_label_to_csv()
|
| 198 |
+
print("Categories identified:", pia_benchmark.categories)
|
| 199 |
+
|
| 200 |
+
detector = EventDetector(config_path=self.config.cfg_target_path,
|
| 201 |
+
model_name=self.config.model_name ,
|
| 202 |
+
token=pia_benchmark.token)
|
| 203 |
+
detector.process_and_save_predictions(pia_benchmark.vector_video_path,
|
| 204 |
+
pia_benchmark.dataset_path,
|
| 205 |
+
pia_benchmark.alram_path)
|
| 206 |
+
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
|
| 207 |
+
label_dir=pia_benchmark.dataset_path,
|
| 208 |
+
save_dir=pia_benchmark.metric_path)
|
| 209 |
+
|
| 210 |
+
self.bench_result_dict = metric.evaluate()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
# 파이프라인 설정
|
| 215 |
+
config = PipelineConfig(
|
| 216 |
+
model_name="T2V_CLIP4CLIP_MSRVTT",
|
| 217 |
+
benchmark_name="PIA",
|
| 218 |
+
cfg_target_path="topk.json",
|
| 219 |
+
base_path="/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# 파이프라인 실행
|
| 223 |
+
pipeline = BenchmarkPipeline(config)
|
| 224 |
+
result = pipeline.run()
|
| 225 |
+
|
| 226 |
+
print(f"\n파이프라인 실행 결과:")
|
| 227 |
+
print(str(result))
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
oauth2client
|
| 2 |
+
gspread
|
| 3 |
+
gradio
|
| 4 |
+
python-dotenv
|
| 5 |
+
APScheduler
|
| 6 |
+
black
|
| 7 |
+
gradio[oauth]
|
| 8 |
+
gradio_leaderboard==0.0.13
|
| 9 |
+
gradio_client
|
| 10 |
+
huggingface-hub>=0.18.0
|
| 11 |
+
matplotlib
|
| 12 |
+
numpy
|
| 13 |
+
pandas
|
| 14 |
+
python-dateutil
|
| 15 |
+
tqdm
|
sample.csv
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
video_name,resolution,video_duration,category,benchmark,duration_seconds,total_frames,file_format,file_size_mb,aspect_ratio,fps
|
| 2 |
+
417-2_cam02_assault01_place03_night_spring.mp4,3840x2160,5:15,violence,PIA,315.14816666666667,9445,.mp4,351.66,1.78,29.97002997002997
|
| 3 |
+
439-5_cam01_assault01_place03_day_summer.mp4,3840x2160,5:19,violence,PIA,318.6516666666667,9550,.mp4,342.09,1.78,29.97002997002997
|
| 4 |
+
24-2_cam01_assault01_place09_night_winter.mp4,3840x2160,4:33,violence,PIA,272.7725,8175,.mp4,299.35,1.78,29.97002997002997
|
| 5 |
+
6-1_cam01_assault01_place03_night_summer.mp4,3840x2160,5:19,violence,PIA,318.6516666666667,9550,.mp4,342.81,1.78,29.97002997002997
|
| 6 |
+
fight_0026.mp4,640x360,3:20,violence,PIA,199.93306666666666,5992,.mp4,16.19,1.78,29.97002997002997
|
| 7 |
+
10-1_cam01_assault03_place07_night_winter.mp4,3840x2160,5:10,violence,PIA,310.07643333333334,9293,.mp4,333.36,1.78,29.97002997002997
|
| 8 |
+
22-2_cam01_assault01_place07_night_winter.mp4,3840x2160,5:5,violence,PIA,305.305,9150,.mp4,368.69,1.78,29.97002997002997
|
| 9 |
+
407-6_cam01_assault01_place04_day_winter.mp4,3840x2160,5:13,violence,PIA,312.77913333333333,9374,.mp4,747.29,1.78,29.97002997002997
|
| 10 |
+
407-6_cam01_assault01_place04_day_spring.mp4,3840x2160,5:9,violence,PIA,309.10880000000003,9264,.mp4,738.48,1.78,29.97002997002997
|
| 11 |
+
411-3_cam01_assault01_place08_night_winter.mp4,3840x2160,5:15,violence,PIA,314.7477666666667,9433,.mp4,337.82,1.78,29.97002997002997
|
| 12 |
+
412-1_cam01_assault01_place09_night_winter.mp4,3840x2160,5:30,violence,PIA,330.26326666666665,9898,.mp4,380.09,1.78,29.97002997002997
|
| 13 |
+
416-5_cam03_assault01_place02_night_spring.mp4,3840x2160,5:9,violence,PIA,308.97533333333337,9260,.mp4,342.81,1.78,29.97002997002997
|
| 14 |
+
12-1_cam01_assault01_place09_day_summer.mp4,3840x2160,4:54,violence,PIA,294.02706666666666,8812,.mp4,712.79,1.78,29.97002997002997
|
| 15 |
+
13-4_cam02_assault02_place08_day_spring.mp4,3840x2160,4:56,violence,PIA,295.96233333333333,8870,.mp4,711.64,1.78,29.97002997002997
|
| 16 |
+
16-3_cam01_assault01_place02_night_summer.mp4,3840x2160,5:0,violence,PIA,300.3333666666667,9001,.mp4,358.8,1.78,29.97002997002997
|
| 17 |
+
17-2_cam01_assault03_place03_night_spring.mp4,3840x2160,5:13,violence,PIA,312.8125,9375,.mp4,758.24,1.78,29.97002997002997
|
| 18 |
+
2-3_cam01_assault01_place04_night_spring.mp4,3840x2160,5:12,violence,PIA,312.4454666666667,9364,.mp4,750.96,1.78,29.97002997002997
|
| 19 |
+
20-3_cam02_assault01_place02_night_summer.mp4,3840x2160,5:14,violence,PIA,314.2139,9417,.mp4,450.17,1.78,29.97002997002997
|
| 20 |
+
23-3_cam01_assault01_place02_night_summer.mp4,3840x2160,5:0,violence,PIA,300.3,9000,.mp4,358.84,1.78,29.97002997002997
|
| 21 |
+
406-1_cam01_assault01_place03_day_summer.mp4,3840x2160,5:0,violence,PIA,300.3333666666667,9001,.mp4,428.29,1.78,29.97002997002997
|
| 22 |
+
8-1_cam01_assault03_place05_day_spring.mp4,3840x2160,5:0,violence,PIA,300.3,9000,.mp4,717.83,1.78,29.97002997002997
|
| 23 |
+
fight_0035.mp4,406x720,1:47,violence,PIA,107.27383333333333,3215,.mp4,11.04,0.56,29.97002997002997
|
| 24 |
+
fight_0062.mp4,640x360,1:12,violence,PIA,71.70496666666666,2149,.mp4,3.05,1.78,29.97002997002997
|
| 25 |
+
fight_0051.mp4,1280x720,2:0,violence,PIA,120.43333333333334,3599,.mp4,30.26,1.78,29.88375311375588
|
| 26 |
+
fight_0097.mp4,1280x720,1:8,violence,PIA,68.00126666666667,2038,.mp4,6.47,1.78,29.97002997002997
|
| 27 |
+
fight_0125.mp4,1280x720,1:26,violence,PIA,86.01926666666667,2578,.mp4,9.09,1.78,29.97002997002997
|
| 28 |
+
fight_0141.mp4,1280x720,3:11,violence,PIA,191.42456666666666,5737,.mp4,19.21,1.78,29.97002997002997
|
| 29 |
+
fight_0147.mp4,640x360,2:35,violence,PIA,154.73333333333335,4624,.mp4,9.45,1.78,29.88367083153813
|
| 30 |
+
fight_0156.mp4,1280x720,1:16,violence,PIA,76.0,2280,.mp4,10.58,1.78,30.0
|
| 31 |
+
fight_0162.mp4,1280x720,1:8,violence,PIA,67.86666666666666,2036,.mp4,8.28,1.78,30.0
|
| 32 |
+
20190102_013314A.mp4,3840x2160,15:0,fire,PIA,900.39,27013,.mp4,2477.09,1.78,30.001443818789635
|
| 33 |
+
화재 - 불피우기.mp4,1920x1080,2:40,fire,PIA,159.535675,4786,.mp4,76.23,1.78,29.999559659618452
|
| 34 |
+
화재 - 토치.mp4,1920x1080,8:44,fire,PIA,523.871349,15716,.mp4,249.98,1.78,29.999731861648346
|
| 35 |
+
Video34.mp4,292x240,15:1,fire,PIA,901.0485436893204,7734,.mp4,1.71,1.22,8.583333333333334
|
| 36 |
+
Video5.mp4,320x240,3:7,fire,PIA,187.33333333333334,4496,.mp4,4.82,1.33,24.0
|
| 37 |
+
Video49.mp4,320x240,1:9,fire,PIA,69.08333333333333,1658,.mp4,2.7,1.33,24.0
|
| 38 |
+
Video149.mp4,854x480,0:30,fire,PIA,29.996663329996665,899,.mp4,5.37,1.78,29.97
|
| 39 |
+
Video261.mp4,292x240,15:0,fire,PIA,900.1199999999999,7501,.mp4,4.09,1.22,8.333333333333334
|
| 40 |
+
fire_general-fire_rgb_0002_cctv1.mp4,1920x1080,0:7,fire,PIA,7.0,189,.mp4,5.58,1.78,27.0
|
| 41 |
+
fire_general-fire_rgb_0065_cctv4.mp4,1920x1080,0:9,fire,PIA,8.525191858525192,511,.mp4,47.03,1.78,59.94
|
| 42 |
+
fire_general-fire_rgb_0070_cctv3.mp4,1920x1080,0:12,fire,PIA,12.479145812479146,748,.mp4,73.64,1.78,59.94
|
| 43 |
+
fire_general-fire_rgb_0083_cctv2.mp4,1920x1080,0:10,fire,PIA,9.743076409743077,584,.mp4,55.54,1.78,59.94
|
| 44 |
+
fire_general-fire_rgb_0559_cctv1.mp4,1920x1080,0:10,fire,PIA,10.477143810477145,628,.mp4,81.13,1.78,59.94
|
| 45 |
+
fire_general-fire_rgb_0556_cctv2.mp4,1920x1080,0:9,fire,PIA,9.376042709376042,562,.mp4,65.8,1.78,59.94
|
| 46 |
+
fire_general-fire_rgb_0530_cctv1.mp4,1920x1080,0:12,fire,PIA,12.028695362028696,721,.mp4,87.26,1.78,59.94
|
| 47 |
+
fire_general-fire_rgb_0514_cctv2.mp4,1920x1080,0:8,fire,PIA,8.241574908241574,494,.mp4,59.32,1.78,59.94
|
| 48 |
+
fire_general-fire_rgb_0475_cctv1.mp4,1920x1080,0:9,fire,PIA,9.426092759426092,565,.mp4,73.66,1.78,59.94
|
| 49 |
+
fire_general-fire_rgb_0460_cctv2.mp4,1920x1080,0:11,fire,PIA,11.16116116116116,669,.mp4,80.93,1.78,59.94
|
| 50 |
+
fire_general-fire_rgb_0356_cctv1.mp4,1920x1080,0:9,fire,PIA,9.376042709376042,562,.mp4,79.11,1.78,59.94
|
| 51 |
+
fire_general-fire_rgb_0331_cctv2.mp4,1920x1080,0:7,fire,PIA,6.773440106773441,406,.mp4,52.11,1.78,59.94
|
| 52 |
+
fire_general-fire_rgb_0305_cctv3.mp4,1920x1080,0:10,fire,PIA,10.226893560226895,613,.mp4,84.28,1.78,59.94
|
| 53 |
+
fire_general-fire_rgb_0291_cctv1.mp4,1920x1080,0:4,fire,PIA,4.170837504170838,250,.mp4,22.08,1.78,59.94
|
| 54 |
+
fire_general-fire_rgb_0280_cctv4.mp4,1920x1080,0:4,fire,PIA,4.087420754087421,245,.mp4,21.92,1.78,59.94
|
| 55 |
+
fire_general-fire_rgb_0562_cctv7.mp4,1920x1080,0:7,fire,PIA,7.0,203,.mp4,7.62,1.78,29.0
|
| 56 |
+
fire_general-fire_rgb_0337_cctv2.mp4,1920x1080,0:7,fire,PIA,7.0,203,.mp4,5.37,1.78,29.0
|
| 57 |
+
fire_general-fire_rgb_0289_cctv2.mp4,1920x1080,0:7,fire,PIA,7.0,203,.mp4,7.74,1.78,29.0
|
| 58 |
+
fire_oil-fire_rgb_0002_cctv2.mp4,1920x1080,0:6,fire,PIA,6.0,180,.mp4,5.28,1.78,30.0
|
| 59 |
+
fire_oil-fire_rgb_0222_cctv4.mp4,1920x1080,0:6,fire,PIA,6.0,180,.mp4,6.37,1.78,30.0
|
| 60 |
+
fire_oil-fire_rgb_0445_cctv7.mp4,1920x1080,0:6,fire,PIA,6.0,174,.mp4,4.19,1.78,29.0
|
| 61 |
+
Explosion004_x264.mp4,320x240,1:3,fire,PIA,63.4,1902,.mp4,8.04,1.33,30.0
|
| 62 |
+
Explosion005_x264.mp4,320x240,0:23,fire,PIA,23.1,693,.mp4,5.46,1.33,30.0
|
| 63 |
+
Explosion009_x264.mp4,320x240,0:37,fire,PIA,36.7,1101,.mp4,9.13,1.33,30.0
|
| 64 |
+
Explosion010_x264.mp4,320x240,1:23,fire,PIA,83.26666666666667,2498,.mp4,11.58,1.33,30.0
|
| 65 |
+
Explosion013_x264.mp4,320x240,1:51,fire,PIA,110.56666666666666,3317,.mp4,15.66,1.33,30.0
|
| 66 |
+
Explosion014_x264.mp4,320x240,0:43,fire,PIA,43.06666666666667,1292,.mp4,6.24,1.33,30.0
|
| 67 |
+
Explosion017_x264.mp4,320x240,0:55,fire,PIA,54.766666666666666,1643,.mp4,12.31,1.33,30.0
|
| 68 |
+
Explosion002_x264.mp4,320x240,2:14,fire,PIA,133.76666666666668,4013,.mp4,18.32,1.33,30.0
|
| 69 |
+
Explosion051_x264.mp4,320x240,1:34,fire,PIA,94.0,2820,.mp4,13.52,1.33,30.0
|
| 70 |
+
119-1_cam01_swoon01_place03_day_summer.mp4,3840x2160,4:60,falldown,PIA,299.9997,8991,.mp4,810.22,1.78,29.97002997002997
|
| 71 |
+
100-5_cam02_swoon01_place02_day_summer.mp4,3840x2160,5:21,falldown,PIA,321.38773333333336,9632,.mp4,451.85,1.78,29.97002997002997
|
| 72 |
+
FILE210101-012606F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.23,1.78,29.97002997002997
|
| 73 |
+
118-2_cam01_swoon02_place10_day_spring.mp4,3840x2160,5:0,falldown,PIA,300.3,9000,.mp4,442.05,1.78,29.97002997002997
|
| 74 |
+
108-5_cam02_swoon01_place06_night_spring.mp4,3840x2160,4:56,falldown,PIA,296.22926666666666,8878,.mp4,718.68,1.78,29.97002997002997
|
| 75 |
+
FILE210101-003713F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.27,1.78,29.97002997002997
|
| 76 |
+
FILE210101-010727F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.04,1.78,29.97002997002997
|
| 77 |
+
245-5_cam02_swoon01_place04_night_spring.mp4,3840x2160,5:7,falldown,PIA,306.97333333333336,9200,.mp4,369.43,1.78,29.97002997002997
|
| 78 |
+
115-1_cam01_swoon01_place02_night_spring.mp4,3840x2160,5:8,falldown,PIA,308.2746333333333,9239,.mp4,362.95,1.78,29.97002997002997
|
| 79 |
+
110-2_cam01_swoon01_place01_day_spring.mp4,3840x2160,4:55,falldown,PIA,295.06143333333335,8843,.mp4,709.46,1.78,29.97002997002997
|
| 80 |
+
FILE210101-005228F.MOV,3840x2160,15:0,falldown,PIA,900.0324666666667,26974,.mov,3362.06,1.78,29.97002997002997
|
| 81 |
+
114-2_cam01_swoon03_place03_day_spring.mp4,3840x2160,5:6,falldown,PIA,306.306,9180,.mp4,438.87,1.78,29.97002997002997
|
| 82 |
+
117-1_cam02_swoon01_place04_night_spring.mp4,3840x2160,5:8,falldown,PIA,308.4081,9243,.mp4,747.54,1.78,29.97002997002997
|
| 83 |
+
104-2_cam01_swoon01_place04_day_spring.mp4,3840x2160,5:6,falldown,PIA,305.7054,9162,.mp4,742.15,1.78,29.97002997002997
|
| 84 |
+
103-1_cam01_swoon01_place04_day_spring.mp4,3840x2160,5:8,falldown,PIA,307.5072,9216,.mp4,746.19,1.78,29.97002997002997
|
| 85 |
+
99-4_cam03_swoon01_place03_day_winter.mp4,3840x2160,5:7,falldown,PIA,306.6396666666667,9190,.mp4,329.27,1.78,29.97002997002997
|
| 86 |
+
240-3_cam01_swoon01_place02_night_spring.mp4,3840x2160,5:8,falldown,PIA,308.04106666666667,9232,.mp4,330.92,1.78,29.97002997002997
|
| 87 |
+
115-5_cam02_swoon01_place02_night_spring.mp4,3840x2160,5:10,falldown,PIA,309.57593333333335,9278,.mp4,750.26,1.78,29.97002997002997
|
| 88 |
+
640-2_cam01_swoon01_place01_day_summer.mp4,3840x2160,5:0,falldown,PIA,300.3333666666667,9001,.mp4,425.3,1.78,29.97002997002997
|
| 89 |
+
517-3_cam03_swoon03_place04_night_winter.mp4,3840x2160,5:7,falldown,PIA,306.9066,9198,.mp4,330.13,1.78,29.97002997002997
|
| 90 |
+
107-1_cam01_swoon01_place06_night_spring.mp4,3840x2160,4:0,falldown,PIA,240.24,7200,.mp4,580.49,1.78,29.97002997002997
|
| 91 |
+
106-1_cam02_swoon01_place05_day_spring.mp4,3840x2160,5:1,falldown,PIA,300.56693333333334,9008,.mp4,430.69,1.78,29.97002997002997
|
| 92 |
+
110-1_cam01_swoon01_place01_day_spring.mp4,3840x2160,4:55,falldown,PIA,294.6276666666667,8830,.mp4,708.43,1.78,29.97002997002997
|
| 93 |
+
116-4_cam02_swoon01_place03_day_summer.mp4,3840x2160,5:7,falldown,PIA,306.53956666666664,9187,.mp4,340.57,1.78,29.97002997002997
|
| 94 |
+
109-5_cam02_swoon02_place01_night_summer.mp4,3840x2160,5:12,falldown,PIA,312.312,9360,.mp4,757.09,1.78,29.97002997002997
|
| 95 |
+
105-5_cam01_swoon01_place05_night_summer.mp4,3840x2160,5:11,falldown,PIA,311.34436666666664,9331,.mp4,333.82,1.78,29.97002997002997
|
| 96 |
+
112-6_cam01_swoon02_place01_night_summer.mp4,3840x2160,4:0,falldown,PIA,240.24,7200,.mp4,582.44,1.78,29.97002997002997
|
| 97 |
+
108-2_cam01_swoon01_place06_night_summer.mp4,3840x2160,4:60,falldown,PIA,299.7327666666667,8983,.mp4,726.91,1.78,29.97002997002997
|
| 98 |
+
120-1_cam02_swoon02_place06_day_summer.mp4,3840x2160,5:22,falldown,PIA,322.2219,9657,.mp4,781.36,1.78,29.97002997002997
|
| 99 |
+
113-2_cam02_swoon01_place08_day_summer.mp4,3840x2160,5:14,falldown,PIA,313.7134,9402,.mp4,450.0,1.78,29.97002997002997
|
sheet_manager/sheet_checker/sheet_check.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
import logging
|
| 3 |
+
import gspread
|
| 4 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 5 |
+
|
| 6 |
+
class SheetChecker:
|
| 7 |
+
def __init__(self, sheet_manager: SheetManager):
|
| 8 |
+
"""SheetChecker 초기화"""
|
| 9 |
+
self.sheet_manager = sheet_manager
|
| 10 |
+
self.bench_sheet_manager = None
|
| 11 |
+
self.logger = logging.getLogger(__name__)
|
| 12 |
+
self._init_bench_sheet()
|
| 13 |
+
|
| 14 |
+
def _init_bench_sheet(self):
|
| 15 |
+
"""model 시트용 시트 매니저 초기화"""
|
| 16 |
+
self.bench_sheet_manager = type(self.sheet_manager)(
|
| 17 |
+
spreadsheet_url=self.sheet_manager.spreadsheet_url,
|
| 18 |
+
worksheet_name="model",
|
| 19 |
+
column_name="Model name"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def add_benchmark_column(self, column_name: str):
|
| 23 |
+
"""새로운 벤치마크 컬럼 추가"""
|
| 24 |
+
try:
|
| 25 |
+
headers = self.bench_sheet_manager.get_available_columns()
|
| 26 |
+
|
| 27 |
+
if column_name in headers:
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
new_col_index = len(headers) + 1
|
| 31 |
+
cell = gspread.utils.rowcol_to_a1(1, new_col_index)
|
| 32 |
+
self.bench_sheet_manager.sheet.update(cell, [[column_name]])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# 관련 컬럼 추가 (벤치마크이름*100)
|
| 36 |
+
next_col_index = new_col_index + 1
|
| 37 |
+
next_cell = gspread.utils.rowcol_to_a1(1, next_col_index)
|
| 38 |
+
self.bench_sheet_manager.sheet.update(next_cell, [[f"{column_name}*100"]])
|
| 39 |
+
|
| 40 |
+
self.logger.info(f"새로운 벤치마크 컬럼들 추가됨: {column_name}, {column_name}*100")
|
| 41 |
+
# 컬럼 추가 후 시트 매니저 재연결
|
| 42 |
+
self.bench_sheet_manager._connect_to_sheet(validate_column=False)
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
self.logger.error(f"벤치마크 컬럼 {column_name} 추가 중 오류 발생: {str(e)}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
def check_model_and_benchmark(self, model_name: str, benchmark_name: str) -> Tuple[bool, bool]:
|
| 49 |
+
"""
|
| 50 |
+
모델 존재 여부와 벤치마크 상태를 확인하고, 필요한 경우 모델 정보를 추가
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model_name: 확인할 모델 이름
|
| 54 |
+
benchmark_name: 확인할 벤치마크 이름
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple[bool, bool]: (모델이 새로 추가되었는지 여부, 벤치마크가 이미 존재하는지 여부)
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
# 모델 존재 여부 확인
|
| 61 |
+
model_exists = self._check_model_exists(model_name)
|
| 62 |
+
model_added = False
|
| 63 |
+
|
| 64 |
+
# 모델이 없으면 추가
|
| 65 |
+
if not model_exists:
|
| 66 |
+
self._add_new_model(model_name)
|
| 67 |
+
model_added = True
|
| 68 |
+
self.logger.info(f"새로운 모델 추가됨: {model_name}")
|
| 69 |
+
|
| 70 |
+
# 벤치마크 컬럼이 없으면 추가
|
| 71 |
+
available_columns = self.bench_sheet_manager.get_available_columns()
|
| 72 |
+
if benchmark_name not in available_columns:
|
| 73 |
+
self.add_benchmark_column(benchmark_name)
|
| 74 |
+
self.logger.info(f"새로운 벤치마크 컬럼 추가됨: {benchmark_name}")
|
| 75 |
+
|
| 76 |
+
# 벤치마크 상태 확인
|
| 77 |
+
benchmark_exists = self._check_benchmark_exists(model_name, benchmark_name)
|
| 78 |
+
|
| 79 |
+
return model_added, benchmark_exists
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
self.logger.error(f"모델/벤치마크 확인 중 오류 발생: {str(e)}")
|
| 83 |
+
raise
|
| 84 |
+
|
| 85 |
+
def _check_model_exists(self, model_name: str) -> bool:
|
| 86 |
+
"""모델 존재 여부 확인"""
|
| 87 |
+
try:
|
| 88 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 89 |
+
values = self.bench_sheet_manager.get_all_values()
|
| 90 |
+
return model_name in values
|
| 91 |
+
except Exception as e:
|
| 92 |
+
self.logger.error(f"모델 존재 여부 확인 중 오류 발생: {str(e)}")
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
def _add_new_model(self, model_name: str):
|
| 96 |
+
"""새로운 모델 정보 추가"""
|
| 97 |
+
try:
|
| 98 |
+
model_info = {
|
| 99 |
+
"Model name": model_name,
|
| 100 |
+
"Model link": f"https://huggingface.co/PIA-SPACE-LAB/{model_name}",
|
| 101 |
+
"Model": f'<a target="_blank" href="https://huggingface.co/PIA-SPACE-LAB/{model_name}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
for column_name, value in model_info.items():
|
| 105 |
+
self.bench_sheet_manager.change_column(column_name)
|
| 106 |
+
self.bench_sheet_manager.push(value)
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
self.logger.error(f"모델 정보 추가 중 오류 발생: {str(e)}")
|
| 110 |
+
raise
|
| 111 |
+
|
| 112 |
+
def _check_benchmark_exists(self, model_name: str, benchmark_name: str) -> bool:
|
| 113 |
+
"""벤치마크 값 존재 여부 확인"""
|
| 114 |
+
try:
|
| 115 |
+
# 해당 모델의 벤치마크 값 확인
|
| 116 |
+
self.bench_sheet_manager.change_column("Model name")
|
| 117 |
+
all_values = self.bench_sheet_manager.get_all_values()
|
| 118 |
+
row_index = all_values.index(model_name) + 2
|
| 119 |
+
|
| 120 |
+
self.bench_sheet_manager.change_column(benchmark_name)
|
| 121 |
+
value = self.bench_sheet_manager.sheet.cell(row_index, self.bench_sheet_manager.col_index).value
|
| 122 |
+
|
| 123 |
+
return bool(value and value.strip())
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
self.logger.error(f"벤치마크 존재 여부 확인 중 오류 발생: {str(e)}")
|
| 127 |
+
raise
|
| 128 |
+
|
| 129 |
+
# 사용 예시
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
sheet_manager = SheetManager()
|
| 132 |
+
checker = SheetChecker(sheet_manager)
|
| 133 |
+
|
| 134 |
+
model_added, benchmark_exists = checker.check_model_and_benchmark(
|
| 135 |
+
model_name="test-model",
|
| 136 |
+
benchmark_name="COCO"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
print(f"Model added: {model_added}")
|
| 140 |
+
print(f"Benchmark exists: {benchmark_exists}")
|
sheet_manager/sheet_convert/json2sheet.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 3 |
+
import json
|
| 4 |
+
from typing import Optional, Dict
|
| 5 |
+
|
| 6 |
+
def update_benchmark_json(
|
| 7 |
+
model_name: str,
|
| 8 |
+
benchmark_data: dict,
|
| 9 |
+
worksheet_name: str = "metric",
|
| 10 |
+
target_column: str = "benchmark" # 타겟 칼럼 파라미터 추가
|
| 11 |
+
):
|
| 12 |
+
"""
|
| 13 |
+
특정 모델의 벤치마크 데이터를 JSON 형태로 업데이트합니다.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
model_name (str): 업데이트할 모델 이름
|
| 17 |
+
benchmark_data (dict): 업데이트할 벤치마크 데이터 딕셔너리
|
| 18 |
+
worksheet_name (str): 작업할 워크시트 이름 (기본값: "metric")
|
| 19 |
+
target_column (str): 업데이트할 타겟 칼럼 이름 (기본값: "benchmark")
|
| 20 |
+
"""
|
| 21 |
+
sheet_manager = SheetManager(worksheet_name=worksheet_name)
|
| 22 |
+
|
| 23 |
+
# 딕셔너리를 JSON 문자열로 변환
|
| 24 |
+
json_str = json.dumps(benchmark_data, ensure_ascii=False)
|
| 25 |
+
|
| 26 |
+
# 모델명을 기준으로 지정된 칼럼 업데이트
|
| 27 |
+
row = sheet_manager.update_cell_by_condition(
|
| 28 |
+
condition_column="Model name", # 모델명이 있는 칼럼
|
| 29 |
+
condition_value=model_name, # 찾을 모델명
|
| 30 |
+
target_column=target_column, # 업데이트할 타겟 칼럼
|
| 31 |
+
target_value=json_str # 업데이트할 JSON 값
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if row:
|
| 35 |
+
print(f"Successfully updated {target_column} data for model: {model_name}")
|
| 36 |
+
else:
|
| 37 |
+
print(f"Model {model_name} not found in the sheet")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_benchmark_dict(
|
| 42 |
+
model_name: str,
|
| 43 |
+
worksheet_name: str = "metric",
|
| 44 |
+
target_column: str = "benchmark",
|
| 45 |
+
save_path: Optional[str] = None
|
| 46 |
+
) -> Dict:
|
| 47 |
+
"""
|
| 48 |
+
시트에서 특정 모델의 벤치마크 JSON 데이터를 가져와 딕셔너리로 변환합니다.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model_name (str): 가져올 모델 이름
|
| 52 |
+
worksheet_name (str): 작업할 워크시트 이름 (기본값: "metric")
|
| 53 |
+
target_column (str): 데이터를 가져올 칼럼 이름 (기본값: "benchmark")
|
| 54 |
+
save_path (str, optional): 딕셔너리를 저장할 JSON 파일 경로
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Dict: 벤치마크 데이터 딕셔너리. 데이터가 없거나 JSON 파싱 실패시 빈 딕셔너리 반환
|
| 58 |
+
"""
|
| 59 |
+
sheet_manager = SheetManager(worksheet_name=worksheet_name)
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
# 모든 데이터 가져오기
|
| 63 |
+
data = sheet_manager.sheet.get_all_records()
|
| 64 |
+
|
| 65 |
+
# 해당 모델 찾기
|
| 66 |
+
target_row = next(
|
| 67 |
+
(row for row in data if row.get("Model name") == model_name),
|
| 68 |
+
None
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if not target_row:
|
| 72 |
+
print(f"Model {model_name} not found in the sheet")
|
| 73 |
+
return {}
|
| 74 |
+
|
| 75 |
+
# 타겟 칼럼의 JSON 문자열 가져오기
|
| 76 |
+
json_str = target_row.get(target_column)
|
| 77 |
+
|
| 78 |
+
if not json_str:
|
| 79 |
+
print(f"No data found in {target_column} for model: {model_name}")
|
| 80 |
+
return {}
|
| 81 |
+
|
| 82 |
+
# JSON 문자열을 딕셔너리로 변환
|
| 83 |
+
result_dict = json.loads(json_str)
|
| 84 |
+
|
| 85 |
+
# 결과 저장 (save_path가 제공된 경우)
|
| 86 |
+
if save_path:
|
| 87 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 88 |
+
json.dump(result_dict, f, ensure_ascii=False, indent=2)
|
| 89 |
+
print(f"Successfully saved dictionary to: {save_path}")
|
| 90 |
+
|
| 91 |
+
return result_dict
|
| 92 |
+
|
| 93 |
+
except json.JSONDecodeError:
|
| 94 |
+
print(f"Failed to parse JSON data for model: {model_name}")
|
| 95 |
+
return {}
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Error occurred: {str(e)}")
|
| 98 |
+
return {}
|
| 99 |
+
|
| 100 |
+
def str2json(json_str):
|
| 101 |
+
"""
|
| 102 |
+
문자열을 JSON 객체로 변환합니다.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
json_str (str): JSON 형식의 문자열
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
dict: 파싱된 JSON 객체, 실패시 None
|
| 109 |
+
"""
|
| 110 |
+
try:
|
| 111 |
+
return json.loads(json_str)
|
| 112 |
+
except json.JSONDecodeError as e:
|
| 113 |
+
print(f"JSON Parsing Error: {e}")
|
| 114 |
+
return None
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Unexpected Error: {e}")
|
| 117 |
+
return None
|
sheet_manager/sheet_crud/create_col.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 3 |
+
import gspread
|
| 4 |
+
from huggingface_hub import HfApi
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from enviroments.convert import get_json_from_env_var
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
def push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization):
|
| 12 |
+
"""
|
| 13 |
+
Fetches model names from Hugging Face and updates a Google Sheet with the names, links, and HTML links.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
json_key_path (str): Path to the Google service account JSON key file.
|
| 17 |
+
spreadsheet_url (str): URL of the Google Spreadsheet.
|
| 18 |
+
sheet_name (str): Name of the sheet to update.
|
| 19 |
+
access_token (str): Hugging Face access token.
|
| 20 |
+
organization (str): Organization name on Hugging Face.
|
| 21 |
+
"""
|
| 22 |
+
# Authorize Google Sheets API
|
| 23 |
+
scope = ['https://spreadsheets.google.com/feeds',
|
| 24 |
+
'https://www.googleapis.com/auth/drive']
|
| 25 |
+
json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
|
| 26 |
+
credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
|
| 27 |
+
gc = gspread.authorize(credential)
|
| 28 |
+
|
| 29 |
+
# Open the Google Spreadsheet
|
| 30 |
+
doc = gc.open_by_url(spreadsheet_url)
|
| 31 |
+
sheet = doc.worksheet(sheet_name)
|
| 32 |
+
|
| 33 |
+
# Fetch existing data from the sheet
|
| 34 |
+
existing_data = pd.DataFrame(sheet.get_all_records())
|
| 35 |
+
|
| 36 |
+
# Fetch models from Hugging Face
|
| 37 |
+
api = HfApi()
|
| 38 |
+
models = api.list_models(author=organization, use_auth_token=access_token)
|
| 39 |
+
|
| 40 |
+
# Extract model names, links, and HTML links
|
| 41 |
+
model_details = [{
|
| 42 |
+
"Model name": model.modelId.split("/")[1],
|
| 43 |
+
"Model link": f"https://huggingface.co/{model.modelId}",
|
| 44 |
+
"Model": f"<a target=\"_blank\" href=\"https://huggingface.co/{model.modelId}\" style=\"color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;\">{model.modelId}</a>"
|
| 45 |
+
} for model in models]
|
| 46 |
+
|
| 47 |
+
new_data_df = pd.DataFrame(model_details)
|
| 48 |
+
|
| 49 |
+
# Check for duplicates and update only new model names
|
| 50 |
+
if "Model name" in existing_data.columns:
|
| 51 |
+
existing_model_names = existing_data["Model name"].tolist()
|
| 52 |
+
else:
|
| 53 |
+
existing_model_names = []
|
| 54 |
+
|
| 55 |
+
new_data_df = new_data_df[~new_data_df["Model name"].isin(existing_model_names)]
|
| 56 |
+
|
| 57 |
+
if not new_data_df.empty:
|
| 58 |
+
# Append new model names, links, and HTML links to the existing data
|
| 59 |
+
updated_data = pd.concat([existing_data, new_data_df], ignore_index=True)
|
| 60 |
+
|
| 61 |
+
# Push updated data back to the sheet
|
| 62 |
+
updated_data = updated_data.replace([float('inf'), float('-inf')], None) # Infinity 값을 None으로 변환
|
| 63 |
+
updated_data = updated_data.fillna('') # NaN 값을 빈 문자열로 변환
|
| 64 |
+
sheet.update([updated_data.columns.values.tolist()] + updated_data.values.tolist())
|
| 65 |
+
print("New model names, links, and HTML links successfully added to Google Sheet.")
|
| 66 |
+
else:
|
| 67 |
+
print("No new model names to add.")
|
| 68 |
+
|
| 69 |
+
# Example usage
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
spreadsheet_url = os.getenv("SPREADSHEET_URL")
|
| 72 |
+
access_token = os.getenv("ACCESS_TOKEN")
|
| 73 |
+
sheet_name = "시트1"
|
| 74 |
+
organization = "PIA-SPACE-LAB"
|
| 75 |
+
|
| 76 |
+
push_model_names_to_sheet(spreadsheet_url, sheet_name, access_token, organization)
|
sheet_manager/sheet_crud/sheet_crud.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 3 |
+
import gspread
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from enviroments.convert import get_json_from_env_var
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
|
| 8 |
+
load_dotenv(override=True)
|
| 9 |
+
|
| 10 |
+
class SheetManager:
|
| 11 |
+
def __init__(self, spreadsheet_url: Optional[str] = None,
|
| 12 |
+
worksheet_name: str = "flag",
|
| 13 |
+
column_name: str = "huggingface_id"):
|
| 14 |
+
"""
|
| 15 |
+
Initialize SheetManager with Google Sheets credentials and connection.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
spreadsheet_url (str, optional): URL of the Google Spreadsheet.
|
| 19 |
+
If None, takes from environment variable.
|
| 20 |
+
worksheet_name (str): Name of the worksheet to operate on.
|
| 21 |
+
Defaults to "flag".
|
| 22 |
+
column_name (str): Name of the column to operate on.
|
| 23 |
+
Defaults to "huggingface_id".
|
| 24 |
+
"""
|
| 25 |
+
self.spreadsheet_url = spreadsheet_url or os.getenv("SPREADSHEET_URL")
|
| 26 |
+
if not self.spreadsheet_url:
|
| 27 |
+
raise ValueError("Spreadsheet URL not provided and not found in environment variables")
|
| 28 |
+
|
| 29 |
+
self.worksheet_name = worksheet_name
|
| 30 |
+
self.column_name = column_name
|
| 31 |
+
|
| 32 |
+
# Initialize credentials and client
|
| 33 |
+
self._init_google_client()
|
| 34 |
+
|
| 35 |
+
# Initialize sheet connection
|
| 36 |
+
self.doc = None
|
| 37 |
+
self.sheet = None
|
| 38 |
+
self.col_index = None
|
| 39 |
+
self._connect_to_sheet(validate_column=True)
|
| 40 |
+
|
| 41 |
+
def _init_google_client(self):
|
| 42 |
+
"""Initialize Google Sheets client with credentials."""
|
| 43 |
+
scope = ['https://spreadsheets.google.com/feeds',
|
| 44 |
+
'https://www.googleapis.com/auth/drive']
|
| 45 |
+
json_key_dict = get_json_from_env_var("GOOGLE_CREDENTIALS")
|
| 46 |
+
credentials = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
|
| 47 |
+
self.client = gspread.authorize(credentials)
|
| 48 |
+
|
| 49 |
+
def _connect_to_sheet(self, validate_column: bool = True):
|
| 50 |
+
"""
|
| 51 |
+
Connect to the specified Google Sheet and initialize necessary attributes.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
validate_column (bool): Whether to validate the column name exists
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
self.doc = self.client.open_by_url(self.spreadsheet_url)
|
| 58 |
+
|
| 59 |
+
# Try to get the worksheet
|
| 60 |
+
try:
|
| 61 |
+
self.sheet = self.doc.worksheet(self.worksheet_name)
|
| 62 |
+
except gspread.exceptions.WorksheetNotFound:
|
| 63 |
+
raise ValueError(f"Worksheet '{self.worksheet_name}' not found in spreadsheet")
|
| 64 |
+
|
| 65 |
+
# Get headers
|
| 66 |
+
self.headers = self.sheet.row_values(1)
|
| 67 |
+
|
| 68 |
+
# Validate column only if requested
|
| 69 |
+
if validate_column:
|
| 70 |
+
try:
|
| 71 |
+
self.col_index = self.headers.index(self.column_name) + 1
|
| 72 |
+
except ValueError:
|
| 73 |
+
# If column not found, use first available column
|
| 74 |
+
if self.headers:
|
| 75 |
+
self.column_name = self.headers[0]
|
| 76 |
+
self.col_index = 1
|
| 77 |
+
print(f"Column '{self.column_name}' not found. Using first available column: '{self.headers[0]}'")
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError("No columns found in worksheet")
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
if isinstance(e, ValueError):
|
| 83 |
+
raise e
|
| 84 |
+
raise ConnectionError(f"Failed to connect to sheet: {str(e)}")
|
| 85 |
+
|
| 86 |
+
def change_worksheet(self, worksheet_name: str, column_name: Optional[str] = None):
|
| 87 |
+
"""
|
| 88 |
+
Change the current worksheet and optionally the column.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
worksheet_name (str): Name of the worksheet to switch to
|
| 92 |
+
column_name (str, optional): Name of the column to switch to
|
| 93 |
+
"""
|
| 94 |
+
old_worksheet = self.worksheet_name
|
| 95 |
+
old_column = self.column_name
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
self.worksheet_name = worksheet_name
|
| 99 |
+
if column_name:
|
| 100 |
+
self.column_name = column_name
|
| 101 |
+
|
| 102 |
+
# First connect without column validation
|
| 103 |
+
self._connect_to_sheet(validate_column=False)
|
| 104 |
+
|
| 105 |
+
# Then validate the column if specified
|
| 106 |
+
if column_name:
|
| 107 |
+
self.change_column(column_name)
|
| 108 |
+
else:
|
| 109 |
+
# Validate existing column in new worksheet
|
| 110 |
+
try:
|
| 111 |
+
self.col_index = self.headers.index(self.column_name) + 1
|
| 112 |
+
except ValueError:
|
| 113 |
+
# If column not found, use first available column
|
| 114 |
+
if self.headers:
|
| 115 |
+
self.column_name = self.headers[0]
|
| 116 |
+
self.col_index = 1
|
| 117 |
+
print(f"Column '{old_column}' not found in new worksheet. Using first available column: '{self.headers[0]}'")
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("No columns found in worksheet")
|
| 120 |
+
|
| 121 |
+
print(f"Successfully switched to worksheet: {worksheet_name}, using column: {self.column_name}")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
# Restore previous state on error
|
| 125 |
+
self.worksheet_name = old_worksheet
|
| 126 |
+
self.column_name = old_column
|
| 127 |
+
self._connect_to_sheet()
|
| 128 |
+
raise e
|
| 129 |
+
|
| 130 |
+
def change_column(self, column_name: str):
|
| 131 |
+
"""
|
| 132 |
+
Change the target column.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
column_name (str): Name of the column to switch to
|
| 136 |
+
"""
|
| 137 |
+
if not self.headers:
|
| 138 |
+
self.headers = self.sheet.row_values(1)
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
self.col_index = self.headers.index(column_name) + 1
|
| 142 |
+
self.column_name = column_name
|
| 143 |
+
print(f"Successfully switched to column: {column_name}")
|
| 144 |
+
except ValueError:
|
| 145 |
+
raise ValueError(f"Column '{column_name}' not found in worksheet. Available columns: {', '.join(self.headers)}")
|
| 146 |
+
|
| 147 |
+
def get_available_worksheets(self) -> List[str]:
|
| 148 |
+
"""Get list of all available worksheets in the spreadsheet."""
|
| 149 |
+
return [worksheet.title for worksheet in self.doc.worksheets()]
|
| 150 |
+
|
| 151 |
+
def get_available_columns(self) -> List[str]:
|
| 152 |
+
"""Get list of all available columns in the current worksheet."""
|
| 153 |
+
return self.headers if self.headers else self.sheet.row_values(1)
|
| 154 |
+
|
| 155 |
+
def _reconnect_if_needed(self):
|
| 156 |
+
"""Reconnect to the sheet if the connection is lost."""
|
| 157 |
+
try:
|
| 158 |
+
self.sheet.row_values(1)
|
| 159 |
+
except (gspread.exceptions.APIError, AttributeError):
|
| 160 |
+
self._init_google_client()
|
| 161 |
+
self._connect_to_sheet()
|
| 162 |
+
|
| 163 |
+
def _fetch_column_data(self) -> List[str]:
|
| 164 |
+
"""Fetch all data from the huggingface_id column."""
|
| 165 |
+
values = self.sheet.col_values(self.col_index)
|
| 166 |
+
return values[1:] # Exclude header
|
| 167 |
+
|
| 168 |
+
def _update_sheet(self, data: List[str]):
|
| 169 |
+
"""Update the entire column with new data."""
|
| 170 |
+
try:
|
| 171 |
+
# Prepare the range for update (excluding header)
|
| 172 |
+
start_cell = gspread.utils.rowcol_to_a1(2, self.col_index) # Start from row 2
|
| 173 |
+
end_cell = gspread.utils.rowcol_to_a1(len(data) + 2, self.col_index)
|
| 174 |
+
range_name = f"{start_cell}:{end_cell}"
|
| 175 |
+
|
| 176 |
+
# Convert data to 2D array format required by gspread
|
| 177 |
+
cells = [[value] for value in data]
|
| 178 |
+
|
| 179 |
+
# Update the range
|
| 180 |
+
self.sheet.update(range_name, cells)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"Error updating sheet: {str(e)}")
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
def push(self, text: str) -> int:
|
| 186 |
+
"""
|
| 187 |
+
Push a text value to the next empty cell in the huggingface_id column.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
text (str): Text to push to the sheet
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
int: The row number where the text was pushed
|
| 194 |
+
"""
|
| 195 |
+
try:
|
| 196 |
+
self._reconnect_if_needed()
|
| 197 |
+
|
| 198 |
+
# Get all values in the huggingface_id column
|
| 199 |
+
column_values = self.sheet.col_values(self.col_index)
|
| 200 |
+
|
| 201 |
+
# Find the next empty row
|
| 202 |
+
next_row = None
|
| 203 |
+
for i in range(1, len(column_values)):
|
| 204 |
+
if not column_values[i].strip():
|
| 205 |
+
next_row = i + 1
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
# If no empty row found, append to the end
|
| 209 |
+
if next_row is None:
|
| 210 |
+
next_row = len(column_values) + 1
|
| 211 |
+
|
| 212 |
+
# Update the cell
|
| 213 |
+
self.sheet.update_cell(next_row, self.col_index, text)
|
| 214 |
+
print(f"Successfully pushed value: {text} to row {next_row}")
|
| 215 |
+
return next_row
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Error pushing to sheet: {str(e)}")
|
| 219 |
+
raise
|
| 220 |
+
|
| 221 |
+
def pop(self) -> Optional[str]:
|
| 222 |
+
"""Remove and return the most recent value."""
|
| 223 |
+
try:
|
| 224 |
+
self._reconnect_if_needed()
|
| 225 |
+
data = self._fetch_column_data()
|
| 226 |
+
|
| 227 |
+
if not data or not data[0].strip():
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
value = data.pop(0) # Remove first value
|
| 231 |
+
data.append("") # Add empty string at the end to maintain sheet size
|
| 232 |
+
|
| 233 |
+
self._update_sheet(data)
|
| 234 |
+
print(f"Successfully popped value: {value}")
|
| 235 |
+
return value
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f"Error popping from sheet: {str(e)}")
|
| 239 |
+
raise
|
| 240 |
+
|
| 241 |
+
def delete(self, value: str) -> List[int]:
|
| 242 |
+
"""Delete all occurrences of a value."""
|
| 243 |
+
try:
|
| 244 |
+
self._reconnect_if_needed()
|
| 245 |
+
data = self._fetch_column_data()
|
| 246 |
+
|
| 247 |
+
# Find all indices before deletion
|
| 248 |
+
indices = [i + 1 for i, v in enumerate(data) if v.strip() == value.strip()]
|
| 249 |
+
if not indices:
|
| 250 |
+
print(f"Value '{value}' not found in sheet")
|
| 251 |
+
return []
|
| 252 |
+
|
| 253 |
+
# Remove matching values and add empty strings at the end
|
| 254 |
+
data = [v for v in data if v.strip() != value.strip()]
|
| 255 |
+
data.extend([""] * len(indices)) # Add empty strings to maintain sheet size
|
| 256 |
+
|
| 257 |
+
self._update_sheet(data)
|
| 258 |
+
print(f"Successfully deleted value '{value}' from rows: {indices}")
|
| 259 |
+
return indices
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error deleting from sheet: {str(e)}")
|
| 263 |
+
raise
|
| 264 |
+
|
| 265 |
+
def update_cell_by_condition(self, condition_column: str, condition_value: str, target_column: str, target_value: str) -> Optional[int]:
|
| 266 |
+
"""
|
| 267 |
+
Update the value of a cell based on a condition in another column.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
condition_column (str): The column to check the condition on.
|
| 271 |
+
condition_value (str): The value to match in the condition column.
|
| 272 |
+
target_column (str): The column where the value should be updated.
|
| 273 |
+
target_value (str): The new value to set in the target column.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Optional[int]: The row number where the value was updated, or None if no matching row was found.
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
self._reconnect_if_needed()
|
| 280 |
+
|
| 281 |
+
# Get all column headers
|
| 282 |
+
headers = self.sheet.row_values(1)
|
| 283 |
+
|
| 284 |
+
# Find the indices for the condition and target columns
|
| 285 |
+
try:
|
| 286 |
+
condition_col_index = headers.index(condition_column) + 1
|
| 287 |
+
except ValueError:
|
| 288 |
+
raise ValueError(f"조건 칼럼 '{condition_column}'이(가) 없습니다.")
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
target_col_index = headers.index(target_column) + 1
|
| 292 |
+
except ValueError:
|
| 293 |
+
raise ValueError(f"목표 칼럼 '{target_column}'이(가) 없습니다.")
|
| 294 |
+
|
| 295 |
+
# Get all rows of data
|
| 296 |
+
data = self.sheet.get_all_records()
|
| 297 |
+
|
| 298 |
+
# Find the row that matches the condition
|
| 299 |
+
for i, row in enumerate(data):
|
| 300 |
+
if row.get(condition_column) == condition_value:
|
| 301 |
+
# Update the target column in the matching row
|
| 302 |
+
row_number = i + 2 # Row index starts at 2 (1 is header)
|
| 303 |
+
self.sheet.update_cell(row_number, target_col_index, target_value)
|
| 304 |
+
print(f"Updated row {row_number}: Set {target_column} to '{target_value}' where {condition_column} is '{condition_value}'")
|
| 305 |
+
return row_number
|
| 306 |
+
|
| 307 |
+
print(f"조건에 맞는 행을 찾을 수 없습니다: {condition_column} = '{condition_value}'")
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
print(f"Error updating cell by condition: {str(e)}")
|
| 312 |
+
raise
|
| 313 |
+
|
| 314 |
+
def get_all_values(self) -> List[str]:
|
| 315 |
+
"""Get all values from the huggingface_id column."""
|
| 316 |
+
self._reconnect_if_needed()
|
| 317 |
+
return [v for v in self._fetch_column_data() if v.strip()]
|
| 318 |
+
|
| 319 |
+
# Example usage
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
# Initialize sheet manager
|
| 322 |
+
sheet_manager = SheetManager()
|
| 323 |
+
|
| 324 |
+
# # Push some test values
|
| 325 |
+
# sheet_manager.push("test-model-1")
|
| 326 |
+
# sheet_manager.push("test-model-2")
|
| 327 |
+
# sheet_manager.push("test-model-3")
|
| 328 |
+
|
| 329 |
+
# print("Initial values:", sheet_manager.get_all_values())
|
| 330 |
+
|
| 331 |
+
# # Pop the most recent value
|
| 332 |
+
# popped = sheet_manager.pop()
|
| 333 |
+
# print(f"Popped value: {popped}")
|
| 334 |
+
# print("After pop:", sheet_manager.get_all_values())
|
| 335 |
+
|
| 336 |
+
# # Delete a specific value
|
| 337 |
+
# deleted_rows = sheet_manager.delete("test-model-2")
|
| 338 |
+
# print(f"Deleted from rows: {deleted_rows}")
|
| 339 |
+
# print("After delete:", sheet_manager.get_all_values())
|
| 340 |
+
|
| 341 |
+
row_updated = sheet_manager.update_cell_by_condition(
|
| 342 |
+
condition_column="model",
|
| 343 |
+
condition_value="msr",
|
| 344 |
+
target_column="pia",
|
| 345 |
+
target_value="new_value"
|
| 346 |
+
)
|
| 347 |
+
|
sheet_manager/sheet_loader/sheet2df.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 3 |
+
import gspread
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from enviroments.convert import get_json_from_env_var
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
def sheet2df(sheet_name:str = "model"):
|
| 11 |
+
"""
|
| 12 |
+
Reads data from a specified Google Spreadsheet and converts it into a Pandas DataFrame.
|
| 13 |
+
|
| 14 |
+
Steps:
|
| 15 |
+
1. Authenticate using a service account JSON key.
|
| 16 |
+
2. Open the spreadsheet by its URL.
|
| 17 |
+
3. Select the worksheet to read.
|
| 18 |
+
4. Convert the worksheet data to a Pandas DataFrame.
|
| 19 |
+
5. Clean up the DataFrame:
|
| 20 |
+
- Rename columns using the first row of data.
|
| 21 |
+
- Drop the first row after renaming columns.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
pd.DataFrame: A Pandas DataFrame containing the cleaned data from the spreadsheet.
|
| 25 |
+
|
| 26 |
+
Note:
|
| 27 |
+
- The following variables must be configured before using this function:
|
| 28 |
+
- `json_key_path`: Path to the service account JSON key file.
|
| 29 |
+
- `spreadsheet_url`: URL of the Google Spreadsheet.
|
| 30 |
+
- `sheet_name`: Name of the worksheet to load.
|
| 31 |
+
|
| 32 |
+
Dependencies:
|
| 33 |
+
- pandas
|
| 34 |
+
- gspread
|
| 35 |
+
- oauth2client
|
| 36 |
+
"""
|
| 37 |
+
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
|
| 38 |
+
json_key_dict =get_json_from_env_var("GOOGLE_CREDENTIALS")
|
| 39 |
+
credential = ServiceAccountCredentials.from_json_keyfile_dict(json_key_dict, scope)
|
| 40 |
+
gc = gspread.authorize(credential)
|
| 41 |
+
|
| 42 |
+
spreadsheet_url = os.getenv("SPREADSHEET_URL")
|
| 43 |
+
doc = gc.open_by_url(spreadsheet_url)
|
| 44 |
+
sheet = doc.worksheet(sheet_name)
|
| 45 |
+
|
| 46 |
+
# Convert to DataFrame
|
| 47 |
+
df = pd.DataFrame(sheet.get_all_values())
|
| 48 |
+
# Clean DataFrame
|
| 49 |
+
df.rename(columns=df.iloc[0], inplace=True)
|
| 50 |
+
df.drop(df.index[0], inplace=True)
|
| 51 |
+
|
| 52 |
+
return df
|
sheet_manager/sheet_monitor/sheet_sync.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
from typing import Optional, Callable
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
class SheetMonitor:
|
| 7 |
+
def __init__(self, sheet_manager, check_interval: float = 1.0):
|
| 8 |
+
"""
|
| 9 |
+
Initialize SheetMonitor with a sheet manager instance.
|
| 10 |
+
"""
|
| 11 |
+
self.sheet_manager = sheet_manager
|
| 12 |
+
self.check_interval = check_interval
|
| 13 |
+
|
| 14 |
+
# Threading control
|
| 15 |
+
self.monitor_thread = None
|
| 16 |
+
self.is_running = threading.Event()
|
| 17 |
+
self.pause_monitoring = threading.Event()
|
| 18 |
+
self.monitor_paused = threading.Event()
|
| 19 |
+
|
| 20 |
+
# Queue status
|
| 21 |
+
self.has_data = threading.Event()
|
| 22 |
+
|
| 23 |
+
# Logging setup
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def start_monitoring(self):
|
| 28 |
+
"""Start the monitoring thread."""
|
| 29 |
+
if self.monitor_thread is not None and self.monitor_thread.is_alive():
|
| 30 |
+
self.logger.warning("Monitoring thread is already running")
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
self.is_running.set()
|
| 34 |
+
self.pause_monitoring.clear()
|
| 35 |
+
self.monitor_thread = threading.Thread(target=self._monitor_loop)
|
| 36 |
+
self.monitor_thread.daemon = True
|
| 37 |
+
self.monitor_thread.start()
|
| 38 |
+
self.logger.info("Started monitoring thread")
|
| 39 |
+
|
| 40 |
+
def stop_monitoring(self):
|
| 41 |
+
"""Stop the monitoring thread."""
|
| 42 |
+
self.is_running.clear()
|
| 43 |
+
if self.monitor_thread:
|
| 44 |
+
self.monitor_thread.join()
|
| 45 |
+
self.logger.info("Stopped monitoring thread")
|
| 46 |
+
|
| 47 |
+
def pause(self):
|
| 48 |
+
"""Pause the monitoring."""
|
| 49 |
+
self.pause_monitoring.set()
|
| 50 |
+
self.monitor_paused.wait()
|
| 51 |
+
self.logger.info("Monitoring paused")
|
| 52 |
+
|
| 53 |
+
def resume(self):
|
| 54 |
+
"""Resume the monitoring."""
|
| 55 |
+
self.pause_monitoring.clear()
|
| 56 |
+
self.monitor_paused.clear()
|
| 57 |
+
# 즉시 체크 수행
|
| 58 |
+
self.logger.info("Monitoring resumed, checking for new data...")
|
| 59 |
+
values = self.sheet_manager.get_all_values()
|
| 60 |
+
if values:
|
| 61 |
+
self.has_data.set()
|
| 62 |
+
self.logger.info(f"Found data after resume: {values}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _monitor_loop(self):
|
| 66 |
+
"""Main monitoring loop that checks for data in sheet."""
|
| 67 |
+
while self.is_running.is_set():
|
| 68 |
+
if self.pause_monitoring.is_set():
|
| 69 |
+
self.monitor_paused.set()
|
| 70 |
+
self.pause_monitoring.wait()
|
| 71 |
+
self.monitor_paused.clear()
|
| 72 |
+
# continue
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Check if there's any data in the sheet
|
| 76 |
+
values = self.sheet_manager.get_all_values()
|
| 77 |
+
self.logger.info(f"Monitoring: Current column={self.sheet_manager.column_name}, "
|
| 78 |
+
f"Values found={len(values)}, "
|
| 79 |
+
f"Has data={self.has_data.is_set()}")
|
| 80 |
+
|
| 81 |
+
if values: # If there's any non-empty value
|
| 82 |
+
self.has_data.set()
|
| 83 |
+
self.logger.info(f"Data detected: {values}")
|
| 84 |
+
else:
|
| 85 |
+
self.has_data.clear()
|
| 86 |
+
self.logger.info("No data in sheet, waiting...")
|
| 87 |
+
|
| 88 |
+
time.sleep(self.check_interval)
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
self.logger.error(f"Error in monitoring loop: {str(e)}")
|
| 92 |
+
time.sleep(self.check_interval)
|
| 93 |
+
|
| 94 |
+
class MainLoop:
|
| 95 |
+
def __init__(self, sheet_manager, sheet_monitor, callback_function: Callable = None):
|
| 96 |
+
"""
|
| 97 |
+
Initialize MainLoop with sheet manager and monitor instances.
|
| 98 |
+
"""
|
| 99 |
+
self.sheet_manager = sheet_manager
|
| 100 |
+
self.monitor = sheet_monitor
|
| 101 |
+
self.callback = callback_function
|
| 102 |
+
self.is_running = threading.Event()
|
| 103 |
+
self.logger = logging.getLogger(__name__)
|
| 104 |
+
|
| 105 |
+
def start(self):
|
| 106 |
+
"""Start the main processing loop."""
|
| 107 |
+
self.is_running.set()
|
| 108 |
+
self.monitor.start_monitoring()
|
| 109 |
+
self._main_loop()
|
| 110 |
+
|
| 111 |
+
def stop(self):
|
| 112 |
+
"""Stop the main processing loop."""
|
| 113 |
+
self.is_running.clear()
|
| 114 |
+
self.monitor.stop_monitoring()
|
| 115 |
+
|
| 116 |
+
def process_new_value(self):
|
| 117 |
+
"""Process values by calling pop function for multiple columns and custom callback."""
|
| 118 |
+
try:
|
| 119 |
+
# Store original column
|
| 120 |
+
original_column = self.sheet_manager.column_name
|
| 121 |
+
|
| 122 |
+
# Pop from huggingface_id column
|
| 123 |
+
model_id = self.sheet_manager.pop()
|
| 124 |
+
|
| 125 |
+
if model_id:
|
| 126 |
+
# Pop from benchmark_name column
|
| 127 |
+
self.sheet_manager.change_column("benchmark_name")
|
| 128 |
+
benchmark_name = self.sheet_manager.pop()
|
| 129 |
+
|
| 130 |
+
# Pop from prompt_cfg_name column
|
| 131 |
+
self.sheet_manager.change_column("prompt_cfg_name")
|
| 132 |
+
prompt_cfg_name = self.sheet_manager.pop()
|
| 133 |
+
|
| 134 |
+
# Return to original column
|
| 135 |
+
self.sheet_manager.change_column(original_column)
|
| 136 |
+
|
| 137 |
+
self.logger.info(f"Processed values - model_id: {model_id}, "
|
| 138 |
+
f"benchmark_name: {benchmark_name}, "
|
| 139 |
+
f"prompt_cfg_name: {prompt_cfg_name}")
|
| 140 |
+
|
| 141 |
+
if self.callback:
|
| 142 |
+
# Pass all three values to callback
|
| 143 |
+
self.callback(model_id, benchmark_name, prompt_cfg_name)
|
| 144 |
+
|
| 145 |
+
return model_id, benchmark_name, prompt_cfg_name
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
self.logger.error(f"Error processing values: {str(e)}")
|
| 149 |
+
# Return to original column in case of error
|
| 150 |
+
try:
|
| 151 |
+
self.sheet_manager.change_column(original_column)
|
| 152 |
+
except:
|
| 153 |
+
pass
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
def _main_loop(self):
|
| 157 |
+
"""Main processing loop."""
|
| 158 |
+
while self.is_running.is_set():
|
| 159 |
+
# Wait for data to be available
|
| 160 |
+
if self.monitor.has_data.wait(timeout=1.0):
|
| 161 |
+
# Pause monitoring
|
| 162 |
+
self.monitor.pause()
|
| 163 |
+
|
| 164 |
+
# Process the value
|
| 165 |
+
self.process_new_value()
|
| 166 |
+
|
| 167 |
+
# Check if there's still data in the sheet
|
| 168 |
+
values = self.sheet_manager.get_all_values()
|
| 169 |
+
self.logger.info(f"After processing: Current column={self.sheet_manager.column_name}, "
|
| 170 |
+
f"Values remaining={len(values)}")
|
| 171 |
+
|
| 172 |
+
if not values:
|
| 173 |
+
self.monitor.has_data.clear()
|
| 174 |
+
self.logger.info("All data processed, clearing has_data flag")
|
| 175 |
+
else:
|
| 176 |
+
self.logger.info(f"Remaining data: {values}")
|
| 177 |
+
|
| 178 |
+
# Resume monitoring
|
| 179 |
+
self.monitor.resume()
|
| 180 |
+
## TODO
|
| 181 |
+
# API 분당 호출 문제로 만약에 참조하다가 실패할 경우 대기했다가 다시 시도하게끔 설계
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Example usage
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
import sys
|
| 187 |
+
from pathlib import Path
|
| 188 |
+
sys.path.append(str(Path(__file__).parent.parent.parent))
|
| 189 |
+
from sheet_manager.sheet_crud.sheet_crud import SheetManager
|
| 190 |
+
from pia_bench.pipe_line.piepline import PiaBenchMark
|
| 191 |
+
def my_custom_function(huggingface_id, benchmark_name, prompt_cfg_name):
|
| 192 |
+
piabenchmark = PiaBenchMark(huggingface_id, benchmark_name, prompt_cfg_name)
|
| 193 |
+
piabenchmark.bench_start()
|
| 194 |
+
|
| 195 |
+
# Initialize components
|
| 196 |
+
sheet_manager = SheetManager()
|
| 197 |
+
monitor = SheetMonitor(sheet_manager, check_interval=10.0)
|
| 198 |
+
main_loop = MainLoop(sheet_manager, monitor, callback_function=my_custom_function)
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
main_loop.start()
|
| 202 |
+
while True:
|
| 203 |
+
time.sleep(5)
|
| 204 |
+
except KeyboardInterrupt:
|
| 205 |
+
main_loop.stop()
|
topk.json
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"VIDEO_CFG": {
|
| 3 |
+
"window_size": 6,
|
| 4 |
+
"time_sampling": 15,
|
| 5 |
+
"tile_size": null
|
| 6 |
+
},
|
| 7 |
+
"MODEL_CFG": {
|
| 8 |
+
"_comment": "",
|
| 9 |
+
"_link": "",
|
| 10 |
+
"name": "assets/c7.pt",
|
| 11 |
+
"type": "clip4clip"
|
| 12 |
+
},
|
| 13 |
+
"PROMPT_CFG": [
|
| 14 |
+
{
|
| 15 |
+
"event": "falldown",
|
| 16 |
+
"top_candidates": 1,
|
| 17 |
+
"alert_threshold": 1,
|
| 18 |
+
"prompts": {
|
| 19 |
+
"normal": [
|
| 20 |
+
{
|
| 21 |
+
"sentence": "typical"
|
| 22 |
+
}
|
| 23 |
+
],
|
| 24 |
+
"abnormal": [
|
| 25 |
+
{
|
| 26 |
+
"sentence": "falldown"
|
| 27 |
+
}
|
| 28 |
+
]
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"event": "violence",
|
| 33 |
+
"top_candidates": 1,
|
| 34 |
+
"alert_threshold": 1,
|
| 35 |
+
"prompts": {
|
| 36 |
+
"normal": [
|
| 37 |
+
{
|
| 38 |
+
"sentence": "normal"
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"sentence": "average"
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"sentence": "typical"
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"abnormal": [
|
| 48 |
+
{
|
| 49 |
+
"sentence": "violence with kicking and punching"
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"sentence": "physical confrontation between people"
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"sentence": "violence"
|
| 56 |
+
}
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"event": "fire",
|
| 62 |
+
"top_candidates": 1,
|
| 63 |
+
"alert_threshold": 1,
|
| 64 |
+
"prompts": {
|
| 65 |
+
"normal": [
|
| 66 |
+
{
|
| 67 |
+
"sentence": "tomato"
|
| 68 |
+
}
|
| 69 |
+
],
|
| 70 |
+
"abnormal": [
|
| 71 |
+
{
|
| 72 |
+
"sentence": "fire"
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"sentence": "video of be on fire with a stove"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"sentence": "a fire is burning"
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"sentence": "embers are burning"
|
| 82 |
+
}
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
]
|
| 88 |
+
}
|
utils/bench_meta.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from utils.except_dir import cust_listdir
|
| 7 |
+
def get_video_metadata(video_path, category, benchmark):
|
| 8 |
+
"""Extract metadata from a video file."""
|
| 9 |
+
cap = cv2.VideoCapture(video_path)
|
| 10 |
+
|
| 11 |
+
if not cap.isOpened():
|
| 12 |
+
return None
|
| 13 |
+
# Extract metadata
|
| 14 |
+
video_name = os.path.basename(video_path)
|
| 15 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 16 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 17 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 18 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 19 |
+
resolution = f"{frame_width}x{frame_height}"
|
| 20 |
+
duration_seconds = frame_count / fps if fps > 0 else 0
|
| 21 |
+
aspect_ratio = round(frame_width / frame_height, 2) if frame_height > 0 else 0
|
| 22 |
+
file_size = os.path.getsize(video_path) / (1024 * 1024) # MB
|
| 23 |
+
file_format = os.path.splitext(video_name)[1].lower()
|
| 24 |
+
cap.release()
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"video_name": video_name,
|
| 28 |
+
"resolution": resolution,
|
| 29 |
+
"video_duration": f"{duration_seconds // 60:.0f}:{duration_seconds % 60:.0f}",
|
| 30 |
+
"category": category,
|
| 31 |
+
"benchmark": benchmark,
|
| 32 |
+
"duration_seconds": duration_seconds,
|
| 33 |
+
"total_frames": frame_count,
|
| 34 |
+
"file_format": file_format,
|
| 35 |
+
"file_size_mb": round(file_size, 2),
|
| 36 |
+
"aspect_ratio": aspect_ratio,
|
| 37 |
+
"fps": fps
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def process_videos_in_directory(root_dir):
|
| 41 |
+
"""Process all videos in the given directory structure."""
|
| 42 |
+
video_metadata_list = []
|
| 43 |
+
|
| 44 |
+
# 벤치마크 폴더들을 순회
|
| 45 |
+
for benchmark in cust_listdir(root_dir):
|
| 46 |
+
benchmark_path = os.path.join(root_dir, benchmark)
|
| 47 |
+
if not os.path.isdir(benchmark_path):
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
# dataset 폴더 경로
|
| 51 |
+
dataset_path = os.path.join(benchmark_path, "dataset")
|
| 52 |
+
if not os.path.isdir(dataset_path):
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# dataset 폴더 안의 카테고리 폴더들을 순회
|
| 56 |
+
for category in cust_listdir(dataset_path):
|
| 57 |
+
category_path = os.path.join(dataset_path, category)
|
| 58 |
+
if not os.path.isdir(category_path):
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
# 각 카테고리 폴더 안의 비디오 파일들을 처리
|
| 62 |
+
for file in cust_listdir(category_path):
|
| 63 |
+
file_path = os.path.join(category_path, file)
|
| 64 |
+
|
| 65 |
+
if file_path.lower().endswith(('.mp4', '.avi', '.mkv', '.mov', 'MOV')):
|
| 66 |
+
metadata = get_video_metadata(file_path, category, benchmark)
|
| 67 |
+
if metadata:
|
| 68 |
+
video_metadata_list.append(metadata)
|
| 69 |
+
# df = pd.DataFrame(video_metadata_list)
|
| 70 |
+
# df.to_csv('sample.csv', index=False)
|
| 71 |
+
return pd.DataFrame(video_metadata_list)
|
| 72 |
+
|
utils/except_dir.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
import enviroments.config as config
|
| 4 |
+
|
| 5 |
+
def cust_listdir(directory: str) -> List[str]:
|
| 6 |
+
"""
|
| 7 |
+
os.listdir와 유사하게 작동하지만 config에 정의된 폴더/파일들을 제외하고 목록을 반환합니다.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
directory (str): 탐색할 디렉토리 경로
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
List[str]: config의 EXCLUDE_DIRS에 정의된 폴더/파일들을 제외한 목록
|
| 14 |
+
"""
|
| 15 |
+
return [item for item in os.listdir(directory) if item not in config.EXCLUDE_DIRS]
|
utils/hf_api.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import HfApi
|
| 2 |
+
from typing import Optional, List, Dict, Any
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class ModelInfo:
|
| 7 |
+
"""모델 정보를 저장하는 데이터 클래스"""
|
| 8 |
+
model_id: str
|
| 9 |
+
last_modified: Any
|
| 10 |
+
downloads: int
|
| 11 |
+
private: bool
|
| 12 |
+
attributes: Dict[str, Any]
|
| 13 |
+
|
| 14 |
+
class HuggingFaceInfoManager:
|
| 15 |
+
def __init__(self, access_token: Optional[str] = None, organization: str = "PIA-SPACE-LAB"):
|
| 16 |
+
"""
|
| 17 |
+
HuggingFace API 관리자 클래스 초기화
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
access_token (str, optional): HuggingFace 액세스 토큰
|
| 21 |
+
organization (str): 조직 이름 (기본값: "PIA-SPACE-LAB")
|
| 22 |
+
|
| 23 |
+
Raises:
|
| 24 |
+
ValueError: access_token이 None일 경우 발생
|
| 25 |
+
"""
|
| 26 |
+
if access_token is None:
|
| 27 |
+
raise ValueError("액세스 토큰은 필수 입력값입니다. HuggingFace에서 발급받은 토큰을 입력해주세요.")
|
| 28 |
+
|
| 29 |
+
self.api = HfApi()
|
| 30 |
+
self.access_token = access_token
|
| 31 |
+
self.organization = organization
|
| 32 |
+
|
| 33 |
+
# API 호출 결과를 바로 처리하여 저장
|
| 34 |
+
api_models = self.api.list_models(author=self.organization, use_auth_token=self.access_token)
|
| 35 |
+
self._stored_models = []
|
| 36 |
+
self._model_infos = []
|
| 37 |
+
|
| 38 |
+
# 모든 모델 정보를 미리 처리하여 저장
|
| 39 |
+
for model in api_models:
|
| 40 |
+
# 기본 정보 저장
|
| 41 |
+
model_attrs = {}
|
| 42 |
+
for attr in dir(model):
|
| 43 |
+
if not attr.startswith("_"):
|
| 44 |
+
model_attrs[attr] = getattr(model, attr)
|
| 45 |
+
|
| 46 |
+
# ModelInfo 객체 생성 및 저장
|
| 47 |
+
model_info = ModelInfo(
|
| 48 |
+
model_id=model.modelId,
|
| 49 |
+
last_modified=model.lastModified,
|
| 50 |
+
downloads=model.downloads,
|
| 51 |
+
private=model.private,
|
| 52 |
+
attributes=model_attrs
|
| 53 |
+
)
|
| 54 |
+
self._model_infos.append(model_info)
|
| 55 |
+
self._stored_models.append(model)
|
| 56 |
+
|
| 57 |
+
def get_model_info(self) -> List[Dict[str, Any]]:
|
| 58 |
+
"""모든 모델의 정보를 반환"""
|
| 59 |
+
return [
|
| 60 |
+
{
|
| 61 |
+
'model_id': info.model_id,
|
| 62 |
+
'last_modified': info.last_modified,
|
| 63 |
+
'downloads': info.downloads,
|
| 64 |
+
'private': info.private,
|
| 65 |
+
**info.attributes
|
| 66 |
+
}
|
| 67 |
+
for info in self._model_infos
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
def get_model_ids(self) -> List[str]:
|
| 71 |
+
"""모든 모델의 ID 리스트 반환"""
|
| 72 |
+
return [info.model_id for info in self._model_infos]
|
| 73 |
+
|
| 74 |
+
def get_private_models(self) -> List[Dict[str, Any]]:
|
| 75 |
+
"""비공개 모델 정보 반환"""
|
| 76 |
+
return [
|
| 77 |
+
{
|
| 78 |
+
'model_id': info.model_id,
|
| 79 |
+
'last_modified': info.last_modified,
|
| 80 |
+
'downloads': info.downloads,
|
| 81 |
+
'private': info.private,
|
| 82 |
+
**info.attributes
|
| 83 |
+
}
|
| 84 |
+
for info in self._model_infos if info.private
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
def get_public_models(self) -> List[Dict[str, Any]]:
|
| 88 |
+
"""공개 모델 정보 반환"""
|
| 89 |
+
return [
|
| 90 |
+
{
|
| 91 |
+
'model_id': info.model_id,
|
| 92 |
+
'last_modified': info.last_modified,
|
| 93 |
+
'downloads': info.downloads,
|
| 94 |
+
'private': info.private,
|
| 95 |
+
**info.attributes
|
| 96 |
+
}
|
| 97 |
+
for info in self._model_infos if not info.private
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
def refresh_models(self) -> None:
|
| 101 |
+
"""모델 정보 새로고침 (새로운 API 호출 수행)"""
|
| 102 |
+
# 클래스 재초기화
|
| 103 |
+
self.__init__(self.access_token, self.organization)
|
utils/parser.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Dict, List, Tuple
|
| 3 |
+
|
| 4 |
+
def load_config(config_path: str) -> Dict:
|
| 5 |
+
"""JSON 설정 파일을 읽어서 딕셔너리로 반환"""
|
| 6 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 7 |
+
return json.load(f)
|
| 8 |
+
|
| 9 |
+
class PromptManager:
|
| 10 |
+
def __init__(self, config_path: str):
|
| 11 |
+
self.config = load_config(config_path)
|
| 12 |
+
self.sentences, self.index_mapping = self._extract_all_sentences_with_index()
|
| 13 |
+
self.reverse_mapping = self._create_reverse_mapping()
|
| 14 |
+
|
| 15 |
+
def _extract_all_sentences_with_index(self) -> Tuple[List[str], Dict]:
|
| 16 |
+
"""모든 sentence와 인덱스 매핑 추출"""
|
| 17 |
+
sentences = []
|
| 18 |
+
index_mapping = {}
|
| 19 |
+
|
| 20 |
+
for event_idx, event_config in enumerate(self.config.get('PROMPT_CFG', [])):
|
| 21 |
+
prompts = event_config.get('prompts', {})
|
| 22 |
+
for status in ['normal', 'abnormal']:
|
| 23 |
+
for prompt_idx, prompt in enumerate(prompts.get(status, [])):
|
| 24 |
+
sentence = prompt.get('sentence', '')
|
| 25 |
+
sentences.append(sentence)
|
| 26 |
+
index_mapping[(event_idx, status, prompt_idx)] = sentence
|
| 27 |
+
|
| 28 |
+
return sentences, index_mapping
|
| 29 |
+
|
| 30 |
+
def _create_reverse_mapping(self) -> Dict:
|
| 31 |
+
"""sentence -> indices 역방향 매핑 생성"""
|
| 32 |
+
reverse_map = {}
|
| 33 |
+
for indices, sent in self.index_mapping.items():
|
| 34 |
+
if sent not in reverse_map:
|
| 35 |
+
reverse_map[sent] = []
|
| 36 |
+
reverse_map[sent].append(indices)
|
| 37 |
+
return reverse_map
|
| 38 |
+
|
| 39 |
+
def get_sentence_indices(self, sentence: str) -> List[Tuple[int, str, int]]:
|
| 40 |
+
"""특정 sentence의 모든 인덱스 위치 반환"""
|
| 41 |
+
return self.reverse_mapping.get(sentence, [])
|
| 42 |
+
|
| 43 |
+
def get_details_by_sentence(self, sentence: str) -> List[Dict]:
|
| 44 |
+
"""sentence로 모든 관련 상세 정보 찾아 반환"""
|
| 45 |
+
indices = self.get_sentence_indices(sentence)
|
| 46 |
+
return [self.get_details_by_index(*idx) for idx in indices]
|
| 47 |
+
|
| 48 |
+
def get_details_by_index(self, event_idx: int, status: str, prompt_idx: int) -> Dict:
|
| 49 |
+
"""인덱스로 상세 정보 찾아 반환"""
|
| 50 |
+
event_config = self.config['PROMPT_CFG'][event_idx]
|
| 51 |
+
prompt = event_config['prompts'][status][prompt_idx]
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
'event': event_config['event'],
|
| 55 |
+
'status': status,
|
| 56 |
+
'sentence': prompt['sentence'],
|
| 57 |
+
'top_candidates': event_config['top_candidates'],
|
| 58 |
+
'alert_threshold': event_config['alert_threshold'],
|
| 59 |
+
'event_idx': event_idx,
|
| 60 |
+
'prompt_idx': prompt_idx
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
def get_all_sentences(self) -> List[str]:
|
| 64 |
+
"""모든 sentence 리스트 반환"""
|
| 65 |
+
return self.sentences
|