Spaces:
Running
Running
Markus Pobitzer commited on
Commit ·
b6d1c13
1
Parent(s): 7896daf
app
Browse files- .gitattributes +2 -0
- README.md +7 -7
- app.py +14 -0
- requirements.txt +3 -0
- src/gecora/__init__.py +1 -0
- src/gecora/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/app/__init__.py +1 -0
- src/gecora/app/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/app/__pycache__/i_to_v_app.cpython-312.pyc +0 -0
- src/gecora/app/i_to_v_app.py +249 -0
- src/gecora/cli/__init__.py +1 -0
- src/gecora/cli/loomis_painter.py +22 -0
- src/gecora/dataset/__init__.py +1 -0
- src/gecora/dataset/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/dataset/__pycache__/base_manager.cpython-312.pyc +0 -0
- src/gecora/dataset/__pycache__/video_manager.cpython-312.pyc +0 -0
- src/gecora/dataset/__pycache__/video_pkl_manager.cpython-312.pyc +0 -0
- src/gecora/dataset/__pycache__/vieo_pkl_manager.cpython-312.pyc +0 -0
- src/gecora/dataset/base_manager.py +82 -0
- src/gecora/dataset/create_test_dataset.py +48 -0
- src/gecora/dataset/sub_dir_manager.py +117 -0
- src/gecora/dataset/video_manager.py +156 -0
- src/gecora/dataset/video_pkl_manager.py +207 -0
- src/gecora/dataset_converting/__init__.py +1 -0
- src/gecora/dataset_converting/video_pkl_to_video.py +67 -0
- src/gecora/db/__init__.py +1 -0
- src/gecora/db/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/db/__pycache__/hf_jsonl.cpython-312.pyc +0 -0
- src/gecora/db/__pycache__/sqlite.cpython-312.pyc +0 -0
- src/gecora/db/hf_jsonl.py +385 -0
- src/gecora/db/sqlite.py +279 -0
- src/gecora/logging/__init__.py +1 -0
- src/gecora/logging/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/logging/__pycache__/logger.cpython-312.pyc +0 -0
- src/gecora/logging/logger.py +16 -0
- src/gecora/logic/__init__.py +1 -0
- src/gecora/logic/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/logic/__pycache__/base.cpython-312.pyc +0 -0
- src/gecora/logic/__pycache__/loomis_painter.cpython-312.pyc +0 -0
- src/gecora/logic/__pycache__/utils.cpython-312.pyc +0 -0
- src/gecora/logic/base.py +53 -0
- src/gecora/logic/loomis_painter.py +205 -0
- src/gecora/logic/utils.py +37 -0
- src/gecora/py.typed +0 -0
- src/gecora/ranking/__init__.py +1 -0
- src/gecora/ranking/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gecora/ranking/__pycache__/ranking_system.cpython-312.pyc +0 -0
- src/gecora/ranking/ranking_system.py +178 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
short_description:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Test User Study
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.5.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: Test for a user study
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# 1. Add the 'src' folder to the Python path so we can import 'gecora'
|
| 6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
|
| 7 |
+
|
| 8 |
+
from gecora.logic.loomis_painter import LoomisPainterApp
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
app = LoomisPainterApp(
|
| 12 |
+
root_path="./", dataset_path="data/", hf_repo_id="Markus-Pobitzer/gecora-wlp", desired_num_selections=40
|
| 13 |
+
)
|
| 14 |
+
app.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
imageio>=2.37.0
|
| 2 |
+
pillow
|
| 3 |
+
huggingface_hub
|
src/gecora/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init of project."""
|
src/gecora/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
src/gecora/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/app/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
src/gecora/app/__pycache__/i_to_v_app.cpython-312.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
src/gecora/app/i_to_v_app.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from gecora.logic.base import LogicBase
|
| 6 |
+
from gecora.logic.utils import cleanup_list, create_temp_file
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ItoVApp:
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
logic_class: LogicBase,
|
| 13 |
+
task_desc: str = "Select the best video for the reference image.",
|
| 14 |
+
ref_img_label: str = "Reference Image",
|
| 15 |
+
left_media_label: str = "Left Video",
|
| 16 |
+
right_media_label: str = "Right Video",
|
| 17 |
+
desired_num_selections: Optional[int] = None,
|
| 18 |
+
):
|
| 19 |
+
self.logic_class = logic_class
|
| 20 |
+
self.desired_num_selections = desired_num_selections
|
| 21 |
+
# User specific input
|
| 22 |
+
self.user_input = gr.Accordion(label="1. Enter your name")
|
| 23 |
+
self.username_input = gr.Textbox(label=None, show_label=False)
|
| 24 |
+
self.select_button = gr.Button("Select Name")
|
| 25 |
+
# Preference task
|
| 26 |
+
self.task = gr.Accordion(label="2. Task", visible=False)
|
| 27 |
+
self.task_description: str = task_desc
|
| 28 |
+
self.reference_image = gr.Image(label=ref_img_label, height=512)
|
| 29 |
+
self.left_media = gr.Video(label=left_media_label, height=512, autoplay=True)
|
| 30 |
+
self.right_media = gr.Video(label=right_media_label, height=512, autoplay=True)
|
| 31 |
+
self.left_button = gr.Button("←")
|
| 32 |
+
self.tie_button = gr.Button("Tie")
|
| 33 |
+
self.right_button = gr.Button("→")
|
| 34 |
+
self.tmp_video_path_left = create_temp_file()
|
| 35 |
+
self.tmp_video_path_right = create_temp_file()
|
| 36 |
+
|
| 37 |
+
def set_username(self, username):
|
| 38 |
+
user_id = self.logic_class.set_username(username=username)
|
| 39 |
+
if username and user_id is not None:
|
| 40 |
+
next_comp = self.logic_class.get_next_comparison(user_id=user_id)
|
| 41 |
+
if next_comp is None:
|
| 42 |
+
gr.Error("Error: Loading the content failed.")
|
| 43 |
+
return 0, gr.update(visible=False), "", "", "", "", None, None, None
|
| 44 |
+
|
| 45 |
+
(
|
| 46 |
+
(ret_reference_id, ret_model_left_id, ret_model_right_id),
|
| 47 |
+
(reference_image, left_video, right_video),
|
| 48 |
+
(num_preferences, total_num_comparison),
|
| 49 |
+
) = next_comp
|
| 50 |
+
|
| 51 |
+
progress_str = str(num_preferences) + " preferences selected!"
|
| 52 |
+
if total_num_comparison > 0:
|
| 53 |
+
perc_num = (
|
| 54 |
+
self.desired_num_selections if self.desired_num_selections is not None else total_num_comparison
|
| 55 |
+
)
|
| 56 |
+
progress_str = (
|
| 57 |
+
str(int(num_preferences / perc_num * 100))
|
| 58 |
+
+ f"% ({num_preferences} / {total_num_comparison} total) preferences selected!"
|
| 59 |
+
)
|
| 60 |
+
return (
|
| 61 |
+
user_id,
|
| 62 |
+
gr.update(visible=True),
|
| 63 |
+
ret_reference_id,
|
| 64 |
+
ret_model_left_id,
|
| 65 |
+
ret_model_right_id,
|
| 66 |
+
progress_str,
|
| 67 |
+
reference_image,
|
| 68 |
+
left_video,
|
| 69 |
+
right_video,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
gr.Error(
|
| 73 |
+
f"Failed to set Username {username}. Make sure to put in a value in the textfield, otherwise try to reload webpage."
|
| 74 |
+
)
|
| 75 |
+
return 0, gr.update(visible=False), "", "", "", "", None, None, None
|
| 76 |
+
|
| 77 |
+
def set_preference(
|
| 78 |
+
self, user_id, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
|
| 79 |
+
) -> Tuple[str, str, str, str, Any, Any, Any]:
|
| 80 |
+
"""Returns reference_id: str, model_left_id: str, model_right_id: str, reference_image: PIL.Image.Image, left_video:str, right_video:str."""
|
| 81 |
+
user_id = int(user_id) # type: ignore
|
| 82 |
+
succ, msg = self.logic_class.set_preference(
|
| 83 |
+
user_id=user_id,
|
| 84 |
+
reference_id=reference_id,
|
| 85 |
+
model_left_id=model_left_id,
|
| 86 |
+
model_right_id=model_right_id,
|
| 87 |
+
preferred_side=preferred_side,
|
| 88 |
+
)
|
| 89 |
+
if not succ:
|
| 90 |
+
gr.Info(f"Something went wrong: {msg}")
|
| 91 |
+
next_comp = self.logic_class.get_next_comparison(user_id=user_id)
|
| 92 |
+
if next_comp is None:
|
| 93 |
+
gr.Error("We are sorry, something went wrong! Please try to reload the page.")
|
| 94 |
+
return "", "", "", "", None, None, None # type: ignore
|
| 95 |
+
|
| 96 |
+
if preferred_side == "left":
|
| 97 |
+
gr.Success("You chose the left side!\n💪🙂")
|
| 98 |
+
elif preferred_side == "right":
|
| 99 |
+
gr.Success("You chose the right side!\n🙂💪")
|
| 100 |
+
elif preferred_side == "tie":
|
| 101 |
+
gr.Success("It's a tie!\n🤝")
|
| 102 |
+
|
| 103 |
+
(
|
| 104 |
+
(ret_reference_id, ret_model_left_id, ret_model_right_id),
|
| 105 |
+
(reference_image, left_video, right_video),
|
| 106 |
+
(num_preferences, total_num_comparison),
|
| 107 |
+
) = next_comp
|
| 108 |
+
progress_str = str(num_preferences) + " preferences selected!"
|
| 109 |
+
if total_num_comparison > 0:
|
| 110 |
+
perc_num = self.desired_num_selections if self.desired_num_selections is not None else total_num_comparison
|
| 111 |
+
progress_str = (
|
| 112 |
+
str(int(num_preferences / perc_num * 100))
|
| 113 |
+
+ f"% ({num_preferences} / {total_num_comparison} total) preferences selected!"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return (
|
| 117 |
+
ret_reference_id,
|
| 118 |
+
ret_model_left_id,
|
| 119 |
+
ret_model_right_id,
|
| 120 |
+
progress_str,
|
| 121 |
+
reference_image,
|
| 122 |
+
left_video,
|
| 123 |
+
right_video,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def choose_left(self, user_id: str, reference_id: str, model_left_id: str, model_right_id: str):
|
| 127 |
+
return self.set_preference(
|
| 128 |
+
user_id=user_id,
|
| 129 |
+
reference_id=reference_id,
|
| 130 |
+
model_left_id=model_left_id,
|
| 131 |
+
model_right_id=model_right_id,
|
| 132 |
+
preferred_side="left",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def choose_tie(self, user_id: str, reference_id: str, model_left_id: str, model_right_id: str):
|
| 136 |
+
return self.set_preference(
|
| 137 |
+
user_id=user_id,
|
| 138 |
+
reference_id=reference_id,
|
| 139 |
+
model_left_id=model_left_id,
|
| 140 |
+
model_right_id=model_right_id,
|
| 141 |
+
preferred_side="tie",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def choose_right(self, user_id: str, reference_id: str, model_left_id: str, model_right_id: str):
|
| 145 |
+
return self.set_preference(
|
| 146 |
+
user_id=user_id,
|
| 147 |
+
reference_id=reference_id,
|
| 148 |
+
model_left_id=model_left_id,
|
| 149 |
+
model_right_id=model_right_id,
|
| 150 |
+
preferred_side="right",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def build_interface(self):
|
| 154 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 155 |
+
local_storage = gr.BrowserState(["", ""])
|
| 156 |
+
with self.user_input.render():
|
| 157 |
+
with gr.Row():
|
| 158 |
+
user_id = gr.Textbox(show_label=False, visible=False)
|
| 159 |
+
self.username_input.render()
|
| 160 |
+
self.select_button.render()
|
| 161 |
+
|
| 162 |
+
with self.task.render():
|
| 163 |
+
gr.Markdown(self.task_description)
|
| 164 |
+
reference_id = gr.Textbox(show_label=False, visible=False)
|
| 165 |
+
model_left_id = gr.Textbox(show_label=False, visible=False)
|
| 166 |
+
model_right_id = gr.Textbox(show_label=False, visible=False)
|
| 167 |
+
progress = gr.Textbox(show_label=False, text_align="right")
|
| 168 |
+
|
| 169 |
+
with gr.Row():
|
| 170 |
+
self.left_media.render()
|
| 171 |
+
self.reference_image.render()
|
| 172 |
+
self.right_media.render()
|
| 173 |
+
|
| 174 |
+
with gr.Row():
|
| 175 |
+
self.left_button.render()
|
| 176 |
+
self.tie_button.render()
|
| 177 |
+
self.right_button.render()
|
| 178 |
+
|
| 179 |
+
@demo.load(inputs=[local_storage], outputs=[user_id, self.username_input])
|
| 180 |
+
def load_from_local_storage(saved_values):
|
| 181 |
+
return saved_values[0], saved_values[1]
|
| 182 |
+
|
| 183 |
+
@gr.on([user_id.change], inputs=[user_id, self.username_input], outputs=[local_storage])
|
| 184 |
+
def save_to_local_storage(user_id, username):
|
| 185 |
+
return [user_id, username]
|
| 186 |
+
|
| 187 |
+
self.select_button.click(
|
| 188 |
+
self.set_username,
|
| 189 |
+
inputs=[self.username_input],
|
| 190 |
+
outputs=[
|
| 191 |
+
user_id,
|
| 192 |
+
self.task,
|
| 193 |
+
reference_id,
|
| 194 |
+
model_left_id,
|
| 195 |
+
model_right_id,
|
| 196 |
+
progress,
|
| 197 |
+
self.reference_image,
|
| 198 |
+
self.left_media,
|
| 199 |
+
self.right_media,
|
| 200 |
+
],
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.left_button.click(
|
| 204 |
+
self.choose_left,
|
| 205 |
+
inputs=[user_id, reference_id, model_left_id, model_right_id],
|
| 206 |
+
outputs=[
|
| 207 |
+
reference_id,
|
| 208 |
+
model_left_id,
|
| 209 |
+
model_right_id,
|
| 210 |
+
progress,
|
| 211 |
+
self.reference_image,
|
| 212 |
+
self.left_media,
|
| 213 |
+
self.right_media,
|
| 214 |
+
],
|
| 215 |
+
)
|
| 216 |
+
self.tie_button.click(
|
| 217 |
+
self.choose_tie,
|
| 218 |
+
inputs=[user_id, reference_id, model_left_id, model_right_id],
|
| 219 |
+
outputs=[
|
| 220 |
+
reference_id,
|
| 221 |
+
model_left_id,
|
| 222 |
+
model_right_id,
|
| 223 |
+
progress,
|
| 224 |
+
self.reference_image,
|
| 225 |
+
self.left_media,
|
| 226 |
+
self.right_media,
|
| 227 |
+
],
|
| 228 |
+
)
|
| 229 |
+
self.right_button.click(
|
| 230 |
+
self.choose_right,
|
| 231 |
+
inputs=[user_id, reference_id, model_left_id, model_right_id],
|
| 232 |
+
outputs=[
|
| 233 |
+
reference_id,
|
| 234 |
+
model_left_id,
|
| 235 |
+
model_right_id,
|
| 236 |
+
progress,
|
| 237 |
+
self.reference_image,
|
| 238 |
+
self.left_media,
|
| 239 |
+
self.right_media,
|
| 240 |
+
],
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
demo.unload(fn=lambda: cleanup_list([self.tmp_video_path_left, self.tmp_video_path_right]))
|
| 244 |
+
|
| 245 |
+
return demo
|
| 246 |
+
|
| 247 |
+
def launch(self):
|
| 248 |
+
app = self.build_interface()
|
| 249 |
+
app.launch(server_name="0.0.0.0")
|
src/gecora/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/cli/loomis_painter.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from gecora.logic.loomis_painter import LoomisPainterApp
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(
|
| 8 |
+
description="Loomis Painter Ranking CLI - Evaluate painting processes and update model rankings."
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
parser.add_argument("root_path", type=str, help="Root directory containing Hugging Face datasets.")
|
| 12 |
+
parser.add_argument("--dataset_path", type=str, help="Optional path to the dataset.")
|
| 13 |
+
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
|
| 16 |
+
app = LoomisPainterApp(args.root_path, args.dataset_path)
|
| 17 |
+
print("App initialized.")
|
| 18 |
+
app.launch()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
main()
|
src/gecora/dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/dataset/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
src/gecora/dataset/__pycache__/base_manager.cpython-312.pyc
ADDED
|
Binary file (4.69 kB). View file
|
|
|
src/gecora/dataset/__pycache__/video_manager.cpython-312.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
src/gecora/dataset/__pycache__/video_pkl_manager.cpython-312.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
src/gecora/dataset/__pycache__/vieo_pkl_manager.cpython-312.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
src/gecora/dataset/base_manager.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import asdict, dataclass
|
| 5 |
+
from typing import Any, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from gecora.logging.logger import setup_file_logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class DatasetManagerConfig:
|
| 12 |
+
root_path: str
|
| 13 |
+
gt_dataset: str = "gt"
|
| 14 |
+
dataset_split: str = "test"
|
| 15 |
+
entry_id: str = "entry_id"
|
| 16 |
+
reference_column_name: str = "reference"
|
| 17 |
+
genereated_column_name: str = "generated"
|
| 18 |
+
logging_path: Optional[str] = None
|
| 19 |
+
|
| 20 |
+
def to_json(self, path: str):
|
| 21 |
+
with open(path, "w") as f:
|
| 22 |
+
json.dump(asdict(self), f)
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def from_json(cls, path: str):
|
| 26 |
+
with open(path, "r") as f:
|
| 27 |
+
data = json.load(f)
|
| 28 |
+
return cls(**data)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseDatasetManager:
|
| 32 |
+
"""
|
| 33 |
+
A class to manage Hugging Face datasets stored in subdirectories of a given root path.
|
| 34 |
+
|
| 35 |
+
Attributes:
|
| 36 |
+
root_path (str): The root directory containing subfolders with Hugging Face datasets.
|
| 37 |
+
dataset_split (str): If the datasets are a Dict, select specified split.
|
| 38 |
+
common_entry_ids (List[str]): List of 'entry_id's present in all datasets.
|
| 39 |
+
partial_entry_ids (List[str]): List of 'entry_id's present in only some datasets.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, config: DatasetManagerConfig):
|
| 43 |
+
self.config = config
|
| 44 |
+
self.common_entry_ids: List[str] = []
|
| 45 |
+
self.partial_entry_ids: List[str] = []
|
| 46 |
+
|
| 47 |
+
logger_name = f"{self.__class__.__name__}_logger"
|
| 48 |
+
log_file = os.path.join(self.config.logging_path, f"{logger_name}.txt") if self.config.logging_path else None
|
| 49 |
+
self.logger = setup_file_logger(logger_name, log_file) if log_file else logging.getLogger(logger_name)
|
| 50 |
+
|
| 51 |
+
self.logger = logging.getLogger()
|
| 52 |
+
|
| 53 |
+
def get_dataset_names(self) -> List[str]:
|
| 54 |
+
"""
|
| 55 |
+
Returns a list of all loaded dataset names.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
List[str]: Names of the datasets.
|
| 59 |
+
"""
|
| 60 |
+
raise ValueError("Must be implemented by child class.")
|
| 61 |
+
|
| 62 |
+
def get_entries_by_id(
|
| 63 |
+
self, entry_id: str, dataset_name1: str, dataset_name2: str
|
| 64 |
+
) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 65 |
+
"""
|
| 66 |
+
Retrieves entries from two datasets by a specific 'entry_id'.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
entry_id (str): The 'entry_id' to search for.
|
| 70 |
+
dataset_name1 (str): Name of the first dataset.
|
| 71 |
+
dataset_name2 (str): Name of the second dataset.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 75 |
+
Reference content, or None if not found.
|
| 76 |
+
Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
|
| 77 |
+
Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
|
| 78 |
+
|
| 79 |
+
Raises:
|
| 80 |
+
ValueError: If one or both dataset names are not found.
|
| 81 |
+
"""
|
| 82 |
+
raise ValueError("Must be implemented by child class.")
|
src/gecora/dataset/create_test_dataset.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import imageio
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from datasets import Dataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_random_image(size=(512, 512)):
|
| 11 |
+
array = np.random.randint(0, 256, size + (3,), dtype=np.uint8)
|
| 12 |
+
return Image.fromarray(array)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def generate_random_video(num_frames=10, size=(512, 512)):
|
| 16 |
+
frames = [np.random.randint(0, 256, size + (3,), dtype=np.uint8) for _ in range(num_frames)]
|
| 17 |
+
return frames
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def save_video(frames, path):
|
| 21 |
+
imageio.mimsave(path, frames, fps=5)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def create_i_to_v_huggingface_datasets(output_dir, num_datasets=3, num_entries=5):
|
| 25 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 26 |
+
entry_ids = [f"id_{i}" for i in range(num_entries)]
|
| 27 |
+
|
| 28 |
+
for dataset_index in range(num_datasets):
|
| 29 |
+
data = {"entry_id": [], "reference_image": [], "video": []}
|
| 30 |
+
dataset_path = os.path.join(output_dir, f"dataset_{dataset_index}")
|
| 31 |
+
os.makedirs(dataset_path, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
for entry_id in entry_ids:
|
| 34 |
+
img = generate_random_image()
|
| 35 |
+
img_path = os.path.join(dataset_path, f"{entry_id}_image.png")
|
| 36 |
+
img.save(img_path)
|
| 37 |
+
|
| 38 |
+
video_frames = generate_random_video()
|
| 39 |
+
video_path = os.path.join(dataset_path, f"{entry_id}_video.mp4")
|
| 40 |
+
save_video(video_frames, video_path)
|
| 41 |
+
|
| 42 |
+
data["entry_id"].append(entry_id)
|
| 43 |
+
data["reference_image"].append(img_path)
|
| 44 |
+
data["video"].append(video_path)
|
| 45 |
+
|
| 46 |
+
df = pd.DataFrame(data)
|
| 47 |
+
hf_dataset = Dataset.from_pandas(df)
|
| 48 |
+
hf_dataset.save_to_disk(dataset_path)
|
src/gecora/dataset/sub_dir_manager.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
from datasets import Dataset, DatasetDict, load_from_disk
|
| 5 |
+
|
| 6 |
+
from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SubDirDatasetManager(BaseDatasetManager):
|
| 10 |
+
"""
|
| 11 |
+
A class to manage Hugging Face datasets stored in subdirectories of a given root path.
|
| 12 |
+
|
| 13 |
+
Attributes:
|
| 14 |
+
root_path (str): The root directory containing subfolders with Hugging Face datasets.
|
| 15 |
+
dataset_split (str): If the datasets are a Dict, select specified split.
|
| 16 |
+
datasets (Dict[str, Union[Dataset, DatasetDict]]): Dictionary mapping dataset names to loaded datasets.
|
| 17 |
+
common_entry_ids (List[str]): List of 'entry_id's present in all datasets.
|
| 18 |
+
partial_entry_ids (List[str]): List of 'entry_id's present in only some datasets.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: DatasetManagerConfig) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Initializes the dataset manager and loads datasets from subdirectories.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
config (DatasetManagerConfig): Configuration.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__(config=config)
|
| 29 |
+
self.datasets: Dict[str, Union[Dataset, DatasetDict]] = {}
|
| 30 |
+
self._load_datasets()
|
| 31 |
+
self._analyze_entry_ids()
|
| 32 |
+
self.logger.info(f"Loaded following datasets: {self.get_dataset_names()}")
|
| 33 |
+
self.logger.info(f"Found {len(self.common_entry_ids)} entries in all datasets.")
|
| 34 |
+
self.logger.info(f"Found {len(self.partial_entry_ids)} entries only in some datasets.")
|
| 35 |
+
|
| 36 |
+
def _load_datasets(self) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Loads all Hugging Face datasets from subdirectories in the root path.
|
| 39 |
+
"""
|
| 40 |
+
for subdir in os.listdir(self.config.root_path):
|
| 41 |
+
full_path = os.path.join(self.config.root_path, subdir)
|
| 42 |
+
if os.path.isdir(full_path):
|
| 43 |
+
try:
|
| 44 |
+
dataset = load_from_disk(full_path)
|
| 45 |
+
self.datasets[subdir] = dataset
|
| 46 |
+
except Exception as e:
|
| 47 |
+
self.logger.info(f"Skipping {subdir}: {e}")
|
| 48 |
+
|
| 49 |
+
def get_dataset_names(self) -> List[str]:
|
| 50 |
+
"""
|
| 51 |
+
Returns a list of all loaded dataset names.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List[str]: Names of the datasets.
|
| 55 |
+
"""
|
| 56 |
+
return list(self.datasets.keys())
|
| 57 |
+
|
| 58 |
+
def _analyze_entry_ids(self) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Analyzes all datasets to find common and partial 'entry_id's.
|
| 61 |
+
Updates `common_entry_ids` and `partial_entry_ids` attributes.
|
| 62 |
+
"""
|
| 63 |
+
entry_id_sets: List[set] = []
|
| 64 |
+
|
| 65 |
+
for dataset in self.datasets.values():
|
| 66 |
+
if isinstance(dataset, DatasetDict):
|
| 67 |
+
dataset = dataset[self.config.dataset_split]
|
| 68 |
+
entry_ids = set(dataset[self.config.entry_id])
|
| 69 |
+
entry_id_sets.append(entry_ids)
|
| 70 |
+
|
| 71 |
+
if entry_id_sets:
|
| 72 |
+
self.common_entry_ids = list(set.intersection(*entry_id_sets))
|
| 73 |
+
all_entry_ids = set.union(*entry_id_sets)
|
| 74 |
+
self.partial_entry_ids = list(all_entry_ids - set(self.common_entry_ids))
|
| 75 |
+
|
| 76 |
+
def get_entries_by_id(
|
| 77 |
+
self, entry_id: str, dataset_name1: str, dataset_name2: str
|
| 78 |
+
) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 79 |
+
"""
|
| 80 |
+
Retrieves entries from two datasets by a specific 'entry_id'.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
entry_id (str): The 'entry_id' to search for.
|
| 84 |
+
dataset_name1 (str): Name of the first dataset.
|
| 85 |
+
dataset_name2 (str): Name of the second dataset.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 89 |
+
Reference content, or None if not found.
|
| 90 |
+
Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
|
| 91 |
+
Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
|
| 92 |
+
|
| 93 |
+
Raises:
|
| 94 |
+
ValueError: If one or both dataset names are not found.
|
| 95 |
+
"""
|
| 96 |
+
if dataset_name1 not in self.datasets or dataset_name2 not in self.datasets:
|
| 97 |
+
raise ValueError("One or both dataset names not found.")
|
| 98 |
+
|
| 99 |
+
def find_entry(dataset: Union[Dataset, DatasetDict], entry_id: str) -> Optional[Dict]:
|
| 100 |
+
if isinstance(dataset, DatasetDict):
|
| 101 |
+
dataset = dataset[self.config.dataset_split]
|
| 102 |
+
for entry in dataset:
|
| 103 |
+
if entry.get(self.config.entry_id) == entry_id:
|
| 104 |
+
return entry
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
entry1 = find_entry(self.datasets[dataset_name1], entry_id)
|
| 108 |
+
entry2 = find_entry(self.datasets[dataset_name2], entry_id)
|
| 109 |
+
if entry1 is not None:
|
| 110 |
+
reference_image = entry1[self.config.reference_column_name]
|
| 111 |
+
else:
|
| 112 |
+
reference_image = None
|
| 113 |
+
|
| 114 |
+
ret_entry1 = entry1[self.config.genereated_column_name] if entry1 is not None else None
|
| 115 |
+
ret_entry2 = entry2[self.config.genereated_column_name] if entry2 is not None else None
|
| 116 |
+
|
| 117 |
+
return reference_image, ret_entry1, ret_entry2
|
src/gecora/dataset/video_manager.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(
|
| 9 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 10 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 11 |
+
level=logging.INFO,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VideoManager(BaseDatasetManager):
|
| 18 |
+
def __init__(self, config: DatasetManagerConfig) -> None:
|
| 19 |
+
"""
|
| 20 |
+
Initializes the dataset manager and loads datasets from subdirectories.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
config (DatasetManagerConfig): Configuration.
|
| 24 |
+
"""
|
| 25 |
+
super().__init__(config=config)
|
| 26 |
+
self.datasets: Dict[str, Dict[str, Dict[str, str]]] = {}
|
| 27 |
+
self.reference_images: Dict[str, Dict[str, str]] = {}
|
| 28 |
+
self.root_path = Path(self.config.root_path)
|
| 29 |
+
self.common_entry_ids = []
|
| 30 |
+
self.partial_entry_ids = []
|
| 31 |
+
self._load_datasets()
|
| 32 |
+
self._analyze_entry_ids()
|
| 33 |
+
self.logger.info(f"Loaded following datasets: {self.get_dataset_names()}")
|
| 34 |
+
self.logger.info(f"Found {len(self.common_entry_ids)} entries in all datasets.")
|
| 35 |
+
self.logger.info(f"Found {len(self.partial_entry_ids)} entries only in some datasets.")
|
| 36 |
+
|
| 37 |
+
def _collect_reference_images(self, dataset_dir: Path, split: str) -> Dict[str, Dict[str, str]]:
|
| 38 |
+
"""Collects all reference images under the split."""
|
| 39 |
+
split_dir = dataset_dir / split
|
| 40 |
+
dataset: Dict[str, Dict[str, str]] = {}
|
| 41 |
+
for source_dir in split_dir.iterdir():
|
| 42 |
+
if source_dir.is_dir():
|
| 43 |
+
for video_dir in source_dir.iterdir():
|
| 44 |
+
if video_dir.is_dir():
|
| 45 |
+
reference_image = video_dir / "reference_image.png"
|
| 46 |
+
if reference_image.is_file():
|
| 47 |
+
source = source_dir.stem
|
| 48 |
+
url = video_dir.stem
|
| 49 |
+
id = source + "_" + url
|
| 50 |
+
dataset[id] = {
|
| 51 |
+
"source": source,
|
| 52 |
+
"url": url,
|
| 53 |
+
"id": id,
|
| 54 |
+
"reference_image": str(reference_image),
|
| 55 |
+
}
|
| 56 |
+
return dataset
|
| 57 |
+
|
| 58 |
+
def _collect_video_dirs(self, dataset_dir: Path, split: str) -> Dict[str, Dict[str, str]]:
|
| 59 |
+
"""Collects all video subdirectories under the split."""
|
| 60 |
+
split_dir = dataset_dir / split
|
| 61 |
+
dataset: Dict[str, Dict[str, str]] = {}
|
| 62 |
+
for source_dir in split_dir.iterdir():
|
| 63 |
+
if source_dir.is_dir():
|
| 64 |
+
for video_dir in source_dir.iterdir():
|
| 65 |
+
if video_dir.is_dir():
|
| 66 |
+
video_file = video_dir / "video.mp4"
|
| 67 |
+
if video_file.is_file():
|
| 68 |
+
source = source_dir.stem
|
| 69 |
+
url = video_dir.stem
|
| 70 |
+
id = source + "_" + url
|
| 71 |
+
dataset[id] = {
|
| 72 |
+
"source": source,
|
| 73 |
+
"url": url,
|
| 74 |
+
"id": id,
|
| 75 |
+
"video": str(video_file),
|
| 76 |
+
}
|
| 77 |
+
return dataset
|
| 78 |
+
|
| 79 |
+
def _load_datasets(self) -> None:
|
| 80 |
+
"""
|
| 81 |
+
Loads all datasets from subdirectories in the root path.
|
| 82 |
+
"""
|
| 83 |
+
for subdir in self.root_path.iterdir():
|
| 84 |
+
if subdir.is_dir():
|
| 85 |
+
dataset_name = subdir.stem
|
| 86 |
+
if dataset_name == self.config.gt_dataset:
|
| 87 |
+
# GT dataset only use reference images
|
| 88 |
+
self.reference_images = self._collect_reference_images(subdir, split=self.config.dataset_split)
|
| 89 |
+
if len(self.reference_images.keys()) == 0:
|
| 90 |
+
raise ValueError(f"No entries found for ground truth dataset {dataset_name} under {subdir}.")
|
| 91 |
+
else:
|
| 92 |
+
try:
|
| 93 |
+
dataset = self._collect_video_dirs(subdir, split=self.config.dataset_split)
|
| 94 |
+
if len(dataset.keys()) == 0:
|
| 95 |
+
self.logger.info(f"No entries found for datataset {dataset}")
|
| 96 |
+
else:
|
| 97 |
+
self.datasets[dataset_name] = dataset
|
| 98 |
+
except Exception as e:
|
| 99 |
+
self.logger.info(f"Skipping {subdir}: {e}")
|
| 100 |
+
|
| 101 |
+
def get_dataset_names(self) -> List[str]:
|
| 102 |
+
"""
|
| 103 |
+
Returns a list of all loaded dataset names.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
List[str]: Names of the datasets.
|
| 107 |
+
"""
|
| 108 |
+
return list(self.datasets.keys())
|
| 109 |
+
|
| 110 |
+
def _analyze_entry_ids(self) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Analyzes all datasets to find common and partial 'entry_id's.
|
| 113 |
+
Updates `common_entry_ids` and `partial_entry_ids` attributes.
|
| 114 |
+
"""
|
| 115 |
+
entry_id_sets: List[set] = []
|
| 116 |
+
|
| 117 |
+
# GT reference images
|
| 118 |
+
entry_id_sets.append(set(self.reference_images.keys()))
|
| 119 |
+
# The other datasets
|
| 120 |
+
for dataset in self.datasets.values():
|
| 121 |
+
entry_ids = set(dataset.keys())
|
| 122 |
+
entry_id_sets.append(entry_ids)
|
| 123 |
+
|
| 124 |
+
if entry_id_sets:
|
| 125 |
+
self.common_entry_ids = list(set.intersection(*entry_id_sets))
|
| 126 |
+
all_entry_ids = set.union(*entry_id_sets)
|
| 127 |
+
self.partial_entry_ids = list(all_entry_ids - set(self.common_entry_ids))
|
| 128 |
+
|
| 129 |
+
def get_entries_by_id(
|
| 130 |
+
self, entry_id: str, dataset_name1: str, dataset_name2: str
|
| 131 |
+
) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 132 |
+
"""
|
| 133 |
+
Retrieves entries from two datasets by a specific 'entry_id'.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
entry_id (str): The 'entry_id' to search for.
|
| 137 |
+
dataset_name1 (str): Name of the first dataset.
|
| 138 |
+
dataset_name2 (str): Name of the second dataset.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 142 |
+
Reference content, or None if not found.
|
| 143 |
+
Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
|
| 144 |
+
Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
|
| 145 |
+
|
| 146 |
+
Raises:
|
| 147 |
+
ValueError: If one or both dataset names are not found.
|
| 148 |
+
"""
|
| 149 |
+
if dataset_name1 not in self.datasets or dataset_name2 not in self.datasets:
|
| 150 |
+
raise ValueError("One or both dataset names not found.")
|
| 151 |
+
|
| 152 |
+
entry1 = self.datasets[dataset_name1][entry_id]
|
| 153 |
+
entry2 = self.datasets[dataset_name2][entry_id]
|
| 154 |
+
reference_image = self.reference_images[entry_id]["reference_image"]
|
| 155 |
+
|
| 156 |
+
return reference_image, entry1["video"], entry2["video"]
|
src/gecora/dataset/video_pkl_manager.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VideoFrameDataset:
|
| 17 |
+
"""
|
| 18 |
+
PyTorch Dataset for loading video frames and progress values from a custom directory structure.
|
| 19 |
+
|
| 20 |
+
Each item corresponds to one video and returns:
|
| 21 |
+
- frames: List of compressed image bytes
|
| 22 |
+
- progress: List of float progress values
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
dataset_dir (str): Root directory of the dataset.
|
| 26 |
+
split (str): Either 'train' or 'test'.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
dataset_dir: str,
|
| 32 |
+
split: str = "test",
|
| 33 |
+
):
|
| 34 |
+
assert split in ["train", "test", "eval"], "Split must be 'train', 'test' or 'eval'"
|
| 35 |
+
self.dataset_dir = Path(dataset_dir)
|
| 36 |
+
self.split = split
|
| 37 |
+
self.video_dirs: List[Path] = self._collect_video_dirs()
|
| 38 |
+
self.video_id_to_idx: Dict[str, int] = {}
|
| 39 |
+
for idx, video_dir in enumerate(self.video_dirs):
|
| 40 |
+
url = video_dir.stem
|
| 41 |
+
source = video_dir.parts[-2]
|
| 42 |
+
id = source + "_" + url
|
| 43 |
+
self.video_id_to_idx[id] = idx
|
| 44 |
+
|
| 45 |
+
def _collect_video_dirs(self) -> List[Path]:
|
| 46 |
+
"""Collects all video subdirectories under the split."""
|
| 47 |
+
split_dir = self.dataset_dir / self.split
|
| 48 |
+
video_dirs = []
|
| 49 |
+
for source_dir in split_dir.iterdir():
|
| 50 |
+
if source_dir.is_dir():
|
| 51 |
+
for video_dir in source_dir.iterdir():
|
| 52 |
+
if video_dir.is_dir():
|
| 53 |
+
video_dirs.append(video_dir)
|
| 54 |
+
return sorted(video_dirs)
|
| 55 |
+
|
| 56 |
+
def _load_ref_frame(self, video_dir: Path, frame_data: List[bytes]) -> Image.Image:
|
| 57 |
+
reference_frame_path = video_dir / "reference_frame.png"
|
| 58 |
+
reference_frame = None
|
| 59 |
+
try:
|
| 60 |
+
reference_frame = Image.open(reference_frame_path)
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
# Load last frame if reference frame is not set
|
| 65 |
+
if reference_frame is None:
|
| 66 |
+
reference_frame = Image.open(BytesIO(frame_data[-1]))
|
| 67 |
+
|
| 68 |
+
return reference_frame.convert("RGB")
|
| 69 |
+
|
| 70 |
+
def _prepare_frame(self, frame: bytes) -> Image.Image:
|
| 71 |
+
img = Image.open(BytesIO(frame)).convert("RGB")
|
| 72 |
+
return img
|
| 73 |
+
|
| 74 |
+
def __len__(self) -> int:
|
| 75 |
+
return len(self.video_dirs)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx: int) -> Any:
|
| 78 |
+
video_dir = self.video_dirs[idx]
|
| 79 |
+
frame_path = video_dir / "frame_data.pkl"
|
| 80 |
+
progress_path = video_dir / "frame_progress.pkl"
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
with open(frame_path, "rb") as f:
|
| 84 |
+
frame_data: List[bytes] = pickle.load(f)
|
| 85 |
+
|
| 86 |
+
reference_image = self._load_ref_frame(video_dir=video_dir, frame_data=frame_data)
|
| 87 |
+
|
| 88 |
+
with open(progress_path, "rb") as f:
|
| 89 |
+
frame_progress: List[float] = pickle.load(f)
|
| 90 |
+
|
| 91 |
+
frame_img_list = [self._prepare_frame(fd) for fd in frame_data]
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"While loading data for video {video_dir} an error occurred: {e}")
|
| 95 |
+
# Better to raise an error as long as we do not have a workaround
|
| 96 |
+
raise (ValueError(e))
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
"video_dir": str(video_dir),
|
| 100 |
+
"video": frame_img_list,
|
| 101 |
+
"reference_image": reference_image,
|
| 102 |
+
"progress_steps": frame_progress,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class VideoPklManager(BaseDatasetManager):
|
| 107 |
+
"""
|
| 108 |
+
A class to manage datasets stored in video pkl subdirectories of a given root path.
|
| 109 |
+
|
| 110 |
+
Attributes:
|
| 111 |
+
root_path (str): The root directory containing subfolders with Hugging Face datasets.
|
| 112 |
+
dataset_split (str): If the datasets are a Dict, select specified split.
|
| 113 |
+
datasets (Dict[str, Union[Dataset, DatasetDict]]): Dictionary mapping dataset names to loaded datasets.
|
| 114 |
+
common_entry_ids (List[str]): List of 'entry_id's present in all datasets.
|
| 115 |
+
partial_entry_ids (List[str]): List of 'entry_id's present in only some datasets.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, config: DatasetManagerConfig) -> None:
|
| 119 |
+
"""
|
| 120 |
+
Initializes the dataset manager and loads datasets from subdirectories.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
config (DatasetManagerConfig): Configuration.
|
| 124 |
+
"""
|
| 125 |
+
super().__init__(config=config)
|
| 126 |
+
self.datasets: Dict[str, VideoFrameDataset] = {}
|
| 127 |
+
self._load_datasets()
|
| 128 |
+
self._analyze_entry_ids()
|
| 129 |
+
self.logger.info(f"Loaded following datasets: {self.get_dataset_names()}")
|
| 130 |
+
self.logger.info(f"Found {len(self.common_entry_ids)} entries in all datasets.")
|
| 131 |
+
self.logger.info(f"Found {len(self.partial_entry_ids)} entries only in some datasets.")
|
| 132 |
+
|
| 133 |
+
def _load_datasets(self) -> None:
|
| 134 |
+
"""
|
| 135 |
+
Loads all Hugging Face datasets from subdirectories in the root path.
|
| 136 |
+
"""
|
| 137 |
+
for subdir in os.listdir(self.config.root_path):
|
| 138 |
+
full_path = os.path.join(self.config.root_path, subdir)
|
| 139 |
+
if os.path.isdir(full_path):
|
| 140 |
+
try:
|
| 141 |
+
dataset = VideoFrameDataset(full_path, split=self.config.dataset_split)
|
| 142 |
+
self.datasets[subdir] = dataset
|
| 143 |
+
except Exception as e:
|
| 144 |
+
self.logger.info(f"Skipping {subdir}: {e}")
|
| 145 |
+
|
| 146 |
+
def get_dataset_names(self) -> List[str]:
|
| 147 |
+
"""
|
| 148 |
+
Returns a list of all loaded dataset names.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
List[str]: Names of the datasets.
|
| 152 |
+
"""
|
| 153 |
+
return list(self.datasets.keys())
|
| 154 |
+
|
| 155 |
+
def _analyze_entry_ids(self) -> None:
|
| 156 |
+
"""
|
| 157 |
+
Analyzes all datasets to find common and partial 'entry_id's.
|
| 158 |
+
Updates `common_entry_ids` and `partial_entry_ids` attributes.
|
| 159 |
+
"""
|
| 160 |
+
entry_id_sets: List[set] = []
|
| 161 |
+
|
| 162 |
+
for dataset in self.datasets.values():
|
| 163 |
+
entry_ids = set(dataset.video_id_to_idx.keys())
|
| 164 |
+
entry_id_sets.append(entry_ids)
|
| 165 |
+
|
| 166 |
+
if entry_id_sets:
|
| 167 |
+
self.common_entry_ids = list(set.intersection(*entry_id_sets))
|
| 168 |
+
all_entry_ids = set.union(*entry_id_sets)
|
| 169 |
+
self.partial_entry_ids = list(all_entry_ids - set(self.common_entry_ids))
|
| 170 |
+
|
| 171 |
+
def get_entries_by_id(
|
| 172 |
+
self, entry_id: str, dataset_name1: str, dataset_name2: str
|
| 173 |
+
) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 174 |
+
"""
|
| 175 |
+
Retrieves entries from two datasets by a specific 'entry_id'.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
entry_id (str): The 'entry_id' to search for.
|
| 179 |
+
dataset_name1 (str): Name of the first dataset.
|
| 180 |
+
dataset_name2 (str): Name of the second dataset.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Tuple[Optional[Any], Optional[Any], Optional[Any]]:
|
| 184 |
+
Reference content, or None if not found.
|
| 185 |
+
Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
|
| 186 |
+
Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
|
| 187 |
+
|
| 188 |
+
Raises:
|
| 189 |
+
ValueError: If one or both dataset names are not found.
|
| 190 |
+
"""
|
| 191 |
+
if dataset_name1 not in self.datasets or dataset_name2 not in self.datasets:
|
| 192 |
+
raise ValueError("One or both dataset names not found.")
|
| 193 |
+
|
| 194 |
+
def find_entry(dataset: VideoFrameDataset, entry_id: str) -> Optional[Dict]:
|
| 195 |
+
if entry_id in dataset.video_id_to_idx:
|
| 196 |
+
return dataset[dataset.video_id_to_idx[entry_id]]
|
| 197 |
+
else:
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
entry1 = find_entry(self.datasets[dataset_name1], entry_id)
|
| 201 |
+
entry2 = find_entry(self.datasets[dataset_name2], entry_id)
|
| 202 |
+
|
| 203 |
+
reference_image = entry1["reference_image"] if entry1 is not None else None
|
| 204 |
+
ret_entry1 = entry1["video"] if entry1 is not None else None
|
| 205 |
+
ret_entry2 = entry2["video"] if entry2 is not None else None
|
| 206 |
+
|
| 207 |
+
return reference_image, ret_entry1, ret_entry2
|
src/gecora/dataset_converting/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/dataset_converting/video_pkl_to_video.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from gecora.dataset.video_pkl_manager import VideoFrameDataset
|
| 7 |
+
from gecora.logic.utils import save_video_from_frames
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def convert(
|
| 11 |
+
dataset_dir: str,
|
| 12 |
+
output_dir: str,
|
| 13 |
+
split: str = "test",
|
| 14 |
+
fps: int = 3,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Converts a VideoFrameDataset into video files.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dataset_dir: Path to the dataset directory.
|
| 21 |
+
output_dir: Path where the output videos will be saved.
|
| 22 |
+
split: The dataset split to process (e.g., 'train', 'test', 'eval').
|
| 23 |
+
fps: Frames per second for the output video.
|
| 24 |
+
"""
|
| 25 |
+
print(f"Loading dataset from: {dataset_dir} (Split: {split})")
|
| 26 |
+
dataset = VideoFrameDataset(dataset_dir=dataset_dir, split=split)
|
| 27 |
+
|
| 28 |
+
num_entries = len(dataset)
|
| 29 |
+
print(f"Found {num_entries} entries. Starting conversion...")
|
| 30 |
+
|
| 31 |
+
for idx in tqdm(range(num_entries)):
|
| 32 |
+
entry_dict = dataset[idx]
|
| 33 |
+
video_dir = entry_dict["video_dir"]
|
| 34 |
+
|
| 35 |
+
# Extract source and ID from the path parts
|
| 36 |
+
dir_split = Path(video_dir).parts
|
| 37 |
+
source = dir_split[-2]
|
| 38 |
+
id = dir_split[-1]
|
| 39 |
+
|
| 40 |
+
# Construct output path
|
| 41 |
+
out_path = Path(output_dir) / split / source / id
|
| 42 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# Save reference_image
|
| 45 |
+
entry_dict["reference_image"].save(out_path / "reference_image.png")
|
| 46 |
+
|
| 47 |
+
# Save the video
|
| 48 |
+
output_file = out_path / "video.mp4"
|
| 49 |
+
save_video_from_frames(entry_dict["video"], video_output_path=str(output_file), fps=fps)
|
| 50 |
+
|
| 51 |
+
print("Conversion complete.")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
parser = argparse.ArgumentParser(description="Convert VideoFrameDataset entries to MP4 video files.")
|
| 56 |
+
|
| 57 |
+
# Required arguments
|
| 58 |
+
parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the root directory of the dataset.")
|
| 59 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Directory where output videos will be saved.")
|
| 60 |
+
|
| 61 |
+
# Optional arguments
|
| 62 |
+
parser.add_argument("--split", type=str, default="test", help="Dataset split to process (default: 'test').")
|
| 63 |
+
parser.add_argument("--fps", type=int, default=3, help="Frames per second for the output video (default: 2).")
|
| 64 |
+
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
convert(dataset_dir=args.dataset_dir, output_dir=args.output_dir, split=args.split, fps=args.fps)
|
src/gecora/db/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/db/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
src/gecora/db/__pycache__/hf_jsonl.cpython-312.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
src/gecora/db/__pycache__/sqlite.cpython-312.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
src/gecora/db/hf_jsonl.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Class for Database handling using HuggingFace Datasets with JSONL storage."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import HfApi, hf_hub_download, upload_file
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HFJsonlDB:
|
| 13 |
+
"""
|
| 14 |
+
Handles database operations using HuggingFace datasets with JSONL files.
|
| 15 |
+
|
| 16 |
+
Stores data in JSONL files on HuggingFace datasets for persistent storage in HF Spaces.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
repo_id (str): HuggingFace repository ID (e.g., "username/dataset-name").
|
| 20 |
+
experiment_name (str): Name of the experiment (used to name files).
|
| 21 |
+
token (Optional[str]): HuggingFace API token for authentication.
|
| 22 |
+
users_filename (str): Filename for users JSONL file.
|
| 23 |
+
preferences_filename (str): Filename for preferences JSONL file.
|
| 24 |
+
logger (logging.Logger): Logger instance for logging database operations.
|
| 25 |
+
hf_api (HfApi): HuggingFace API client.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
repo_id: str,
|
| 31 |
+
experiment_name: str = "arena",
|
| 32 |
+
token: Optional[str] = None,
|
| 33 |
+
log_folder_path: Optional[str] = None,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initializes the HFJsonlDB instance.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
repo_id (str): HuggingFace repository ID (e.g., "username/dataset-name").
|
| 40 |
+
experiment_name (str, optional): Name of the experiment. Defaults to "arena".
|
| 41 |
+
token (Optional[str], optional): HuggingFace API token. If None, uses HF_TOKEN env var.
|
| 42 |
+
log_folder_path (Optional[str], optional): Path for log file. Defaults to current directory.
|
| 43 |
+
"""
|
| 44 |
+
self.repo_id = repo_id
|
| 45 |
+
self.experiment_name = experiment_name
|
| 46 |
+
self.token = token or os.environ.get("HF_TOKEN")
|
| 47 |
+
self.users_filename = f"{experiment_name.lower()}_users.jsonl"
|
| 48 |
+
self.preferences_filename = f"{experiment_name.lower()}_preferences.jsonl"
|
| 49 |
+
|
| 50 |
+
# Setup logging
|
| 51 |
+
if log_folder_path is None:
|
| 52 |
+
log_folder_path = "."
|
| 53 |
+
self.log_path = os.path.join(log_folder_path, "log_hf_db.txt")
|
| 54 |
+
logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
|
| 55 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 56 |
+
|
| 57 |
+
# Initialize HF API
|
| 58 |
+
self.hf_api = HfApi(token=self.token)
|
| 59 |
+
|
| 60 |
+
# Cache for data
|
| 61 |
+
self._users_cache: List[Dict[str, Any]] = []
|
| 62 |
+
self._preferences_cache: List[Dict[str, Any]] = []
|
| 63 |
+
self._cache_loaded = False
|
| 64 |
+
|
| 65 |
+
def initialize_database(self) -> bool:
|
| 66 |
+
"""
|
| 67 |
+
Initializes the database by ensuring the HF dataset exists and files are present.
|
| 68 |
+
|
| 69 |
+
Creates empty JSONL files if they don't exist on the HF dataset.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
bool: True if successful, False if an error occurred.
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
# Check if repository exists, if not create it
|
| 76 |
+
try:
|
| 77 |
+
self.hf_api.repo_info(repo_id=self.repo_id, repo_type="dataset")
|
| 78 |
+
self.logger.info(f"Repository {self.repo_id} already exists")
|
| 79 |
+
except Exception:
|
| 80 |
+
self.logger.info(f"Creating repository {self.repo_id}")
|
| 81 |
+
self.hf_api.create_repo(
|
| 82 |
+
repo_id=self.repo_id,
|
| 83 |
+
repo_type="dataset",
|
| 84 |
+
exist_ok=True,
|
| 85 |
+
private=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Try to load existing files or create new ones
|
| 89 |
+
self._load_data()
|
| 90 |
+
|
| 91 |
+
# If files don't exist, create them
|
| 92 |
+
if not self._cache_loaded:
|
| 93 |
+
self._save_users([])
|
| 94 |
+
self._save_preferences([])
|
| 95 |
+
self._users_cache = []
|
| 96 |
+
self._preferences_cache = []
|
| 97 |
+
self._cache_loaded = True
|
| 98 |
+
|
| 99 |
+
self.logger.info("Database initialized successfully")
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
self.logger.error(f"Initializing the database failed with error: {e}")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
def _load_data(self):
|
| 107 |
+
"""Load data from HF dataset into cache."""
|
| 108 |
+
try:
|
| 109 |
+
# Load users
|
| 110 |
+
try:
|
| 111 |
+
users_path = hf_hub_download(
|
| 112 |
+
repo_id=self.repo_id,
|
| 113 |
+
filename=self.users_filename,
|
| 114 |
+
repo_type="dataset",
|
| 115 |
+
token=self.token,
|
| 116 |
+
)
|
| 117 |
+
with open(users_path, "r") as f:
|
| 118 |
+
self._users_cache = [json.loads(line) for line in f if line.strip()]
|
| 119 |
+
except Exception as e:
|
| 120 |
+
self.logger.info(f"Users file not found, will create new: {e}")
|
| 121 |
+
self._users_cache = []
|
| 122 |
+
|
| 123 |
+
# Load preferences
|
| 124 |
+
try:
|
| 125 |
+
prefs_path = hf_hub_download(
|
| 126 |
+
repo_id=self.repo_id,
|
| 127 |
+
filename=self.preferences_filename,
|
| 128 |
+
repo_type="dataset",
|
| 129 |
+
token=self.token,
|
| 130 |
+
)
|
| 131 |
+
with open(prefs_path, "r") as f:
|
| 132 |
+
self._preferences_cache = [json.loads(line) for line in f if line.strip()]
|
| 133 |
+
except Exception as e:
|
| 134 |
+
self.logger.info(f"Preferences file not found, will create new: {e}")
|
| 135 |
+
self._preferences_cache = []
|
| 136 |
+
|
| 137 |
+
self._cache_loaded = True
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
self.logger.error(f"Error loading data: {e}")
|
| 141 |
+
self._cache_loaded = False
|
| 142 |
+
|
| 143 |
+
def _save_users(self, users: List[Dict[str, Any]]):
|
| 144 |
+
"""Save users to HF dataset."""
|
| 145 |
+
temp_path = f"/tmp/{self.users_filename}"
|
| 146 |
+
with open(temp_path, "w") as f:
|
| 147 |
+
for user in users:
|
| 148 |
+
f.write(json.dumps(user) + "\n")
|
| 149 |
+
|
| 150 |
+
upload_file(
|
| 151 |
+
path_or_fileobj=temp_path,
|
| 152 |
+
path_in_repo=self.users_filename,
|
| 153 |
+
repo_id=self.repo_id,
|
| 154 |
+
repo_type="dataset",
|
| 155 |
+
token=self.token,
|
| 156 |
+
)
|
| 157 |
+
os.remove(temp_path)
|
| 158 |
+
|
| 159 |
+
def _save_preferences(self, preferences: List[Dict[str, Any]]):
|
| 160 |
+
"""Save preferences to HF dataset."""
|
| 161 |
+
temp_path = f"/tmp/{self.preferences_filename}"
|
| 162 |
+
with open(temp_path, "w") as f:
|
| 163 |
+
for pref in preferences:
|
| 164 |
+
f.write(json.dumps(pref) + "\n")
|
| 165 |
+
|
| 166 |
+
upload_file(
|
| 167 |
+
path_or_fileobj=temp_path,
|
| 168 |
+
path_in_repo=self.preferences_filename,
|
| 169 |
+
repo_id=self.repo_id,
|
| 170 |
+
repo_type="dataset",
|
| 171 |
+
token=self.token,
|
| 172 |
+
)
|
| 173 |
+
os.remove(temp_path)
|
| 174 |
+
|
| 175 |
+
def create_user(self, username: str) -> Tuple[Optional[int], str]:
|
| 176 |
+
"""
|
| 177 |
+
Creates a new user in the database.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
username (str): The username of the new user.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Tuple[Optional[int], str]:
|
| 184 |
+
The user_id of the newly created user, or None if creation failed.
|
| 185 |
+
If first entry is None then the second contains the exception message.
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
# Reload data to ensure we have latest
|
| 189 |
+
self._load_data()
|
| 190 |
+
|
| 191 |
+
# Check if user already exists
|
| 192 |
+
for user in self._users_cache:
|
| 193 |
+
if user["username"] == username:
|
| 194 |
+
msg = f"User '{username}' already exists."
|
| 195 |
+
self.logger.warning(msg)
|
| 196 |
+
return None, msg
|
| 197 |
+
|
| 198 |
+
# Generate new user_id
|
| 199 |
+
user_id = max([u["user_id"] for u in self._users_cache], default=0) + 1
|
| 200 |
+
|
| 201 |
+
# Create new user
|
| 202 |
+
new_user = {
|
| 203 |
+
"user_id": user_id,
|
| 204 |
+
"username": username,
|
| 205 |
+
"created_at": datetime.now().isoformat(),
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
self._users_cache.append(new_user)
|
| 209 |
+
self._save_users(self._users_cache)
|
| 210 |
+
|
| 211 |
+
self.logger.info(f"User '{username}' created with user_id {user_id}")
|
| 212 |
+
return user_id, username
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
msg = f"Failed to create user '{username}': {e}"
|
| 216 |
+
self.logger.error(msg)
|
| 217 |
+
return None, msg
|
| 218 |
+
|
| 219 |
+
def get_user_id_by_username(self, username: str) -> Optional[int]:
|
| 220 |
+
"""
|
| 221 |
+
Checks if a username exists in the database and returns the associated user_id.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
username (str): The username to look up.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Optional[int]: The user_id if the username exists, None otherwise.
|
| 228 |
+
"""
|
| 229 |
+
try:
|
| 230 |
+
# Reload data to ensure we have latest
|
| 231 |
+
self._load_data()
|
| 232 |
+
|
| 233 |
+
for user in self._users_cache:
|
| 234 |
+
if user["username"] == username:
|
| 235 |
+
return user["user_id"]
|
| 236 |
+
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
self.logger.error(f"Error checking username '{username}': {e}")
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
def insert_preference(
|
| 244 |
+
self,
|
| 245 |
+
user_id: int,
|
| 246 |
+
reference_id: str,
|
| 247 |
+
model_left_id: str,
|
| 248 |
+
model_right_id: str,
|
| 249 |
+
preferred_side: str,
|
| 250 |
+
) -> Tuple[bool, str]:
|
| 251 |
+
"""
|
| 252 |
+
Inserts a new preference entry into the database.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
user_id (int): ID of the user making the preference.
|
| 256 |
+
reference_id (str): ID of the reference image.
|
| 257 |
+
model_left_id (str): ID of the left model's generated image.
|
| 258 |
+
model_right_id (str): ID of the right model's generated image.
|
| 259 |
+
preferred_side (str): The preferred side ('left', 'right', or 'tie').
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
Tuple[bool, str]: True if insertion was successful, False otherwise with a
|
| 263 |
+
string message describing the exception.
|
| 264 |
+
"""
|
| 265 |
+
msg = ""
|
| 266 |
+
if preferred_side not in {"left", "right", "tie"}:
|
| 267 |
+
msg = f"Invalid preferred_side value: {preferred_side}"
|
| 268 |
+
self.logger.error(msg)
|
| 269 |
+
return False, msg
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
# Reload data to ensure we have latest
|
| 273 |
+
self._load_data()
|
| 274 |
+
|
| 275 |
+
# Generate new preference_id
|
| 276 |
+
preference_id = max([p["preference_id"] for p in self._preferences_cache], default=0) + 1
|
| 277 |
+
|
| 278 |
+
# Create new preference
|
| 279 |
+
new_preference = {
|
| 280 |
+
"preference_id": preference_id,
|
| 281 |
+
"user_id": user_id,
|
| 282 |
+
"reference_id": reference_id,
|
| 283 |
+
"model_left_id": model_left_id,
|
| 284 |
+
"model_right_id": model_right_id,
|
| 285 |
+
"preferred_side": preferred_side,
|
| 286 |
+
"timestamp": datetime.now().isoformat(),
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
self._preferences_cache.append(new_preference)
|
| 290 |
+
self._save_preferences(self._preferences_cache)
|
| 291 |
+
|
| 292 |
+
self.logger.info(f"Preference inserted for user_id {user_id}")
|
| 293 |
+
return True, msg
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
msg = f"Failed to insert preference: {e}"
|
| 297 |
+
self.logger.error(msg)
|
| 298 |
+
return False, msg
|
| 299 |
+
|
| 300 |
+
def get_all_preferences(self) -> List[Tuple]:
|
| 301 |
+
"""
|
| 302 |
+
Retrieves all preference entries from the database.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
List[Tuple]: A list of tuples representing all preference entries.
|
| 306 |
+
"""
|
| 307 |
+
try:
|
| 308 |
+
self._load_data()
|
| 309 |
+
|
| 310 |
+
# Convert dicts to tuples matching SQLite format
|
| 311 |
+
result = []
|
| 312 |
+
for pref in self._preferences_cache:
|
| 313 |
+
result.append(
|
| 314 |
+
(
|
| 315 |
+
pref["preference_id"],
|
| 316 |
+
pref["user_id"],
|
| 317 |
+
pref["reference_id"],
|
| 318 |
+
pref["model_left_id"],
|
| 319 |
+
pref["model_right_id"],
|
| 320 |
+
pref["preferred_side"],
|
| 321 |
+
pref["timestamp"],
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
return result
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
self.logger.error(f"Failed to retrieve preferences: {e}")
|
| 328 |
+
return []
|
| 329 |
+
|
| 330 |
+
def get_preferences_by_user(self, user_id: int) -> List[Tuple]:
|
| 331 |
+
"""
|
| 332 |
+
Retrieves all preference entries for a specific user.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
user_id (int): The ID of the user.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
List[Tuple]: A list of tuples representing the user's preference entries.
|
| 339 |
+
"""
|
| 340 |
+
try:
|
| 341 |
+
self._load_data()
|
| 342 |
+
|
| 343 |
+
# Filter by user_id and convert to tuples
|
| 344 |
+
result = []
|
| 345 |
+
for pref in self._preferences_cache:
|
| 346 |
+
if pref["user_id"] == user_id:
|
| 347 |
+
result.append(
|
| 348 |
+
(
|
| 349 |
+
pref["preference_id"],
|
| 350 |
+
pref["user_id"],
|
| 351 |
+
pref["reference_id"],
|
| 352 |
+
pref["model_left_id"],
|
| 353 |
+
pref["model_right_id"],
|
| 354 |
+
pref["preferred_side"],
|
| 355 |
+
pref["timestamp"],
|
| 356 |
+
)
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.logger.info(f"Retrieved {len(result)} preferences for user_id {user_id}.")
|
| 360 |
+
return result
|
| 361 |
+
|
| 362 |
+
except Exception as e:
|
| 363 |
+
self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}")
|
| 364 |
+
return []
|
| 365 |
+
|
| 366 |
+
def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]:
|
| 367 |
+
"""
|
| 368 |
+
Maps a list of preference tuples to a list of dictionaries using the Preference schema.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
preferences (List[Tuple]): List of tuples from the Preference table.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema.
|
| 375 |
+
"""
|
| 376 |
+
keys = [
|
| 377 |
+
"preference_id",
|
| 378 |
+
"user_id",
|
| 379 |
+
"reference_id",
|
| 380 |
+
"model_left_id",
|
| 381 |
+
"model_right_id",
|
| 382 |
+
"preferred_side",
|
| 383 |
+
"timestamp",
|
| 384 |
+
]
|
| 385 |
+
return [dict(zip(keys, row)) for row in preferences]
|
src/gecora/db/sqlite.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Class for Database handling using SQLite."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sqlite3
|
| 6 |
+
from sqlite3 import Connection
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SQLiteDB:
|
| 11 |
+
"""
|
| 12 |
+
Handles SQLite database operations for an image generation comparison experiment.
|
| 13 |
+
|
| 14 |
+
Attributes:
|
| 15 |
+
db_folder_path (str): Directory where the SQLite database file is stored.
|
| 16 |
+
experiment_name (str): Name of the experiment (used to name the database file).
|
| 17 |
+
db_filename (str): Filename of the SQLite database.
|
| 18 |
+
db_path (str): Full path to the SQLite database file.
|
| 19 |
+
logger (logging.Logger): Logger instance for logging database operations.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, db_folder_path: str, experiment_name: str = "arena"):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the SQLiteDB instance with dataset and database configuration.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
db_folder_path (str): Directory where the SQLite database file is stored.
|
| 28 |
+
experiment_name (str, optional): Name of the experiment. Defaults to "arena".
|
| 29 |
+
"""
|
| 30 |
+
self.experiment_name = experiment_name
|
| 31 |
+
self.db_folder_path = db_folder_path
|
| 32 |
+
self.db_filename = f"{experiment_name.lower()}.db"
|
| 33 |
+
self.db_path = os.path.join(db_folder_path, self.db_filename)
|
| 34 |
+
self.log_path = os.path.join(db_folder_path, "log_db.txt")
|
| 35 |
+
logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
|
| 36 |
+
self.logger = logging.getLogger()
|
| 37 |
+
self.conn: Optional[Connection] = None
|
| 38 |
+
|
| 39 |
+
def __del__(self):
|
| 40 |
+
if self.conn is not None:
|
| 41 |
+
try:
|
| 42 |
+
self.conn.close()
|
| 43 |
+
except Exception:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
def initialize_database(self) -> bool:
|
| 47 |
+
"""
|
| 48 |
+
Initializes the SQLite database and creates required tables if they do not exist.
|
| 49 |
+
|
| 50 |
+
Tables:
|
| 51 |
+
- User: Stores user information.
|
| 52 |
+
- Preference: Stores user preferences between generated images.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
bool: True if successful, False if an error occurred.
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
db_exists = os.path.exists(self.db_path)
|
| 59 |
+
if self.conn is None:
|
| 60 |
+
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
| 61 |
+
with self.conn: # auto commit
|
| 62 |
+
cursor = self.conn.cursor()
|
| 63 |
+
|
| 64 |
+
if db_exists:
|
| 65 |
+
self.logger.info(f"Database already exists at {self.db_path}")
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
self.logger.info(f"Creating new database at {self.db_path}")
|
| 69 |
+
os.makedirs(self.db_folder_path, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
cursor.execute("""
|
| 72 |
+
CREATE TABLE IF NOT EXISTS User (
|
| 73 |
+
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 74 |
+
username TEXT UNIQUE,
|
| 75 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 76 |
+
)
|
| 77 |
+
""")
|
| 78 |
+
|
| 79 |
+
cursor.execute("""
|
| 80 |
+
CREATE TABLE IF NOT EXISTS Preference (
|
| 81 |
+
preference_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 82 |
+
user_id INTEGER,
|
| 83 |
+
reference_id TEXT,
|
| 84 |
+
model_left_id TEXT,
|
| 85 |
+
model_right_id TEXT,
|
| 86 |
+
preferred_side TEXT CHECK(preferred_side IN ('left', 'right', 'tie')),
|
| 87 |
+
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 88 |
+
FOREIGN KEY (user_id) REFERENCES User(user_id)
|
| 89 |
+
)
|
| 90 |
+
""")
|
| 91 |
+
cursor.close()
|
| 92 |
+
except Exception as e:
|
| 93 |
+
self.logger.info(f"Creating the database failed with following error: {e}")
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
return True
|
| 97 |
+
|
| 98 |
+
def create_user(self, username: str) -> Tuple[Optional[int], str]:
|
| 99 |
+
"""
|
| 100 |
+
Creates a new user in the database.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
username (str): The username of the new user.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Tuple[Optional[int], str]:
|
| 107 |
+
The user_id of the newly created user, or None if creation failed.
|
| 108 |
+
If first entry is None then the second contains the exception message.
|
| 109 |
+
"""
|
| 110 |
+
ret: Optional[int] = None
|
| 111 |
+
msg = ""
|
| 112 |
+
try:
|
| 113 |
+
if self.conn is None:
|
| 114 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 115 |
+
|
| 116 |
+
with self.conn:
|
| 117 |
+
cursor = self.conn.cursor()
|
| 118 |
+
cursor.execute(
|
| 119 |
+
"""
|
| 120 |
+
INSERT INTO User (username) VALUES (?)
|
| 121 |
+
""",
|
| 122 |
+
(username,),
|
| 123 |
+
)
|
| 124 |
+
user_id = cursor.lastrowid
|
| 125 |
+
cursor.close()
|
| 126 |
+
self.logger.info(f"User '{username}' created with user_id {user_id}")
|
| 127 |
+
ret = user_id
|
| 128 |
+
msg = username
|
| 129 |
+
except sqlite3.IntegrityError:
|
| 130 |
+
msg = f"User '{username}' already exists."
|
| 131 |
+
self.logger.warning(msg)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
msg = f"Failed to create user '{username}': {e}"
|
| 134 |
+
self.logger.error(msg)
|
| 135 |
+
return ret, msg
|
| 136 |
+
|
| 137 |
+
def get_user_id_by_username(self, username: str) -> Optional[int]:
|
| 138 |
+
"""
|
| 139 |
+
Checks if a username exists in the database and returns the associated user_id.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
username (str): The username to look up.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Optional[int]: The user_id if the username exists, None otherwise.
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
if self.conn is None:
|
| 149 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 150 |
+
|
| 151 |
+
cursor = self.conn.cursor()
|
| 152 |
+
cursor.execute(
|
| 153 |
+
"""
|
| 154 |
+
SELECT user_id FROM User WHERE username = ?
|
| 155 |
+
""",
|
| 156 |
+
(username,),
|
| 157 |
+
)
|
| 158 |
+
result = cursor.fetchone()
|
| 159 |
+
cursor.close()
|
| 160 |
+
|
| 161 |
+
if result:
|
| 162 |
+
return result[0]
|
| 163 |
+
else:
|
| 164 |
+
return None
|
| 165 |
+
except Exception as e:
|
| 166 |
+
self.logger.error(f"Error checking username '{username}': {e}")
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
def insert_preference(
|
| 170 |
+
self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
|
| 171 |
+
) -> Tuple[bool, str]:
|
| 172 |
+
"""
|
| 173 |
+
Inserts a new preference entry into the database.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
user_id (int): ID of the user making the preference.
|
| 177 |
+
reference_id (str): ID of the reference image.
|
| 178 |
+
model_left_id (str): ID of the left model's generated image.
|
| 179 |
+
model_right_id (str): ID of the right model's generated image.
|
| 180 |
+
preferred_side (str): The preferred side ('left', 'right', or 'tie').
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Tuple[bool, str]: True if insertion was successful, False otherwise with a
|
| 184 |
+
string message describing the exception.
|
| 185 |
+
"""
|
| 186 |
+
msg = ""
|
| 187 |
+
if preferred_side not in {"left", "right", "tie"}:
|
| 188 |
+
msg = f"Invalid preferred_side value: {preferred_side}"
|
| 189 |
+
self.logger.error(msg)
|
| 190 |
+
return False, msg
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
if self.conn is None:
|
| 194 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 195 |
+
|
| 196 |
+
with self.conn:
|
| 197 |
+
cursor = self.conn.cursor()
|
| 198 |
+
cursor.execute(
|
| 199 |
+
"""
|
| 200 |
+
INSERT INTO Preference (
|
| 201 |
+
user_id, reference_id, model_left_id, model_right_id, preferred_side
|
| 202 |
+
) VALUES (?, ?, ?, ?, ?)
|
| 203 |
+
""",
|
| 204 |
+
(user_id, reference_id, model_left_id, model_right_id, preferred_side),
|
| 205 |
+
)
|
| 206 |
+
cursor.close()
|
| 207 |
+
self.logger.info(f"Preference inserted for user_id {user_id}")
|
| 208 |
+
return True, msg
|
| 209 |
+
except Exception as e:
|
| 210 |
+
msg = f"Failed to insert preference: {e}"
|
| 211 |
+
self.logger.error(msg)
|
| 212 |
+
return False, msg
|
| 213 |
+
|
| 214 |
+
def get_all_preferences(self) -> List[Tuple]:
|
| 215 |
+
"""
|
| 216 |
+
Retrieves all preference entries from the database.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
List[Tuple]: A list of tuples representing all preference entries.
|
| 220 |
+
"""
|
| 221 |
+
preferences = []
|
| 222 |
+
try:
|
| 223 |
+
if self.conn is None:
|
| 224 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 225 |
+
|
| 226 |
+
with self.conn:
|
| 227 |
+
cursor = self.conn.cursor()
|
| 228 |
+
cursor.execute("SELECT * FROM Preference")
|
| 229 |
+
preferences = cursor.fetchall()
|
| 230 |
+
cursor.close()
|
| 231 |
+
except Exception as e:
|
| 232 |
+
self.logger.error(f"Failed to retrieve preferences: {e}")
|
| 233 |
+
return preferences
|
| 234 |
+
|
| 235 |
+
def get_preferences_by_user(self, user_id: int) -> List[Tuple]:
|
| 236 |
+
"""
|
| 237 |
+
Retrieves all preference entries for a specific user.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
user_id (int): The ID of the user.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
List[Tuple]: A list of tuples representing the user's preference entries.
|
| 244 |
+
"""
|
| 245 |
+
preferences = []
|
| 246 |
+
try:
|
| 247 |
+
if self.conn is None:
|
| 248 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 249 |
+
|
| 250 |
+
with self.conn:
|
| 251 |
+
cursor = self.conn.cursor()
|
| 252 |
+
cursor.execute("SELECT * FROM Preference WHERE user_id = ?", (user_id,))
|
| 253 |
+
preferences = cursor.fetchall()
|
| 254 |
+
cursor.close()
|
| 255 |
+
self.logger.info(f"Retrieved {len(preferences)} preferences for user_id {user_id}.")
|
| 256 |
+
except Exception as e:
|
| 257 |
+
self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}")
|
| 258 |
+
return preferences
|
| 259 |
+
|
| 260 |
+
def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]:
|
| 261 |
+
"""
|
| 262 |
+
Maps a list of preference tuples to a list of dictionaries using the Preference schema.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
preferences (List[Tuple]): List of tuples from the Preference table.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema.
|
| 269 |
+
"""
|
| 270 |
+
keys = [
|
| 271 |
+
"preference_id",
|
| 272 |
+
"user_id",
|
| 273 |
+
"reference_id",
|
| 274 |
+
"model_left_id",
|
| 275 |
+
"model_right_id",
|
| 276 |
+
"preferred_side",
|
| 277 |
+
"timestamp",
|
| 278 |
+
]
|
| 279 |
+
return [dict(zip(keys, row)) for row in preferences]
|
src/gecora/logging/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/logging/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
src/gecora/logging/__pycache__/logger.cpython-312.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
src/gecora/logging/logger.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def setup_file_logger(name: str, log_file: str, level=logging.INFO) -> logging.Logger:
|
| 5 |
+
"""Creates a logger with a specific name and log file."""
|
| 6 |
+
logger = logging.getLogger(name)
|
| 7 |
+
logger.setLevel(level)
|
| 8 |
+
|
| 9 |
+
# Prevent adding multiple handlers if logger already exists
|
| 10 |
+
if not logger.handlers:
|
| 11 |
+
file_handler = logging.FileHandler(log_file)
|
| 12 |
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 13 |
+
file_handler.setFormatter(formatter)
|
| 14 |
+
logger.addHandler(file_handler)
|
| 15 |
+
|
| 16 |
+
return logger
|
src/gecora/logic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/logic/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (168 Bytes). View file
|
|
|
src/gecora/logic/__pycache__/base.cpython-312.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
src/gecora/logic/__pycache__/loomis_painter.cpython-312.pyc
ADDED
|
Binary file (8.98 kB). View file
|
|
|
src/gecora/logic/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (1.9 kB). View file
|
|
|
src/gecora/logic/base.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LogicBase:
|
| 7 |
+
"""
|
| 8 |
+
Ranking App.
|
| 9 |
+
|
| 10 |
+
Attributes:
|
| 11 |
+
root_path (str): Root directory for the application.
|
| 12 |
+
dataset_path (Optional[str]): Path to the dataset, if provided.
|
| 13 |
+
ranking_system (ModelRankingSystem): Instance of the model ranking system.
|
| 14 |
+
gecora_db (Optional[GecoraDB]): Placeholder for GecoraDB integration.
|
| 15 |
+
dataset_manager (Optional[GecoraDatasetManager]): Placeholder for dataset manager.
|
| 16 |
+
itov_app (Optional[ItoVApp]): Placeholder for ItoVApp integration.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, root_path: str, dataset_path: Optional[str] = None, dataset_split: str = "test"):
|
| 20 |
+
"""Initializes Base Class"""
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def launch(self):
|
| 24 |
+
"""Launches the App."""
|
| 25 |
+
raise NotImplementedError("This method should be implemented by the child class.")
|
| 26 |
+
|
| 27 |
+
def set_username(self, username: str) -> Optional[int]:
|
| 28 |
+
"""
|
| 29 |
+
Creates the User in the database if not alredy present.
|
| 30 |
+
"""
|
| 31 |
+
raise NotImplementedError("This method should be implemented by the child class.")
|
| 32 |
+
|
| 33 |
+
def set_preference(
|
| 34 |
+
self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
|
| 35 |
+
) -> Tuple[bool, str]:
|
| 36 |
+
"""
|
| 37 |
+
Sets a new preference entry.
|
| 38 |
+
"""
|
| 39 |
+
raise NotImplementedError("This method should be implemented by the child class.")
|
| 40 |
+
|
| 41 |
+
def get_next_comparison(
|
| 42 |
+
self, user_id: int
|
| 43 |
+
) -> Optional[Tuple[Tuple[str, str, str], Tuple[Image.Image, Dict, Dict], Tuple[int, int]]]:
|
| 44 |
+
"""
|
| 45 |
+
Selects the next model pair for comparison.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
user_id (int): ID of the current user.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Optional[Tuple[Tuple[str, str, str], Tuple[Dict, Dict], Tuple[int, int]]]: ((reference_id, model_left_id, model_right_id), left_entry, right_entry), (num_preferences, total_num_comparison)
|
| 52 |
+
"""
|
| 53 |
+
raise NotImplementedError("This method should be implemented by the child class.")
|
src/gecora/logic/loomis_painter.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
from gecora.app.i_to_v_app import ItoVApp
|
| 6 |
+
from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
|
| 7 |
+
from gecora.dataset.video_manager import VideoManager
|
| 8 |
+
from gecora.db.hf_jsonl import HFJsonlDB
|
| 9 |
+
from gecora.db.sqlite import SQLiteDB
|
| 10 |
+
from gecora.logic.base import LogicBase
|
| 11 |
+
from gecora.logic.utils import save_video_from_frames
|
| 12 |
+
from gecora.ranking.ranking_system import RankingSystem
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LoomisPainterApp(LogicBase):
|
| 16 |
+
"""
|
| 17 |
+
Loomis Painter Ranking App.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
root_path (str): Root directory for the application.
|
| 21 |
+
dataset_path (Optional[str]): Path to the dataset, if provided.
|
| 22 |
+
ranking_system (ModelRankingSystem): Instance of the model ranking system.
|
| 23 |
+
gecora_db (Union[SQLiteDB, HFJsonlDB]): Database instance (SQLite or HF JSONL).
|
| 24 |
+
dataset_manager (Optional[GecoraDatasetManager]): Placeholder for dataset manager.
|
| 25 |
+
itov_app (Optional[ItoVApp]): Placeholder for ItoVApp integration.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
root_path: str,
|
| 31 |
+
dataset_path: Optional[str] = None,
|
| 32 |
+
use_hf_db: bool = True,
|
| 33 |
+
hf_repo_id: Optional[str] = None,
|
| 34 |
+
hf_token: Optional[str] = None,
|
| 35 |
+
force_model_id: Optional[str] = None,
|
| 36 |
+
desired_num_selections: Optional[int] = 25,
|
| 37 |
+
):
|
| 38 |
+
"""Initializes LoomisPainterApp.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
root_path (str): The root directory containing subfolders with datasets.
|
| 42 |
+
dataset_path (Optional[str]): Path to the dataset, if provided.
|
| 43 |
+
use_hf_db (bool): Whether to use HuggingFace JSONL database instead of SQLite.
|
| 44 |
+
hf_repo_id (Optional[str]): HF repository ID (required if use_hf_db=True).
|
| 45 |
+
hf_token (Optional[str]): HF API token (if not provided, uses HF_TOKEN env var).
|
| 46 |
+
force_model_id (Optional[str]): If set select preferences where force_model_id is always included.
|
| 47 |
+
desired_num_selections (Optional[int]): The desired number of comparison the user should at least select.
|
| 48 |
+
"""
|
| 49 |
+
self.root_path = root_path
|
| 50 |
+
self.dataset_path = dataset_path
|
| 51 |
+
self.force_model_id = force_model_id
|
| 52 |
+
self.desired_num_selections = desired_num_selections
|
| 53 |
+
|
| 54 |
+
# Core logic
|
| 55 |
+
self.ranking_system = RankingSystem()
|
| 56 |
+
|
| 57 |
+
# External integrations (to be implemented or injected)
|
| 58 |
+
if dataset_path is None:
|
| 59 |
+
dataset_path = root_path
|
| 60 |
+
|
| 61 |
+
self.gecora_db: Union[HFJsonlDB, SQLiteDB]
|
| 62 |
+
# Initialize database based on configuration
|
| 63 |
+
if use_hf_db:
|
| 64 |
+
if hf_repo_id is None:
|
| 65 |
+
raise ValueError("hf_repo_id must be provided when use_hf_db=True")
|
| 66 |
+
self.gecora_db = HFJsonlDB(
|
| 67 |
+
repo_id=hf_repo_id,
|
| 68 |
+
experiment_name="arena",
|
| 69 |
+
token=hf_token,
|
| 70 |
+
log_folder_path=root_path,
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
self.gecora_db = SQLiteDB(
|
| 74 |
+
db_folder_path=root_path,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.gecora_db.initialize_database()
|
| 78 |
+
|
| 79 |
+
dm_conf = DatasetManagerConfig(
|
| 80 |
+
root_path=dataset_path,
|
| 81 |
+
reference_column_name="reference_image",
|
| 82 |
+
genereated_column_name="video",
|
| 83 |
+
logging_path=root_path,
|
| 84 |
+
)
|
| 85 |
+
self.dataset_manager: BaseDatasetManager = VideoManager(config=dm_conf)
|
| 86 |
+
task_desc: str = """
|
| 87 |
+
### 🎨 Choose the Best Step-By-Step Painting video
|
| 88 |
+
You will be shown a **reference painting** in the center, with two generated videos (Left and Right) that depict the painting process.
|
| 89 |
+
Your goal is to decide which step-by-step process is better based on the criteria below.
|
| 90 |
+
|
| 91 |
+
Please consider the following when making your decision:
|
| 92 |
+
* **Process Completeness:** Choose the video that best captures the entire painting process, from start to the finished work.
|
| 93 |
+
* **Visual Fidelity:** The final frame of the video should match the reference painting as closely as possible.
|
| 94 |
+
* **Ignore Duration:** The videos may be of different lengths. Please do not let the video duration influence your decision.
|
| 95 |
+
|
| 96 |
+
#### How to Vote
|
| 97 |
+
* Click **←** if the **Left** video is better.
|
| 98 |
+
* Click **→** if the **Right** video is better.
|
| 99 |
+
* Click **Tie** if neither video is clearly superior.
|
| 100 |
+
"""
|
| 101 |
+
ref_img_label: str = "Reference Painting"
|
| 102 |
+
left_media_label: str = "Left Painting Process"
|
| 103 |
+
right_media_label: str = "Right Painting Process"
|
| 104 |
+
self.itov_app: ItoVApp = ItoVApp(
|
| 105 |
+
logic_class=self,
|
| 106 |
+
task_desc=task_desc,
|
| 107 |
+
ref_img_label=ref_img_label,
|
| 108 |
+
left_media_label=left_media_label,
|
| 109 |
+
right_media_label=right_media_label,
|
| 110 |
+
desired_num_selections=self.desired_num_selections,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.log_path = os.path.join(root_path, "log.txt")
|
| 114 |
+
logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
|
| 115 |
+
self.logger = logging.getLogger()
|
| 116 |
+
|
| 117 |
+
def launch(self):
|
| 118 |
+
"""Launches the App."""
|
| 119 |
+
self.itov_app.launch()
|
| 120 |
+
|
| 121 |
+
def set_username(self, username: str) -> Optional[int]:
|
| 122 |
+
"""
|
| 123 |
+
Creates the User in the database if not already present.
|
| 124 |
+
"""
|
| 125 |
+
user_id = self.gecora_db.get_user_id_by_username(username=username)
|
| 126 |
+
if user_id is None:
|
| 127 |
+
user_id, msg = self.gecora_db.create_user(username=username)
|
| 128 |
+
if user_id is None:
|
| 129 |
+
self.logger.error(f"Error while creating user with username {username}: {msg}")
|
| 130 |
+
return None
|
| 131 |
+
return user_id
|
| 132 |
+
|
| 133 |
+
def set_preference(
|
| 134 |
+
self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
|
| 135 |
+
) -> Tuple[bool, str]:
|
| 136 |
+
"""
|
| 137 |
+
Sets a new preference entry.
|
| 138 |
+
"""
|
| 139 |
+
return self.gecora_db.insert_preference(
|
| 140 |
+
user_id=user_id,
|
| 141 |
+
reference_id=reference_id,
|
| 142 |
+
model_left_id=model_left_id,
|
| 143 |
+
model_right_id=model_right_id,
|
| 144 |
+
preferred_side=preferred_side,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def run_ranking_update(self):
|
| 148 |
+
"""
|
| 149 |
+
Loads preferences and updates model rankings.
|
| 150 |
+
"""
|
| 151 |
+
preferences = self.load_preferences()
|
| 152 |
+
self.ranking_system.update_rankings(preferences)
|
| 153 |
+
|
| 154 |
+
def get_next_comparison(
|
| 155 |
+
self, user_id: int
|
| 156 |
+
) -> Optional[Tuple[Tuple[str, str, str], Tuple[Any, Any, Any], Tuple[int, int]]]:
|
| 157 |
+
"""
|
| 158 |
+
Selects the next model pair for comparison.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
user_id (int): ID of the current user.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Optional[Tuple[Tuple[str, str, str], Tuple[Any, Any, Any], Tuple[int, int]]]: (reference_id, model_left_id, model_right_id), (reference_image, left_entry, right_entry), (num_preferences, total_num_comparison)
|
| 165 |
+
"""
|
| 166 |
+
reference_ids = self.dataset_manager.common_entry_ids
|
| 167 |
+
model_ids = self.dataset_manager.get_dataset_names()
|
| 168 |
+
# Having a data base call here every time may not be ideal. The assumption is that one user
|
| 169 |
+
# does have a small number of set preferences, therefore runtime is still good.
|
| 170 |
+
preferences_tuple = self.gecora_db.get_preferences_by_user(user_id=user_id)
|
| 171 |
+
# Mapping from List[Tuple] to List[Dict[str, Any]]
|
| 172 |
+
preferences = self.gecora_db.map_preferences_to_dicts(preferences=preferences_tuple)
|
| 173 |
+
selected = self.ranking_system.select_next_comparison(
|
| 174 |
+
preferences, reference_ids, model_ids, force_model_id=self.force_model_id
|
| 175 |
+
)
|
| 176 |
+
num_preferences, total_num_comparison = self.ranking_system.calculate_progress(
|
| 177 |
+
preferences, reference_ids, model_ids, force_model_id=self.force_model_id
|
| 178 |
+
)
|
| 179 |
+
try:
|
| 180 |
+
if selected is None:
|
| 181 |
+
self.logger.warning(f"Was not able to get next comparison for user {user_id}")
|
| 182 |
+
return None
|
| 183 |
+
reference_id, model_left_id, model_right_id = selected
|
| 184 |
+
reference_image, left_entry, right_entry = self.dataset_manager.get_entries_by_id(
|
| 185 |
+
entry_id=reference_id, dataset_name1=model_left_id, dataset_name2=model_right_id
|
| 186 |
+
)
|
| 187 |
+
if reference_image is None or left_entry is None or right_entry is None:
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
if isinstance(left_entry, list):
|
| 191 |
+
left_entry = save_video_from_frames(
|
| 192 |
+
left_entry, video_output_path=self.itov_app.tmp_video_path_left, fps=2
|
| 193 |
+
)
|
| 194 |
+
if isinstance(right_entry, list):
|
| 195 |
+
right_entry = save_video_from_frames(
|
| 196 |
+
right_entry, video_output_path=self.itov_app.tmp_video_path_right, fps=2
|
| 197 |
+
)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
self.logger.error(f"Error in get_next_comparison({user_id}) of LoomisPainter: {e}")
|
| 200 |
+
return None
|
| 201 |
+
return (
|
| 202 |
+
(reference_id, model_left_id, model_right_id),
|
| 203 |
+
(reference_image, left_entry, right_entry),
|
| 204 |
+
(num_preferences, total_num_comparison),
|
| 205 |
+
)
|
src/gecora/logic/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import imageio.v3 as iio
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_temp_file(suffix: str = ".mp4") -> str:
|
| 10 |
+
"""Function to create a temporary file and return its path."""
|
| 11 |
+
# Create a temporary file with a .mp4 extension
|
| 12 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
| 13 |
+
temp_file_path = temp_file.name
|
| 14 |
+
temp_file.close() # Close the file but keep it on disk
|
| 15 |
+
return temp_file_path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def cleanup(temp_file_path: str):
|
| 19 |
+
"""Cleanup function to delete the temporary file."""
|
| 20 |
+
if os.path.exists(temp_file_path):
|
| 21 |
+
os.remove(temp_file_path)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def cleanup_list(temp_file_path_list: List[str]):
|
| 25 |
+
"""Cleanup function to delete the temporary file."""
|
| 26 |
+
for temp_file_path in temp_file_path_list:
|
| 27 |
+
if os.path.exists(temp_file_path):
|
| 28 |
+
os.remove(temp_file_path)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_video_from_frames(video: List[Image], video_output_path: str, fps: int = 2) -> Optional[str]:
|
| 32 |
+
"""Saves the frames from as a video to video_output_path"""
|
| 33 |
+
try:
|
| 34 |
+
iio.imwrite(video_output_path, video, fps=fps, codec="libx264")
|
| 35 |
+
return video_output_path
|
| 36 |
+
except Exception:
|
| 37 |
+
return None
|
src/gecora/py.typed
ADDED
|
File without changes
|
src/gecora/ranking/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init."""
|
src/gecora/ranking/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
src/gecora/ranking/__pycache__/ranking_system.cpython-312.pyc
ADDED
|
Binary file (5.58 kB). View file
|
|
|
src/gecora/ranking/ranking_system.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import random
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RankingSystem:
|
| 8 |
+
def __init__(self, base_rating: int = 1500, k: int = 32):
|
| 9 |
+
"""
|
| 10 |
+
Initializes the ranking system.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
base_rating (int): Initial ELO rating for all models.
|
| 14 |
+
k (int): ELO adjustment factor.
|
| 15 |
+
"""
|
| 16 |
+
self.base_rating = base_rating
|
| 17 |
+
self.k = k
|
| 18 |
+
self.elo_ratings: Dict[str, float] = defaultdict(lambda: base_rating)
|
| 19 |
+
self.win_counts: Dict[str, int] = defaultdict(int)
|
| 20 |
+
self.match_counts: Dict[str, int] = defaultdict(int)
|
| 21 |
+
|
| 22 |
+
def update_rankings(self, preferences: List[Dict[str, str]]) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Updates ELO ratings and win counts based on preferences.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
preferences (List[Dict[str, str]]): List of preference records.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def expected_score(rating_a: float, rating_b: float) -> float:
|
| 31 |
+
return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
|
| 32 |
+
|
| 33 |
+
for pref in preferences:
|
| 34 |
+
left = pref["model_left_id"]
|
| 35 |
+
right = pref["model_right_id"]
|
| 36 |
+
result = pref["preferred_side"]
|
| 37 |
+
|
| 38 |
+
rating_left = self.elo_ratings[left]
|
| 39 |
+
rating_right = self.elo_ratings[right]
|
| 40 |
+
|
| 41 |
+
expected_left = expected_score(rating_left, rating_right)
|
| 42 |
+
expected_right = expected_score(rating_right, rating_left)
|
| 43 |
+
|
| 44 |
+
if result == "left":
|
| 45 |
+
score_left, score_right = 1.0, 0.0
|
| 46 |
+
self.win_counts[left] += 1
|
| 47 |
+
elif result == "right":
|
| 48 |
+
score_left, score_right = 0.0, 1.0
|
| 49 |
+
self.win_counts[right] += 1
|
| 50 |
+
else: # tie
|
| 51 |
+
score_left = score_right = 0.5
|
| 52 |
+
|
| 53 |
+
self.elo_ratings[left] += self.k * (score_left - expected_left)
|
| 54 |
+
self.elo_ratings[right] += self.k * (score_right - expected_right)
|
| 55 |
+
|
| 56 |
+
self.match_counts[left] += 1
|
| 57 |
+
self.match_counts[right] += 1
|
| 58 |
+
|
| 59 |
+
def get_elo_ratings(self) -> Dict[str, float]:
|
| 60 |
+
"""Returns the current ELO ratings."""
|
| 61 |
+
return dict(self.elo_ratings)
|
| 62 |
+
|
| 63 |
+
def get_winrates(self) -> Dict[str, float]:
|
| 64 |
+
"""Returns the win rate for each model."""
|
| 65 |
+
return {
|
| 66 |
+
model: self.win_counts[model] / self.match_counts[model] if self.match_counts[model] > 0 else 0.0
|
| 67 |
+
for model in self.match_counts
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def calculate_progress(
|
| 71 |
+
self,
|
| 72 |
+
preferences: List[Dict[str, str]],
|
| 73 |
+
reference_ids: List[str],
|
| 74 |
+
model_ids: List[str],
|
| 75 |
+
force_model_id: Optional[str] = None,
|
| 76 |
+
) -> Tuple[int, int]:
|
| 77 |
+
"""
|
| 78 |
+
Calculates the number of unique comparisons completed and the total possible comparisons.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
preferences: List of preference dictionaries (assumed to be for a single user).
|
| 82 |
+
reference_ids: List of all possible reference IDs.
|
| 83 |
+
model_ids: List of all possible model IDs.
|
| 84 |
+
force_model_id: If set, restricts total calculation to pairs involving this model.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
(seen_count, total_possible_count)
|
| 88 |
+
"""
|
| 89 |
+
# We assume preferences are unique
|
| 90 |
+
seen_count = len(preferences)
|
| 91 |
+
|
| 92 |
+
num_refs = len(reference_ids)
|
| 93 |
+
num_models = len(model_ids)
|
| 94 |
+
|
| 95 |
+
# Safety check to avoid math errors if lists are empty
|
| 96 |
+
if num_refs == 0 or num_models < 2:
|
| 97 |
+
return seen_count, 0
|
| 98 |
+
|
| 99 |
+
if force_model_id:
|
| 100 |
+
if force_model_id not in model_ids:
|
| 101 |
+
total_possible = 0
|
| 102 |
+
else:
|
| 103 |
+
total_possible = num_refs * (num_models - 1)
|
| 104 |
+
else:
|
| 105 |
+
# Standard combinatorics: nCr = n! / (r! * (n-r)!)
|
| 106 |
+
# Number of unique pairs = N * (N - 1) / 2
|
| 107 |
+
unique_pairs_count = (num_models * (num_models - 1)) // 2
|
| 108 |
+
total_possible = num_refs * unique_pairs_count
|
| 109 |
+
|
| 110 |
+
return seen_count, total_possible
|
| 111 |
+
|
| 112 |
+
def select_next_comparison(
|
| 113 |
+
self,
|
| 114 |
+
preferences: List[Dict[str, str]],
|
| 115 |
+
reference_ids: List[str],
|
| 116 |
+
model_ids: List[str],
|
| 117 |
+
force_model_id: Optional[str],
|
| 118 |
+
) -> Optional[Tuple[str, str, str]]:
|
| 119 |
+
"""
|
| 120 |
+
Selects the next pair of models for comparison for a given user and reference.
|
| 121 |
+
|
| 122 |
+
Ensures the same (reference_id, model_left_id, model_right_id) is not repeated for the user.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
preferences (List[Dict[str, str]]): Existing preferences.
|
| 126 |
+
user_id (int): ID of the current user.
|
| 127 |
+
reference_ids (List[str]): List of possible reference IDs.
|
| 128 |
+
model_ids (List[str]): List of model IDs to choose from.
|
| 129 |
+
force_model_id (Optional[str]): If set always includes this model_id in the comparison.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Optional[Tuple[str, str, str]]: A tuple of (reference_id, model_left_id, model_right_id),
|
| 133 |
+
or None if no valid pair is found.
|
| 134 |
+
"""
|
| 135 |
+
seen_comparisons: Set[Tuple[str, frozenset]] = {
|
| 136 |
+
(pref["reference_id"], frozenset({pref["model_left_id"], pref["model_right_id"]})) for pref in preferences
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
# Copy list by slicing and shuffle
|
| 140 |
+
shuffled_refs = reference_ids[:]
|
| 141 |
+
random.shuffle(shuffled_refs)
|
| 142 |
+
shuffled_models = model_ids[:]
|
| 143 |
+
random.shuffle(shuffled_models)
|
| 144 |
+
|
| 145 |
+
# Handling force_model_id
|
| 146 |
+
if force_model_id:
|
| 147 |
+
if force_model_id not in model_ids:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
opponents = [m for m in shuffled_models if m != force_model_id]
|
| 151 |
+
for ref_id in shuffled_refs:
|
| 152 |
+
for opponent in opponents:
|
| 153 |
+
current_pair = frozenset({force_model_id, opponent})
|
| 154 |
+
|
| 155 |
+
if (ref_id, current_pair) not in seen_comparisons:
|
| 156 |
+
# Randomize left/right placement for UI neutrality
|
| 157 |
+
return (
|
| 158 |
+
(ref_id, force_model_id, opponent)
|
| 159 |
+
if random.random() > 0.5
|
| 160 |
+
else (ref_id, opponent, force_model_id)
|
| 161 |
+
)
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
# Generate all possible unique pairs from the shuffled list
|
| 165 |
+
# combinations('ABCD', 2) --> AB AC AD BC BD CD
|
| 166 |
+
possible_pairs = list(itertools.combinations(shuffled_models, 2))
|
| 167 |
+
|
| 168 |
+
# Optional: Shuffle pairs again to ensure 'AB' isn't always checked before 'CD'
|
| 169 |
+
random.shuffle(possible_pairs)
|
| 170 |
+
|
| 171 |
+
for ref_id in shuffled_refs:
|
| 172 |
+
for m1, m2 in possible_pairs:
|
| 173 |
+
current_pair = frozenset({m1, m2})
|
| 174 |
+
|
| 175 |
+
if (ref_id, current_pair) not in seen_comparisons:
|
| 176 |
+
return (ref_id, m1, m2)
|
| 177 |
+
|
| 178 |
+
return None
|