Markus Pobitzer commited on
Commit
b6d1c13
·
1 Parent(s): 7896daf
Files changed (48) hide show
  1. .gitattributes +2 -0
  2. README.md +7 -7
  3. app.py +14 -0
  4. requirements.txt +3 -0
  5. src/gecora/__init__.py +1 -0
  6. src/gecora/__pycache__/__init__.cpython-312.pyc +0 -0
  7. src/gecora/app/__init__.py +1 -0
  8. src/gecora/app/__pycache__/__init__.cpython-312.pyc +0 -0
  9. src/gecora/app/__pycache__/i_to_v_app.cpython-312.pyc +0 -0
  10. src/gecora/app/i_to_v_app.py +249 -0
  11. src/gecora/cli/__init__.py +1 -0
  12. src/gecora/cli/loomis_painter.py +22 -0
  13. src/gecora/dataset/__init__.py +1 -0
  14. src/gecora/dataset/__pycache__/__init__.cpython-312.pyc +0 -0
  15. src/gecora/dataset/__pycache__/base_manager.cpython-312.pyc +0 -0
  16. src/gecora/dataset/__pycache__/video_manager.cpython-312.pyc +0 -0
  17. src/gecora/dataset/__pycache__/video_pkl_manager.cpython-312.pyc +0 -0
  18. src/gecora/dataset/__pycache__/vieo_pkl_manager.cpython-312.pyc +0 -0
  19. src/gecora/dataset/base_manager.py +82 -0
  20. src/gecora/dataset/create_test_dataset.py +48 -0
  21. src/gecora/dataset/sub_dir_manager.py +117 -0
  22. src/gecora/dataset/video_manager.py +156 -0
  23. src/gecora/dataset/video_pkl_manager.py +207 -0
  24. src/gecora/dataset_converting/__init__.py +1 -0
  25. src/gecora/dataset_converting/video_pkl_to_video.py +67 -0
  26. src/gecora/db/__init__.py +1 -0
  27. src/gecora/db/__pycache__/__init__.cpython-312.pyc +0 -0
  28. src/gecora/db/__pycache__/hf_jsonl.cpython-312.pyc +0 -0
  29. src/gecora/db/__pycache__/sqlite.cpython-312.pyc +0 -0
  30. src/gecora/db/hf_jsonl.py +385 -0
  31. src/gecora/db/sqlite.py +279 -0
  32. src/gecora/logging/__init__.py +1 -0
  33. src/gecora/logging/__pycache__/__init__.cpython-312.pyc +0 -0
  34. src/gecora/logging/__pycache__/logger.cpython-312.pyc +0 -0
  35. src/gecora/logging/logger.py +16 -0
  36. src/gecora/logic/__init__.py +1 -0
  37. src/gecora/logic/__pycache__/__init__.cpython-312.pyc +0 -0
  38. src/gecora/logic/__pycache__/base.cpython-312.pyc +0 -0
  39. src/gecora/logic/__pycache__/loomis_painter.cpython-312.pyc +0 -0
  40. src/gecora/logic/__pycache__/utils.cpython-312.pyc +0 -0
  41. src/gecora/logic/base.py +53 -0
  42. src/gecora/logic/loomis_painter.py +205 -0
  43. src/gecora/logic/utils.py +37 -0
  44. src/gecora/py.typed +0 -0
  45. src/gecora/ranking/__init__.py +1 -0
  46. src/gecora/ranking/__pycache__/__init__.cpython-312.pyc +0 -0
  47. src/gecora/ranking/__pycache__/ranking_system.cpython-312.pyc +0 -0
  48. src/gecora/ranking/ranking_system.py +178 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Wlp User Study
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.6.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: User study for step-by-step painting videos
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test User Study
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Test for a user study
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ # 1. Add the 'src' folder to the Python path so we can import 'gecora'
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
7
+
8
+ from gecora.logic.loomis_painter import LoomisPainterApp
9
+
10
+
11
+ app = LoomisPainterApp(
12
+ root_path="./", dataset_path="data/", hf_repo_id="Markus-Pobitzer/gecora-wlp", desired_num_selections=40
13
+ )
14
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ imageio>=2.37.0
2
+ pillow
3
+ huggingface_hub
src/gecora/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init of project."""
src/gecora/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
src/gecora/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/app/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (166 Bytes). View file
 
src/gecora/app/__pycache__/i_to_v_app.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
src/gecora/app/i_to_v_app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple
2
+
3
+ import gradio as gr
4
+
5
+ from gecora.logic.base import LogicBase
6
+ from gecora.logic.utils import cleanup_list, create_temp_file
7
+
8
+
9
+ class ItoVApp:
10
+ def __init__(
11
+ self,
12
+ logic_class: LogicBase,
13
+ task_desc: str = "Select the best video for the reference image.",
14
+ ref_img_label: str = "Reference Image",
15
+ left_media_label: str = "Left Video",
16
+ right_media_label: str = "Right Video",
17
+ desired_num_selections: Optional[int] = None,
18
+ ):
19
+ self.logic_class = logic_class
20
+ self.desired_num_selections = desired_num_selections
21
+ # User specific input
22
+ self.user_input = gr.Accordion(label="1. Enter your name")
23
+ self.username_input = gr.Textbox(label=None, show_label=False)
24
+ self.select_button = gr.Button("Select Name")
25
+ # Preference task
26
+ self.task = gr.Accordion(label="2. Task", visible=False)
27
+ self.task_description: str = task_desc
28
+ self.reference_image = gr.Image(label=ref_img_label, height=512)
29
+ self.left_media = gr.Video(label=left_media_label, height=512, autoplay=True)
30
+ self.right_media = gr.Video(label=right_media_label, height=512, autoplay=True)
31
+ self.left_button = gr.Button("←")
32
+ self.tie_button = gr.Button("Tie")
33
+ self.right_button = gr.Button("→")
34
+ self.tmp_video_path_left = create_temp_file()
35
+ self.tmp_video_path_right = create_temp_file()
36
+
37
+ def set_username(self, username):
38
+ user_id = self.logic_class.set_username(username=username)
39
+ if username and user_id is not None:
40
+ next_comp = self.logic_class.get_next_comparison(user_id=user_id)
41
+ if next_comp is None:
42
+ gr.Error("Error: Loading the content failed.")
43
+ return 0, gr.update(visible=False), "", "", "", "", None, None, None
44
+
45
+ (
46
+ (ret_reference_id, ret_model_left_id, ret_model_right_id),
47
+ (reference_image, left_video, right_video),
48
+ (num_preferences, total_num_comparison),
49
+ ) = next_comp
50
+
51
+ progress_str = str(num_preferences) + " preferences selected!"
52
+ if total_num_comparison > 0:
53
+ perc_num = (
54
+ self.desired_num_selections if self.desired_num_selections is not None else total_num_comparison
55
+ )
56
+ progress_str = (
57
+ str(int(num_preferences / perc_num * 100))
58
+ + f"% ({num_preferences} / {total_num_comparison} total) preferences selected!"
59
+ )
60
+ return (
61
+ user_id,
62
+ gr.update(visible=True),
63
+ ret_reference_id,
64
+ ret_model_left_id,
65
+ ret_model_right_id,
66
+ progress_str,
67
+ reference_image,
68
+ left_video,
69
+ right_video,
70
+ )
71
+ else:
72
+ gr.Error(
73
+ f"Failed to set Username {username}. Make sure to put in a value in the textfield, otherwise try to reload webpage."
74
+ )
75
+ return 0, gr.update(visible=False), "", "", "", "", None, None, None
76
+
77
+ def set_preference(
78
+ self, user_id, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
79
+ ) -> Tuple[str, str, str, str, Any, Any, Any]:
80
+ """Returns reference_id: str, model_left_id: str, model_right_id: str, reference_image: PIL.Image.Image, left_video:str, right_video:str."""
81
+ user_id = int(user_id) # type: ignore
82
+ succ, msg = self.logic_class.set_preference(
83
+ user_id=user_id,
84
+ reference_id=reference_id,
85
+ model_left_id=model_left_id,
86
+ model_right_id=model_right_id,
87
+ preferred_side=preferred_side,
88
+ )
89
+ if not succ:
90
+ gr.Info(f"Something went wrong: {msg}")
91
+ next_comp = self.logic_class.get_next_comparison(user_id=user_id)
92
+ if next_comp is None:
93
+ gr.Error("We are sorry, something went wrong! Please try to reload the page.")
94
+ return "", "", "", "", None, None, None # type: ignore
95
+
96
+ if preferred_side == "left":
97
+ gr.Success("You chose the left side!\n💪🙂")
98
+ elif preferred_side == "right":
99
+ gr.Success("You chose the right side!\n🙂💪")
100
+ elif preferred_side == "tie":
101
+ gr.Success("It's a tie!\n🤝")
102
+
103
+ (
104
+ (ret_reference_id, ret_model_left_id, ret_model_right_id),
105
+ (reference_image, left_video, right_video),
106
+ (num_preferences, total_num_comparison),
107
+ ) = next_comp
108
+ progress_str = str(num_preferences) + " preferences selected!"
109
+ if total_num_comparison > 0:
110
+ perc_num = self.desired_num_selections if self.desired_num_selections is not None else total_num_comparison
111
+ progress_str = (
112
+ str(int(num_preferences / perc_num * 100))
113
+ + f"% ({num_preferences} / {total_num_comparison} total) preferences selected!"
114
+ )
115
+
116
+ return (
117
+ ret_reference_id,
118
+ ret_model_left_id,
119
+ ret_model_right_id,
120
+ progress_str,
121
+ reference_image,
122
+ left_video,
123
+ right_video,
124
+ )
125
+
126
+ def choose_left(self, user_id: str, reference_id: str, model_left_id: str, model_right_id: str):
127
+ return self.set_preference(
128
+ user_id=user_id,
129
+ reference_id=reference_id,
130
+ model_left_id=model_left_id,
131
+ model_right_id=model_right_id,
132
+ preferred_side="left",
133
+ )
134
+
135
+ def choose_tie(self, user_id: str, reference_id: str, model_left_id: str, model_right_id: str):
136
+ return self.set_preference(
137
+ user_id=user_id,
138
+ reference_id=reference_id,
139
+ model_left_id=model_left_id,
140
+ model_right_id=model_right_id,
141
+ preferred_side="tie",
142
+ )
143
+
144
+ def choose_right(self, user_id: str, reference_id: str, model_left_id: str, model_right_id: str):
145
+ return self.set_preference(
146
+ user_id=user_id,
147
+ reference_id=reference_id,
148
+ model_left_id=model_left_id,
149
+ model_right_id=model_right_id,
150
+ preferred_side="right",
151
+ )
152
+
153
+ def build_interface(self):
154
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
155
+ local_storage = gr.BrowserState(["", ""])
156
+ with self.user_input.render():
157
+ with gr.Row():
158
+ user_id = gr.Textbox(show_label=False, visible=False)
159
+ self.username_input.render()
160
+ self.select_button.render()
161
+
162
+ with self.task.render():
163
+ gr.Markdown(self.task_description)
164
+ reference_id = gr.Textbox(show_label=False, visible=False)
165
+ model_left_id = gr.Textbox(show_label=False, visible=False)
166
+ model_right_id = gr.Textbox(show_label=False, visible=False)
167
+ progress = gr.Textbox(show_label=False, text_align="right")
168
+
169
+ with gr.Row():
170
+ self.left_media.render()
171
+ self.reference_image.render()
172
+ self.right_media.render()
173
+
174
+ with gr.Row():
175
+ self.left_button.render()
176
+ self.tie_button.render()
177
+ self.right_button.render()
178
+
179
+ @demo.load(inputs=[local_storage], outputs=[user_id, self.username_input])
180
+ def load_from_local_storage(saved_values):
181
+ return saved_values[0], saved_values[1]
182
+
183
+ @gr.on([user_id.change], inputs=[user_id, self.username_input], outputs=[local_storage])
184
+ def save_to_local_storage(user_id, username):
185
+ return [user_id, username]
186
+
187
+ self.select_button.click(
188
+ self.set_username,
189
+ inputs=[self.username_input],
190
+ outputs=[
191
+ user_id,
192
+ self.task,
193
+ reference_id,
194
+ model_left_id,
195
+ model_right_id,
196
+ progress,
197
+ self.reference_image,
198
+ self.left_media,
199
+ self.right_media,
200
+ ],
201
+ )
202
+
203
+ self.left_button.click(
204
+ self.choose_left,
205
+ inputs=[user_id, reference_id, model_left_id, model_right_id],
206
+ outputs=[
207
+ reference_id,
208
+ model_left_id,
209
+ model_right_id,
210
+ progress,
211
+ self.reference_image,
212
+ self.left_media,
213
+ self.right_media,
214
+ ],
215
+ )
216
+ self.tie_button.click(
217
+ self.choose_tie,
218
+ inputs=[user_id, reference_id, model_left_id, model_right_id],
219
+ outputs=[
220
+ reference_id,
221
+ model_left_id,
222
+ model_right_id,
223
+ progress,
224
+ self.reference_image,
225
+ self.left_media,
226
+ self.right_media,
227
+ ],
228
+ )
229
+ self.right_button.click(
230
+ self.choose_right,
231
+ inputs=[user_id, reference_id, model_left_id, model_right_id],
232
+ outputs=[
233
+ reference_id,
234
+ model_left_id,
235
+ model_right_id,
236
+ progress,
237
+ self.reference_image,
238
+ self.left_media,
239
+ self.right_media,
240
+ ],
241
+ )
242
+
243
+ demo.unload(fn=lambda: cleanup_list([self.tmp_video_path_left, self.tmp_video_path_right]))
244
+
245
+ return demo
246
+
247
+ def launch(self):
248
+ app = self.build_interface()
249
+ app.launch(server_name="0.0.0.0")
src/gecora/cli/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/cli/loomis_painter.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from gecora.logic.loomis_painter import LoomisPainterApp
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ description="Loomis Painter Ranking CLI - Evaluate painting processes and update model rankings."
9
+ )
10
+
11
+ parser.add_argument("root_path", type=str, help="Root directory containing Hugging Face datasets.")
12
+ parser.add_argument("--dataset_path", type=str, help="Optional path to the dataset.")
13
+
14
+ args = parser.parse_args()
15
+
16
+ app = LoomisPainterApp(args.root_path, args.dataset_path)
17
+ print("App initialized.")
18
+ app.launch()
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()
src/gecora/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/dataset/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (170 Bytes). View file
 
src/gecora/dataset/__pycache__/base_manager.cpython-312.pyc ADDED
Binary file (4.69 kB). View file
 
src/gecora/dataset/__pycache__/video_manager.cpython-312.pyc ADDED
Binary file (8.44 kB). View file
 
src/gecora/dataset/__pycache__/video_pkl_manager.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
src/gecora/dataset/__pycache__/vieo_pkl_manager.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
src/gecora/dataset/base_manager.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from dataclasses import asdict, dataclass
5
+ from typing import Any, List, Optional, Tuple
6
+
7
+ from gecora.logging.logger import setup_file_logger
8
+
9
+
10
+ @dataclass
11
+ class DatasetManagerConfig:
12
+ root_path: str
13
+ gt_dataset: str = "gt"
14
+ dataset_split: str = "test"
15
+ entry_id: str = "entry_id"
16
+ reference_column_name: str = "reference"
17
+ genereated_column_name: str = "generated"
18
+ logging_path: Optional[str] = None
19
+
20
+ def to_json(self, path: str):
21
+ with open(path, "w") as f:
22
+ json.dump(asdict(self), f)
23
+
24
+ @classmethod
25
+ def from_json(cls, path: str):
26
+ with open(path, "r") as f:
27
+ data = json.load(f)
28
+ return cls(**data)
29
+
30
+
31
+ class BaseDatasetManager:
32
+ """
33
+ A class to manage Hugging Face datasets stored in subdirectories of a given root path.
34
+
35
+ Attributes:
36
+ root_path (str): The root directory containing subfolders with Hugging Face datasets.
37
+ dataset_split (str): If the datasets are a Dict, select specified split.
38
+ common_entry_ids (List[str]): List of 'entry_id's present in all datasets.
39
+ partial_entry_ids (List[str]): List of 'entry_id's present in only some datasets.
40
+ """
41
+
42
+ def __init__(self, config: DatasetManagerConfig):
43
+ self.config = config
44
+ self.common_entry_ids: List[str] = []
45
+ self.partial_entry_ids: List[str] = []
46
+
47
+ logger_name = f"{self.__class__.__name__}_logger"
48
+ log_file = os.path.join(self.config.logging_path, f"{logger_name}.txt") if self.config.logging_path else None
49
+ self.logger = setup_file_logger(logger_name, log_file) if log_file else logging.getLogger(logger_name)
50
+
51
+ self.logger = logging.getLogger()
52
+
53
+ def get_dataset_names(self) -> List[str]:
54
+ """
55
+ Returns a list of all loaded dataset names.
56
+
57
+ Returns:
58
+ List[str]: Names of the datasets.
59
+ """
60
+ raise ValueError("Must be implemented by child class.")
61
+
62
+ def get_entries_by_id(
63
+ self, entry_id: str, dataset_name1: str, dataset_name2: str
64
+ ) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
65
+ """
66
+ Retrieves entries from two datasets by a specific 'entry_id'.
67
+
68
+ Args:
69
+ entry_id (str): The 'entry_id' to search for.
70
+ dataset_name1 (str): Name of the first dataset.
71
+ dataset_name2 (str): Name of the second dataset.
72
+
73
+ Returns:
74
+ Tuple[Optional[Any], Optional[Any], Optional[Any]]:
75
+ Reference content, or None if not found.
76
+ Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
77
+ Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
78
+
79
+ Raises:
80
+ ValueError: If one or both dataset names are not found.
81
+ """
82
+ raise ValueError("Must be implemented by child class.")
src/gecora/dataset/create_test_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import imageio
4
+ import numpy as np
5
+ import pandas as pd
6
+ from datasets import Dataset
7
+ from PIL import Image
8
+
9
+
10
+ def generate_random_image(size=(512, 512)):
11
+ array = np.random.randint(0, 256, size + (3,), dtype=np.uint8)
12
+ return Image.fromarray(array)
13
+
14
+
15
+ def generate_random_video(num_frames=10, size=(512, 512)):
16
+ frames = [np.random.randint(0, 256, size + (3,), dtype=np.uint8) for _ in range(num_frames)]
17
+ return frames
18
+
19
+
20
+ def save_video(frames, path):
21
+ imageio.mimsave(path, frames, fps=5)
22
+
23
+
24
+ def create_i_to_v_huggingface_datasets(output_dir, num_datasets=3, num_entries=5):
25
+ os.makedirs(output_dir, exist_ok=True)
26
+ entry_ids = [f"id_{i}" for i in range(num_entries)]
27
+
28
+ for dataset_index in range(num_datasets):
29
+ data = {"entry_id": [], "reference_image": [], "video": []}
30
+ dataset_path = os.path.join(output_dir, f"dataset_{dataset_index}")
31
+ os.makedirs(dataset_path, exist_ok=True)
32
+
33
+ for entry_id in entry_ids:
34
+ img = generate_random_image()
35
+ img_path = os.path.join(dataset_path, f"{entry_id}_image.png")
36
+ img.save(img_path)
37
+
38
+ video_frames = generate_random_video()
39
+ video_path = os.path.join(dataset_path, f"{entry_id}_video.mp4")
40
+ save_video(video_frames, video_path)
41
+
42
+ data["entry_id"].append(entry_id)
43
+ data["reference_image"].append(img_path)
44
+ data["video"].append(video_path)
45
+
46
+ df = pd.DataFrame(data)
47
+ hf_dataset = Dataset.from_pandas(df)
48
+ hf_dataset.save_to_disk(dataset_path)
src/gecora/dataset/sub_dir_manager.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ from datasets import Dataset, DatasetDict, load_from_disk
5
+
6
+ from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
7
+
8
+
9
+ class SubDirDatasetManager(BaseDatasetManager):
10
+ """
11
+ A class to manage Hugging Face datasets stored in subdirectories of a given root path.
12
+
13
+ Attributes:
14
+ root_path (str): The root directory containing subfolders with Hugging Face datasets.
15
+ dataset_split (str): If the datasets are a Dict, select specified split.
16
+ datasets (Dict[str, Union[Dataset, DatasetDict]]): Dictionary mapping dataset names to loaded datasets.
17
+ common_entry_ids (List[str]): List of 'entry_id's present in all datasets.
18
+ partial_entry_ids (List[str]): List of 'entry_id's present in only some datasets.
19
+ """
20
+
21
+ def __init__(self, config: DatasetManagerConfig) -> None:
22
+ """
23
+ Initializes the dataset manager and loads datasets from subdirectories.
24
+
25
+ Args:
26
+ config (DatasetManagerConfig): Configuration.
27
+ """
28
+ super().__init__(config=config)
29
+ self.datasets: Dict[str, Union[Dataset, DatasetDict]] = {}
30
+ self._load_datasets()
31
+ self._analyze_entry_ids()
32
+ self.logger.info(f"Loaded following datasets: {self.get_dataset_names()}")
33
+ self.logger.info(f"Found {len(self.common_entry_ids)} entries in all datasets.")
34
+ self.logger.info(f"Found {len(self.partial_entry_ids)} entries only in some datasets.")
35
+
36
+ def _load_datasets(self) -> None:
37
+ """
38
+ Loads all Hugging Face datasets from subdirectories in the root path.
39
+ """
40
+ for subdir in os.listdir(self.config.root_path):
41
+ full_path = os.path.join(self.config.root_path, subdir)
42
+ if os.path.isdir(full_path):
43
+ try:
44
+ dataset = load_from_disk(full_path)
45
+ self.datasets[subdir] = dataset
46
+ except Exception as e:
47
+ self.logger.info(f"Skipping {subdir}: {e}")
48
+
49
+ def get_dataset_names(self) -> List[str]:
50
+ """
51
+ Returns a list of all loaded dataset names.
52
+
53
+ Returns:
54
+ List[str]: Names of the datasets.
55
+ """
56
+ return list(self.datasets.keys())
57
+
58
+ def _analyze_entry_ids(self) -> None:
59
+ """
60
+ Analyzes all datasets to find common and partial 'entry_id's.
61
+ Updates `common_entry_ids` and `partial_entry_ids` attributes.
62
+ """
63
+ entry_id_sets: List[set] = []
64
+
65
+ for dataset in self.datasets.values():
66
+ if isinstance(dataset, DatasetDict):
67
+ dataset = dataset[self.config.dataset_split]
68
+ entry_ids = set(dataset[self.config.entry_id])
69
+ entry_id_sets.append(entry_ids)
70
+
71
+ if entry_id_sets:
72
+ self.common_entry_ids = list(set.intersection(*entry_id_sets))
73
+ all_entry_ids = set.union(*entry_id_sets)
74
+ self.partial_entry_ids = list(all_entry_ids - set(self.common_entry_ids))
75
+
76
+ def get_entries_by_id(
77
+ self, entry_id: str, dataset_name1: str, dataset_name2: str
78
+ ) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
79
+ """
80
+ Retrieves entries from two datasets by a specific 'entry_id'.
81
+
82
+ Args:
83
+ entry_id (str): The 'entry_id' to search for.
84
+ dataset_name1 (str): Name of the first dataset.
85
+ dataset_name2 (str): Name of the second dataset.
86
+
87
+ Returns:
88
+ Tuple[Optional[Any], Optional[Any], Optional[Any]]:
89
+ Reference content, or None if not found.
90
+ Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
91
+ Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
92
+
93
+ Raises:
94
+ ValueError: If one or both dataset names are not found.
95
+ """
96
+ if dataset_name1 not in self.datasets or dataset_name2 not in self.datasets:
97
+ raise ValueError("One or both dataset names not found.")
98
+
99
+ def find_entry(dataset: Union[Dataset, DatasetDict], entry_id: str) -> Optional[Dict]:
100
+ if isinstance(dataset, DatasetDict):
101
+ dataset = dataset[self.config.dataset_split]
102
+ for entry in dataset:
103
+ if entry.get(self.config.entry_id) == entry_id:
104
+ return entry
105
+ return None
106
+
107
+ entry1 = find_entry(self.datasets[dataset_name1], entry_id)
108
+ entry2 = find_entry(self.datasets[dataset_name2], entry_id)
109
+ if entry1 is not None:
110
+ reference_image = entry1[self.config.reference_column_name]
111
+ else:
112
+ reference_image = None
113
+
114
+ ret_entry1 = entry1[self.config.genereated_column_name] if entry1 is not None else None
115
+ ret_entry2 = entry2[self.config.genereated_column_name] if entry2 is not None else None
116
+
117
+ return reference_image, ret_entry1, ret_entry2
src/gecora/dataset/video_manager.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
6
+
7
+
8
+ logging.basicConfig(
9
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
10
+ datefmt="%m/%d/%Y %H:%M:%S",
11
+ level=logging.INFO,
12
+ )
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class VideoManager(BaseDatasetManager):
18
+ def __init__(self, config: DatasetManagerConfig) -> None:
19
+ """
20
+ Initializes the dataset manager and loads datasets from subdirectories.
21
+
22
+ Args:
23
+ config (DatasetManagerConfig): Configuration.
24
+ """
25
+ super().__init__(config=config)
26
+ self.datasets: Dict[str, Dict[str, Dict[str, str]]] = {}
27
+ self.reference_images: Dict[str, Dict[str, str]] = {}
28
+ self.root_path = Path(self.config.root_path)
29
+ self.common_entry_ids = []
30
+ self.partial_entry_ids = []
31
+ self._load_datasets()
32
+ self._analyze_entry_ids()
33
+ self.logger.info(f"Loaded following datasets: {self.get_dataset_names()}")
34
+ self.logger.info(f"Found {len(self.common_entry_ids)} entries in all datasets.")
35
+ self.logger.info(f"Found {len(self.partial_entry_ids)} entries only in some datasets.")
36
+
37
+ def _collect_reference_images(self, dataset_dir: Path, split: str) -> Dict[str, Dict[str, str]]:
38
+ """Collects all reference images under the split."""
39
+ split_dir = dataset_dir / split
40
+ dataset: Dict[str, Dict[str, str]] = {}
41
+ for source_dir in split_dir.iterdir():
42
+ if source_dir.is_dir():
43
+ for video_dir in source_dir.iterdir():
44
+ if video_dir.is_dir():
45
+ reference_image = video_dir / "reference_image.png"
46
+ if reference_image.is_file():
47
+ source = source_dir.stem
48
+ url = video_dir.stem
49
+ id = source + "_" + url
50
+ dataset[id] = {
51
+ "source": source,
52
+ "url": url,
53
+ "id": id,
54
+ "reference_image": str(reference_image),
55
+ }
56
+ return dataset
57
+
58
+ def _collect_video_dirs(self, dataset_dir: Path, split: str) -> Dict[str, Dict[str, str]]:
59
+ """Collects all video subdirectories under the split."""
60
+ split_dir = dataset_dir / split
61
+ dataset: Dict[str, Dict[str, str]] = {}
62
+ for source_dir in split_dir.iterdir():
63
+ if source_dir.is_dir():
64
+ for video_dir in source_dir.iterdir():
65
+ if video_dir.is_dir():
66
+ video_file = video_dir / "video.mp4"
67
+ if video_file.is_file():
68
+ source = source_dir.stem
69
+ url = video_dir.stem
70
+ id = source + "_" + url
71
+ dataset[id] = {
72
+ "source": source,
73
+ "url": url,
74
+ "id": id,
75
+ "video": str(video_file),
76
+ }
77
+ return dataset
78
+
79
+ def _load_datasets(self) -> None:
80
+ """
81
+ Loads all datasets from subdirectories in the root path.
82
+ """
83
+ for subdir in self.root_path.iterdir():
84
+ if subdir.is_dir():
85
+ dataset_name = subdir.stem
86
+ if dataset_name == self.config.gt_dataset:
87
+ # GT dataset only use reference images
88
+ self.reference_images = self._collect_reference_images(subdir, split=self.config.dataset_split)
89
+ if len(self.reference_images.keys()) == 0:
90
+ raise ValueError(f"No entries found for ground truth dataset {dataset_name} under {subdir}.")
91
+ else:
92
+ try:
93
+ dataset = self._collect_video_dirs(subdir, split=self.config.dataset_split)
94
+ if len(dataset.keys()) == 0:
95
+ self.logger.info(f"No entries found for datataset {dataset}")
96
+ else:
97
+ self.datasets[dataset_name] = dataset
98
+ except Exception as e:
99
+ self.logger.info(f"Skipping {subdir}: {e}")
100
+
101
+ def get_dataset_names(self) -> List[str]:
102
+ """
103
+ Returns a list of all loaded dataset names.
104
+
105
+ Returns:
106
+ List[str]: Names of the datasets.
107
+ """
108
+ return list(self.datasets.keys())
109
+
110
+ def _analyze_entry_ids(self) -> None:
111
+ """
112
+ Analyzes all datasets to find common and partial 'entry_id's.
113
+ Updates `common_entry_ids` and `partial_entry_ids` attributes.
114
+ """
115
+ entry_id_sets: List[set] = []
116
+
117
+ # GT reference images
118
+ entry_id_sets.append(set(self.reference_images.keys()))
119
+ # The other datasets
120
+ for dataset in self.datasets.values():
121
+ entry_ids = set(dataset.keys())
122
+ entry_id_sets.append(entry_ids)
123
+
124
+ if entry_id_sets:
125
+ self.common_entry_ids = list(set.intersection(*entry_id_sets))
126
+ all_entry_ids = set.union(*entry_id_sets)
127
+ self.partial_entry_ids = list(all_entry_ids - set(self.common_entry_ids))
128
+
129
+ def get_entries_by_id(
130
+ self, entry_id: str, dataset_name1: str, dataset_name2: str
131
+ ) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
132
+ """
133
+ Retrieves entries from two datasets by a specific 'entry_id'.
134
+
135
+ Args:
136
+ entry_id (str): The 'entry_id' to search for.
137
+ dataset_name1 (str): Name of the first dataset.
138
+ dataset_name2 (str): Name of the second dataset.
139
+
140
+ Returns:
141
+ Tuple[Optional[Any], Optional[Any], Optional[Any]]:
142
+ Reference content, or None if not found.
143
+ Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
144
+ Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
145
+
146
+ Raises:
147
+ ValueError: If one or both dataset names are not found.
148
+ """
149
+ if dataset_name1 not in self.datasets or dataset_name2 not in self.datasets:
150
+ raise ValueError("One or both dataset names not found.")
151
+
152
+ entry1 = self.datasets[dataset_name1][entry_id]
153
+ entry2 = self.datasets[dataset_name2][entry_id]
154
+ reference_image = self.reference_images[entry_id]["reference_image"]
155
+
156
+ return reference_image, entry1["video"], entry2["video"]
src/gecora/dataset/video_pkl_manager.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pickle
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from PIL import Image
9
+
10
+ from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class VideoFrameDataset:
17
+ """
18
+ PyTorch Dataset for loading video frames and progress values from a custom directory structure.
19
+
20
+ Each item corresponds to one video and returns:
21
+ - frames: List of compressed image bytes
22
+ - progress: List of float progress values
23
+
24
+ Args:
25
+ dataset_dir (str): Root directory of the dataset.
26
+ split (str): Either 'train' or 'test'.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset_dir: str,
32
+ split: str = "test",
33
+ ):
34
+ assert split in ["train", "test", "eval"], "Split must be 'train', 'test' or 'eval'"
35
+ self.dataset_dir = Path(dataset_dir)
36
+ self.split = split
37
+ self.video_dirs: List[Path] = self._collect_video_dirs()
38
+ self.video_id_to_idx: Dict[str, int] = {}
39
+ for idx, video_dir in enumerate(self.video_dirs):
40
+ url = video_dir.stem
41
+ source = video_dir.parts[-2]
42
+ id = source + "_" + url
43
+ self.video_id_to_idx[id] = idx
44
+
45
+ def _collect_video_dirs(self) -> List[Path]:
46
+ """Collects all video subdirectories under the split."""
47
+ split_dir = self.dataset_dir / self.split
48
+ video_dirs = []
49
+ for source_dir in split_dir.iterdir():
50
+ if source_dir.is_dir():
51
+ for video_dir in source_dir.iterdir():
52
+ if video_dir.is_dir():
53
+ video_dirs.append(video_dir)
54
+ return sorted(video_dirs)
55
+
56
+ def _load_ref_frame(self, video_dir: Path, frame_data: List[bytes]) -> Image.Image:
57
+ reference_frame_path = video_dir / "reference_frame.png"
58
+ reference_frame = None
59
+ try:
60
+ reference_frame = Image.open(reference_frame_path)
61
+ except Exception:
62
+ pass
63
+
64
+ # Load last frame if reference frame is not set
65
+ if reference_frame is None:
66
+ reference_frame = Image.open(BytesIO(frame_data[-1]))
67
+
68
+ return reference_frame.convert("RGB")
69
+
70
+ def _prepare_frame(self, frame: bytes) -> Image.Image:
71
+ img = Image.open(BytesIO(frame)).convert("RGB")
72
+ return img
73
+
74
+ def __len__(self) -> int:
75
+ return len(self.video_dirs)
76
+
77
+ def __getitem__(self, idx: int) -> Any:
78
+ video_dir = self.video_dirs[idx]
79
+ frame_path = video_dir / "frame_data.pkl"
80
+ progress_path = video_dir / "frame_progress.pkl"
81
+
82
+ try:
83
+ with open(frame_path, "rb") as f:
84
+ frame_data: List[bytes] = pickle.load(f)
85
+
86
+ reference_image = self._load_ref_frame(video_dir=video_dir, frame_data=frame_data)
87
+
88
+ with open(progress_path, "rb") as f:
89
+ frame_progress: List[float] = pickle.load(f)
90
+
91
+ frame_img_list = [self._prepare_frame(fd) for fd in frame_data]
92
+
93
+ except Exception as e:
94
+ logger.error(f"While loading data for video {video_dir} an error occurred: {e}")
95
+ # Better to raise an error as long as we do not have a workaround
96
+ raise (ValueError(e))
97
+
98
+ return {
99
+ "video_dir": str(video_dir),
100
+ "video": frame_img_list,
101
+ "reference_image": reference_image,
102
+ "progress_steps": frame_progress,
103
+ }
104
+
105
+
106
+ class VideoPklManager(BaseDatasetManager):
107
+ """
108
+ A class to manage datasets stored in video pkl subdirectories of a given root path.
109
+
110
+ Attributes:
111
+ root_path (str): The root directory containing subfolders with Hugging Face datasets.
112
+ dataset_split (str): If the datasets are a Dict, select specified split.
113
+ datasets (Dict[str, Union[Dataset, DatasetDict]]): Dictionary mapping dataset names to loaded datasets.
114
+ common_entry_ids (List[str]): List of 'entry_id's present in all datasets.
115
+ partial_entry_ids (List[str]): List of 'entry_id's present in only some datasets.
116
+ """
117
+
118
+ def __init__(self, config: DatasetManagerConfig) -> None:
119
+ """
120
+ Initializes the dataset manager and loads datasets from subdirectories.
121
+
122
+ Args:
123
+ config (DatasetManagerConfig): Configuration.
124
+ """
125
+ super().__init__(config=config)
126
+ self.datasets: Dict[str, VideoFrameDataset] = {}
127
+ self._load_datasets()
128
+ self._analyze_entry_ids()
129
+ self.logger.info(f"Loaded following datasets: {self.get_dataset_names()}")
130
+ self.logger.info(f"Found {len(self.common_entry_ids)} entries in all datasets.")
131
+ self.logger.info(f"Found {len(self.partial_entry_ids)} entries only in some datasets.")
132
+
133
+ def _load_datasets(self) -> None:
134
+ """
135
+ Loads all Hugging Face datasets from subdirectories in the root path.
136
+ """
137
+ for subdir in os.listdir(self.config.root_path):
138
+ full_path = os.path.join(self.config.root_path, subdir)
139
+ if os.path.isdir(full_path):
140
+ try:
141
+ dataset = VideoFrameDataset(full_path, split=self.config.dataset_split)
142
+ self.datasets[subdir] = dataset
143
+ except Exception as e:
144
+ self.logger.info(f"Skipping {subdir}: {e}")
145
+
146
+ def get_dataset_names(self) -> List[str]:
147
+ """
148
+ Returns a list of all loaded dataset names.
149
+
150
+ Returns:
151
+ List[str]: Names of the datasets.
152
+ """
153
+ return list(self.datasets.keys())
154
+
155
+ def _analyze_entry_ids(self) -> None:
156
+ """
157
+ Analyzes all datasets to find common and partial 'entry_id's.
158
+ Updates `common_entry_ids` and `partial_entry_ids` attributes.
159
+ """
160
+ entry_id_sets: List[set] = []
161
+
162
+ for dataset in self.datasets.values():
163
+ entry_ids = set(dataset.video_id_to_idx.keys())
164
+ entry_id_sets.append(entry_ids)
165
+
166
+ if entry_id_sets:
167
+ self.common_entry_ids = list(set.intersection(*entry_id_sets))
168
+ all_entry_ids = set.union(*entry_id_sets)
169
+ self.partial_entry_ids = list(all_entry_ids - set(self.common_entry_ids))
170
+
171
+ def get_entries_by_id(
172
+ self, entry_id: str, dataset_name1: str, dataset_name2: str
173
+ ) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
174
+ """
175
+ Retrieves entries from two datasets by a specific 'entry_id'.
176
+
177
+ Args:
178
+ entry_id (str): The 'entry_id' to search for.
179
+ dataset_name1 (str): Name of the first dataset.
180
+ dataset_name2 (str): Name of the second dataset.
181
+
182
+ Returns:
183
+ Tuple[Optional[Any], Optional[Any], Optional[Any]]:
184
+ Reference content, or None if not found.
185
+ Generated entry from dataset_name1 matching the 'entry_id', or None if not found.
186
+ Generated entry from dataset_name2 matching the 'entry_id', or None if not found.
187
+
188
+ Raises:
189
+ ValueError: If one or both dataset names are not found.
190
+ """
191
+ if dataset_name1 not in self.datasets or dataset_name2 not in self.datasets:
192
+ raise ValueError("One or both dataset names not found.")
193
+
194
+ def find_entry(dataset: VideoFrameDataset, entry_id: str) -> Optional[Dict]:
195
+ if entry_id in dataset.video_id_to_idx:
196
+ return dataset[dataset.video_id_to_idx[entry_id]]
197
+ else:
198
+ return None
199
+
200
+ entry1 = find_entry(self.datasets[dataset_name1], entry_id)
201
+ entry2 = find_entry(self.datasets[dataset_name2], entry_id)
202
+
203
+ reference_image = entry1["reference_image"] if entry1 is not None else None
204
+ ret_entry1 = entry1["video"] if entry1 is not None else None
205
+ ret_entry2 = entry2["video"] if entry2 is not None else None
206
+
207
+ return reference_image, ret_entry1, ret_entry2
src/gecora/dataset_converting/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/dataset_converting/video_pkl_to_video.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ from tqdm import tqdm
5
+
6
+ from gecora.dataset.video_pkl_manager import VideoFrameDataset
7
+ from gecora.logic.utils import save_video_from_frames
8
+
9
+
10
+ def convert(
11
+ dataset_dir: str,
12
+ output_dir: str,
13
+ split: str = "test",
14
+ fps: int = 3,
15
+ ):
16
+ """
17
+ Converts a VideoFrameDataset into video files.
18
+
19
+ Args:
20
+ dataset_dir: Path to the dataset directory.
21
+ output_dir: Path where the output videos will be saved.
22
+ split: The dataset split to process (e.g., 'train', 'test', 'eval').
23
+ fps: Frames per second for the output video.
24
+ """
25
+ print(f"Loading dataset from: {dataset_dir} (Split: {split})")
26
+ dataset = VideoFrameDataset(dataset_dir=dataset_dir, split=split)
27
+
28
+ num_entries = len(dataset)
29
+ print(f"Found {num_entries} entries. Starting conversion...")
30
+
31
+ for idx in tqdm(range(num_entries)):
32
+ entry_dict = dataset[idx]
33
+ video_dir = entry_dict["video_dir"]
34
+
35
+ # Extract source and ID from the path parts
36
+ dir_split = Path(video_dir).parts
37
+ source = dir_split[-2]
38
+ id = dir_split[-1]
39
+
40
+ # Construct output path
41
+ out_path = Path(output_dir) / split / source / id
42
+ out_path.mkdir(parents=True, exist_ok=True)
43
+
44
+ # Save reference_image
45
+ entry_dict["reference_image"].save(out_path / "reference_image.png")
46
+
47
+ # Save the video
48
+ output_file = out_path / "video.mp4"
49
+ save_video_from_frames(entry_dict["video"], video_output_path=str(output_file), fps=fps)
50
+
51
+ print("Conversion complete.")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser(description="Convert VideoFrameDataset entries to MP4 video files.")
56
+
57
+ # Required arguments
58
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the root directory of the dataset.")
59
+ parser.add_argument("--output_dir", type=str, required=True, help="Directory where output videos will be saved.")
60
+
61
+ # Optional arguments
62
+ parser.add_argument("--split", type=str, default="test", help="Dataset split to process (default: 'test').")
63
+ parser.add_argument("--fps", type=int, default=3, help="Frames per second for the output video (default: 2).")
64
+
65
+ args = parser.parse_args()
66
+
67
+ convert(dataset_dir=args.dataset_dir, output_dir=args.output_dir, split=args.split, fps=args.fps)
src/gecora/db/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/db/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (165 Bytes). View file
 
src/gecora/db/__pycache__/hf_jsonl.cpython-312.pyc ADDED
Binary file (16.4 kB). View file
 
src/gecora/db/__pycache__/sqlite.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
src/gecora/db/hf_jsonl.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Class for Database handling using HuggingFace Datasets with JSONL storage."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ from huggingface_hub import HfApi, hf_hub_download, upload_file
10
+
11
+
12
+ class HFJsonlDB:
13
+ """
14
+ Handles database operations using HuggingFace datasets with JSONL files.
15
+
16
+ Stores data in JSONL files on HuggingFace datasets for persistent storage in HF Spaces.
17
+
18
+ Attributes:
19
+ repo_id (str): HuggingFace repository ID (e.g., "username/dataset-name").
20
+ experiment_name (str): Name of the experiment (used to name files).
21
+ token (Optional[str]): HuggingFace API token for authentication.
22
+ users_filename (str): Filename for users JSONL file.
23
+ preferences_filename (str): Filename for preferences JSONL file.
24
+ logger (logging.Logger): Logger instance for logging database operations.
25
+ hf_api (HfApi): HuggingFace API client.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ repo_id: str,
31
+ experiment_name: str = "arena",
32
+ token: Optional[str] = None,
33
+ log_folder_path: Optional[str] = None,
34
+ ):
35
+ """
36
+ Initializes the HFJsonlDB instance.
37
+
38
+ Args:
39
+ repo_id (str): HuggingFace repository ID (e.g., "username/dataset-name").
40
+ experiment_name (str, optional): Name of the experiment. Defaults to "arena".
41
+ token (Optional[str], optional): HuggingFace API token. If None, uses HF_TOKEN env var.
42
+ log_folder_path (Optional[str], optional): Path for log file. Defaults to current directory.
43
+ """
44
+ self.repo_id = repo_id
45
+ self.experiment_name = experiment_name
46
+ self.token = token or os.environ.get("HF_TOKEN")
47
+ self.users_filename = f"{experiment_name.lower()}_users.jsonl"
48
+ self.preferences_filename = f"{experiment_name.lower()}_preferences.jsonl"
49
+
50
+ # Setup logging
51
+ if log_folder_path is None:
52
+ log_folder_path = "."
53
+ self.log_path = os.path.join(log_folder_path, "log_hf_db.txt")
54
+ logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
55
+ self.logger = logging.getLogger(self.__class__.__name__)
56
+
57
+ # Initialize HF API
58
+ self.hf_api = HfApi(token=self.token)
59
+
60
+ # Cache for data
61
+ self._users_cache: List[Dict[str, Any]] = []
62
+ self._preferences_cache: List[Dict[str, Any]] = []
63
+ self._cache_loaded = False
64
+
65
+ def initialize_database(self) -> bool:
66
+ """
67
+ Initializes the database by ensuring the HF dataset exists and files are present.
68
+
69
+ Creates empty JSONL files if they don't exist on the HF dataset.
70
+
71
+ Returns:
72
+ bool: True if successful, False if an error occurred.
73
+ """
74
+ try:
75
+ # Check if repository exists, if not create it
76
+ try:
77
+ self.hf_api.repo_info(repo_id=self.repo_id, repo_type="dataset")
78
+ self.logger.info(f"Repository {self.repo_id} already exists")
79
+ except Exception:
80
+ self.logger.info(f"Creating repository {self.repo_id}")
81
+ self.hf_api.create_repo(
82
+ repo_id=self.repo_id,
83
+ repo_type="dataset",
84
+ exist_ok=True,
85
+ private=True,
86
+ )
87
+
88
+ # Try to load existing files or create new ones
89
+ self._load_data()
90
+
91
+ # If files don't exist, create them
92
+ if not self._cache_loaded:
93
+ self._save_users([])
94
+ self._save_preferences([])
95
+ self._users_cache = []
96
+ self._preferences_cache = []
97
+ self._cache_loaded = True
98
+
99
+ self.logger.info("Database initialized successfully")
100
+ return True
101
+
102
+ except Exception as e:
103
+ self.logger.error(f"Initializing the database failed with error: {e}")
104
+ return False
105
+
106
+ def _load_data(self):
107
+ """Load data from HF dataset into cache."""
108
+ try:
109
+ # Load users
110
+ try:
111
+ users_path = hf_hub_download(
112
+ repo_id=self.repo_id,
113
+ filename=self.users_filename,
114
+ repo_type="dataset",
115
+ token=self.token,
116
+ )
117
+ with open(users_path, "r") as f:
118
+ self._users_cache = [json.loads(line) for line in f if line.strip()]
119
+ except Exception as e:
120
+ self.logger.info(f"Users file not found, will create new: {e}")
121
+ self._users_cache = []
122
+
123
+ # Load preferences
124
+ try:
125
+ prefs_path = hf_hub_download(
126
+ repo_id=self.repo_id,
127
+ filename=self.preferences_filename,
128
+ repo_type="dataset",
129
+ token=self.token,
130
+ )
131
+ with open(prefs_path, "r") as f:
132
+ self._preferences_cache = [json.loads(line) for line in f if line.strip()]
133
+ except Exception as e:
134
+ self.logger.info(f"Preferences file not found, will create new: {e}")
135
+ self._preferences_cache = []
136
+
137
+ self._cache_loaded = True
138
+
139
+ except Exception as e:
140
+ self.logger.error(f"Error loading data: {e}")
141
+ self._cache_loaded = False
142
+
143
+ def _save_users(self, users: List[Dict[str, Any]]):
144
+ """Save users to HF dataset."""
145
+ temp_path = f"/tmp/{self.users_filename}"
146
+ with open(temp_path, "w") as f:
147
+ for user in users:
148
+ f.write(json.dumps(user) + "\n")
149
+
150
+ upload_file(
151
+ path_or_fileobj=temp_path,
152
+ path_in_repo=self.users_filename,
153
+ repo_id=self.repo_id,
154
+ repo_type="dataset",
155
+ token=self.token,
156
+ )
157
+ os.remove(temp_path)
158
+
159
+ def _save_preferences(self, preferences: List[Dict[str, Any]]):
160
+ """Save preferences to HF dataset."""
161
+ temp_path = f"/tmp/{self.preferences_filename}"
162
+ with open(temp_path, "w") as f:
163
+ for pref in preferences:
164
+ f.write(json.dumps(pref) + "\n")
165
+
166
+ upload_file(
167
+ path_or_fileobj=temp_path,
168
+ path_in_repo=self.preferences_filename,
169
+ repo_id=self.repo_id,
170
+ repo_type="dataset",
171
+ token=self.token,
172
+ )
173
+ os.remove(temp_path)
174
+
175
+ def create_user(self, username: str) -> Tuple[Optional[int], str]:
176
+ """
177
+ Creates a new user in the database.
178
+
179
+ Args:
180
+ username (str): The username of the new user.
181
+
182
+ Returns:
183
+ Tuple[Optional[int], str]:
184
+ The user_id of the newly created user, or None if creation failed.
185
+ If first entry is None then the second contains the exception message.
186
+ """
187
+ try:
188
+ # Reload data to ensure we have latest
189
+ self._load_data()
190
+
191
+ # Check if user already exists
192
+ for user in self._users_cache:
193
+ if user["username"] == username:
194
+ msg = f"User '{username}' already exists."
195
+ self.logger.warning(msg)
196
+ return None, msg
197
+
198
+ # Generate new user_id
199
+ user_id = max([u["user_id"] for u in self._users_cache], default=0) + 1
200
+
201
+ # Create new user
202
+ new_user = {
203
+ "user_id": user_id,
204
+ "username": username,
205
+ "created_at": datetime.now().isoformat(),
206
+ }
207
+
208
+ self._users_cache.append(new_user)
209
+ self._save_users(self._users_cache)
210
+
211
+ self.logger.info(f"User '{username}' created with user_id {user_id}")
212
+ return user_id, username
213
+
214
+ except Exception as e:
215
+ msg = f"Failed to create user '{username}': {e}"
216
+ self.logger.error(msg)
217
+ return None, msg
218
+
219
+ def get_user_id_by_username(self, username: str) -> Optional[int]:
220
+ """
221
+ Checks if a username exists in the database and returns the associated user_id.
222
+
223
+ Args:
224
+ username (str): The username to look up.
225
+
226
+ Returns:
227
+ Optional[int]: The user_id if the username exists, None otherwise.
228
+ """
229
+ try:
230
+ # Reload data to ensure we have latest
231
+ self._load_data()
232
+
233
+ for user in self._users_cache:
234
+ if user["username"] == username:
235
+ return user["user_id"]
236
+
237
+ return None
238
+
239
+ except Exception as e:
240
+ self.logger.error(f"Error checking username '{username}': {e}")
241
+ return None
242
+
243
+ def insert_preference(
244
+ self,
245
+ user_id: int,
246
+ reference_id: str,
247
+ model_left_id: str,
248
+ model_right_id: str,
249
+ preferred_side: str,
250
+ ) -> Tuple[bool, str]:
251
+ """
252
+ Inserts a new preference entry into the database.
253
+
254
+ Args:
255
+ user_id (int): ID of the user making the preference.
256
+ reference_id (str): ID of the reference image.
257
+ model_left_id (str): ID of the left model's generated image.
258
+ model_right_id (str): ID of the right model's generated image.
259
+ preferred_side (str): The preferred side ('left', 'right', or 'tie').
260
+
261
+ Returns:
262
+ Tuple[bool, str]: True if insertion was successful, False otherwise with a
263
+ string message describing the exception.
264
+ """
265
+ msg = ""
266
+ if preferred_side not in {"left", "right", "tie"}:
267
+ msg = f"Invalid preferred_side value: {preferred_side}"
268
+ self.logger.error(msg)
269
+ return False, msg
270
+
271
+ try:
272
+ # Reload data to ensure we have latest
273
+ self._load_data()
274
+
275
+ # Generate new preference_id
276
+ preference_id = max([p["preference_id"] for p in self._preferences_cache], default=0) + 1
277
+
278
+ # Create new preference
279
+ new_preference = {
280
+ "preference_id": preference_id,
281
+ "user_id": user_id,
282
+ "reference_id": reference_id,
283
+ "model_left_id": model_left_id,
284
+ "model_right_id": model_right_id,
285
+ "preferred_side": preferred_side,
286
+ "timestamp": datetime.now().isoformat(),
287
+ }
288
+
289
+ self._preferences_cache.append(new_preference)
290
+ self._save_preferences(self._preferences_cache)
291
+
292
+ self.logger.info(f"Preference inserted for user_id {user_id}")
293
+ return True, msg
294
+
295
+ except Exception as e:
296
+ msg = f"Failed to insert preference: {e}"
297
+ self.logger.error(msg)
298
+ return False, msg
299
+
300
+ def get_all_preferences(self) -> List[Tuple]:
301
+ """
302
+ Retrieves all preference entries from the database.
303
+
304
+ Returns:
305
+ List[Tuple]: A list of tuples representing all preference entries.
306
+ """
307
+ try:
308
+ self._load_data()
309
+
310
+ # Convert dicts to tuples matching SQLite format
311
+ result = []
312
+ for pref in self._preferences_cache:
313
+ result.append(
314
+ (
315
+ pref["preference_id"],
316
+ pref["user_id"],
317
+ pref["reference_id"],
318
+ pref["model_left_id"],
319
+ pref["model_right_id"],
320
+ pref["preferred_side"],
321
+ pref["timestamp"],
322
+ )
323
+ )
324
+ return result
325
+
326
+ except Exception as e:
327
+ self.logger.error(f"Failed to retrieve preferences: {e}")
328
+ return []
329
+
330
+ def get_preferences_by_user(self, user_id: int) -> List[Tuple]:
331
+ """
332
+ Retrieves all preference entries for a specific user.
333
+
334
+ Args:
335
+ user_id (int): The ID of the user.
336
+
337
+ Returns:
338
+ List[Tuple]: A list of tuples representing the user's preference entries.
339
+ """
340
+ try:
341
+ self._load_data()
342
+
343
+ # Filter by user_id and convert to tuples
344
+ result = []
345
+ for pref in self._preferences_cache:
346
+ if pref["user_id"] == user_id:
347
+ result.append(
348
+ (
349
+ pref["preference_id"],
350
+ pref["user_id"],
351
+ pref["reference_id"],
352
+ pref["model_left_id"],
353
+ pref["model_right_id"],
354
+ pref["preferred_side"],
355
+ pref["timestamp"],
356
+ )
357
+ )
358
+
359
+ self.logger.info(f"Retrieved {len(result)} preferences for user_id {user_id}.")
360
+ return result
361
+
362
+ except Exception as e:
363
+ self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}")
364
+ return []
365
+
366
+ def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]:
367
+ """
368
+ Maps a list of preference tuples to a list of dictionaries using the Preference schema.
369
+
370
+ Args:
371
+ preferences (List[Tuple]): List of tuples from the Preference table.
372
+
373
+ Returns:
374
+ List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema.
375
+ """
376
+ keys = [
377
+ "preference_id",
378
+ "user_id",
379
+ "reference_id",
380
+ "model_left_id",
381
+ "model_right_id",
382
+ "preferred_side",
383
+ "timestamp",
384
+ ]
385
+ return [dict(zip(keys, row)) for row in preferences]
src/gecora/db/sqlite.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Class for Database handling using SQLite."""
2
+
3
+ import logging
4
+ import os
5
+ import sqlite3
6
+ from sqlite3 import Connection
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+
10
+ class SQLiteDB:
11
+ """
12
+ Handles SQLite database operations for an image generation comparison experiment.
13
+
14
+ Attributes:
15
+ db_folder_path (str): Directory where the SQLite database file is stored.
16
+ experiment_name (str): Name of the experiment (used to name the database file).
17
+ db_filename (str): Filename of the SQLite database.
18
+ db_path (str): Full path to the SQLite database file.
19
+ logger (logging.Logger): Logger instance for logging database operations.
20
+ """
21
+
22
+ def __init__(self, db_folder_path: str, experiment_name: str = "arena"):
23
+ """
24
+ Initializes the SQLiteDB instance with dataset and database configuration.
25
+
26
+ Args:
27
+ db_folder_path (str): Directory where the SQLite database file is stored.
28
+ experiment_name (str, optional): Name of the experiment. Defaults to "arena".
29
+ """
30
+ self.experiment_name = experiment_name
31
+ self.db_folder_path = db_folder_path
32
+ self.db_filename = f"{experiment_name.lower()}.db"
33
+ self.db_path = os.path.join(db_folder_path, self.db_filename)
34
+ self.log_path = os.path.join(db_folder_path, "log_db.txt")
35
+ logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
36
+ self.logger = logging.getLogger()
37
+ self.conn: Optional[Connection] = None
38
+
39
+ def __del__(self):
40
+ if self.conn is not None:
41
+ try:
42
+ self.conn.close()
43
+ except Exception:
44
+ pass
45
+
46
+ def initialize_database(self) -> bool:
47
+ """
48
+ Initializes the SQLite database and creates required tables if they do not exist.
49
+
50
+ Tables:
51
+ - User: Stores user information.
52
+ - Preference: Stores user preferences between generated images.
53
+
54
+ Returns:
55
+ bool: True if successful, False if an error occurred.
56
+ """
57
+ try:
58
+ db_exists = os.path.exists(self.db_path)
59
+ if self.conn is None:
60
+ self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
61
+ with self.conn: # auto commit
62
+ cursor = self.conn.cursor()
63
+
64
+ if db_exists:
65
+ self.logger.info(f"Database already exists at {self.db_path}")
66
+ return True
67
+
68
+ self.logger.info(f"Creating new database at {self.db_path}")
69
+ os.makedirs(self.db_folder_path, exist_ok=True)
70
+
71
+ cursor.execute("""
72
+ CREATE TABLE IF NOT EXISTS User (
73
+ user_id INTEGER PRIMARY KEY AUTOINCREMENT,
74
+ username TEXT UNIQUE,
75
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
76
+ )
77
+ """)
78
+
79
+ cursor.execute("""
80
+ CREATE TABLE IF NOT EXISTS Preference (
81
+ preference_id INTEGER PRIMARY KEY AUTOINCREMENT,
82
+ user_id INTEGER,
83
+ reference_id TEXT,
84
+ model_left_id TEXT,
85
+ model_right_id TEXT,
86
+ preferred_side TEXT CHECK(preferred_side IN ('left', 'right', 'tie')),
87
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
88
+ FOREIGN KEY (user_id) REFERENCES User(user_id)
89
+ )
90
+ """)
91
+ cursor.close()
92
+ except Exception as e:
93
+ self.logger.info(f"Creating the database failed with following error: {e}")
94
+ return False
95
+
96
+ return True
97
+
98
+ def create_user(self, username: str) -> Tuple[Optional[int], str]:
99
+ """
100
+ Creates a new user in the database.
101
+
102
+ Args:
103
+ username (str): The username of the new user.
104
+
105
+ Returns:
106
+ Tuple[Optional[int], str]:
107
+ The user_id of the newly created user, or None if creation failed.
108
+ If first entry is None then the second contains the exception message.
109
+ """
110
+ ret: Optional[int] = None
111
+ msg = ""
112
+ try:
113
+ if self.conn is None:
114
+ self.conn = sqlite3.connect(self.db_path)
115
+
116
+ with self.conn:
117
+ cursor = self.conn.cursor()
118
+ cursor.execute(
119
+ """
120
+ INSERT INTO User (username) VALUES (?)
121
+ """,
122
+ (username,),
123
+ )
124
+ user_id = cursor.lastrowid
125
+ cursor.close()
126
+ self.logger.info(f"User '{username}' created with user_id {user_id}")
127
+ ret = user_id
128
+ msg = username
129
+ except sqlite3.IntegrityError:
130
+ msg = f"User '{username}' already exists."
131
+ self.logger.warning(msg)
132
+ except Exception as e:
133
+ msg = f"Failed to create user '{username}': {e}"
134
+ self.logger.error(msg)
135
+ return ret, msg
136
+
137
+ def get_user_id_by_username(self, username: str) -> Optional[int]:
138
+ """
139
+ Checks if a username exists in the database and returns the associated user_id.
140
+
141
+ Args:
142
+ username (str): The username to look up.
143
+
144
+ Returns:
145
+ Optional[int]: The user_id if the username exists, None otherwise.
146
+ """
147
+ try:
148
+ if self.conn is None:
149
+ self.conn = sqlite3.connect(self.db_path)
150
+
151
+ cursor = self.conn.cursor()
152
+ cursor.execute(
153
+ """
154
+ SELECT user_id FROM User WHERE username = ?
155
+ """,
156
+ (username,),
157
+ )
158
+ result = cursor.fetchone()
159
+ cursor.close()
160
+
161
+ if result:
162
+ return result[0]
163
+ else:
164
+ return None
165
+ except Exception as e:
166
+ self.logger.error(f"Error checking username '{username}': {e}")
167
+ return None
168
+
169
+ def insert_preference(
170
+ self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
171
+ ) -> Tuple[bool, str]:
172
+ """
173
+ Inserts a new preference entry into the database.
174
+
175
+ Args:
176
+ user_id (int): ID of the user making the preference.
177
+ reference_id (str): ID of the reference image.
178
+ model_left_id (str): ID of the left model's generated image.
179
+ model_right_id (str): ID of the right model's generated image.
180
+ preferred_side (str): The preferred side ('left', 'right', or 'tie').
181
+
182
+ Returns:
183
+ Tuple[bool, str]: True if insertion was successful, False otherwise with a
184
+ string message describing the exception.
185
+ """
186
+ msg = ""
187
+ if preferred_side not in {"left", "right", "tie"}:
188
+ msg = f"Invalid preferred_side value: {preferred_side}"
189
+ self.logger.error(msg)
190
+ return False, msg
191
+
192
+ try:
193
+ if self.conn is None:
194
+ self.conn = sqlite3.connect(self.db_path)
195
+
196
+ with self.conn:
197
+ cursor = self.conn.cursor()
198
+ cursor.execute(
199
+ """
200
+ INSERT INTO Preference (
201
+ user_id, reference_id, model_left_id, model_right_id, preferred_side
202
+ ) VALUES (?, ?, ?, ?, ?)
203
+ """,
204
+ (user_id, reference_id, model_left_id, model_right_id, preferred_side),
205
+ )
206
+ cursor.close()
207
+ self.logger.info(f"Preference inserted for user_id {user_id}")
208
+ return True, msg
209
+ except Exception as e:
210
+ msg = f"Failed to insert preference: {e}"
211
+ self.logger.error(msg)
212
+ return False, msg
213
+
214
+ def get_all_preferences(self) -> List[Tuple]:
215
+ """
216
+ Retrieves all preference entries from the database.
217
+
218
+ Returns:
219
+ List[Tuple]: A list of tuples representing all preference entries.
220
+ """
221
+ preferences = []
222
+ try:
223
+ if self.conn is None:
224
+ self.conn = sqlite3.connect(self.db_path)
225
+
226
+ with self.conn:
227
+ cursor = self.conn.cursor()
228
+ cursor.execute("SELECT * FROM Preference")
229
+ preferences = cursor.fetchall()
230
+ cursor.close()
231
+ except Exception as e:
232
+ self.logger.error(f"Failed to retrieve preferences: {e}")
233
+ return preferences
234
+
235
+ def get_preferences_by_user(self, user_id: int) -> List[Tuple]:
236
+ """
237
+ Retrieves all preference entries for a specific user.
238
+
239
+ Args:
240
+ user_id (int): The ID of the user.
241
+
242
+ Returns:
243
+ List[Tuple]: A list of tuples representing the user's preference entries.
244
+ """
245
+ preferences = []
246
+ try:
247
+ if self.conn is None:
248
+ self.conn = sqlite3.connect(self.db_path)
249
+
250
+ with self.conn:
251
+ cursor = self.conn.cursor()
252
+ cursor.execute("SELECT * FROM Preference WHERE user_id = ?", (user_id,))
253
+ preferences = cursor.fetchall()
254
+ cursor.close()
255
+ self.logger.info(f"Retrieved {len(preferences)} preferences for user_id {user_id}.")
256
+ except Exception as e:
257
+ self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}")
258
+ return preferences
259
+
260
+ def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]:
261
+ """
262
+ Maps a list of preference tuples to a list of dictionaries using the Preference schema.
263
+
264
+ Args:
265
+ preferences (List[Tuple]): List of tuples from the Preference table.
266
+
267
+ Returns:
268
+ List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema.
269
+ """
270
+ keys = [
271
+ "preference_id",
272
+ "user_id",
273
+ "reference_id",
274
+ "model_left_id",
275
+ "model_right_id",
276
+ "preferred_side",
277
+ "timestamp",
278
+ ]
279
+ return [dict(zip(keys, row)) for row in preferences]
src/gecora/logging/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/logging/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (170 Bytes). View file
 
src/gecora/logging/__pycache__/logger.cpython-312.pyc ADDED
Binary file (1.01 kB). View file
 
src/gecora/logging/logger.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def setup_file_logger(name: str, log_file: str, level=logging.INFO) -> logging.Logger:
5
+ """Creates a logger with a specific name and log file."""
6
+ logger = logging.getLogger(name)
7
+ logger.setLevel(level)
8
+
9
+ # Prevent adding multiple handlers if logger already exists
10
+ if not logger.handlers:
11
+ file_handler = logging.FileHandler(log_file)
12
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
13
+ file_handler.setFormatter(formatter)
14
+ logger.addHandler(file_handler)
15
+
16
+ return logger
src/gecora/logic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/logic/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (168 Bytes). View file
 
src/gecora/logic/__pycache__/base.cpython-312.pyc ADDED
Binary file (2.81 kB). View file
 
src/gecora/logic/__pycache__/loomis_painter.cpython-312.pyc ADDED
Binary file (8.98 kB). View file
 
src/gecora/logic/__pycache__/utils.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
src/gecora/logic/base.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple
2
+
3
+ from PIL import Image
4
+
5
+
6
+ class LogicBase:
7
+ """
8
+ Ranking App.
9
+
10
+ Attributes:
11
+ root_path (str): Root directory for the application.
12
+ dataset_path (Optional[str]): Path to the dataset, if provided.
13
+ ranking_system (ModelRankingSystem): Instance of the model ranking system.
14
+ gecora_db (Optional[GecoraDB]): Placeholder for GecoraDB integration.
15
+ dataset_manager (Optional[GecoraDatasetManager]): Placeholder for dataset manager.
16
+ itov_app (Optional[ItoVApp]): Placeholder for ItoVApp integration.
17
+ """
18
+
19
+ def __init__(self, root_path: str, dataset_path: Optional[str] = None, dataset_split: str = "test"):
20
+ """Initializes Base Class"""
21
+ pass
22
+
23
+ def launch(self):
24
+ """Launches the App."""
25
+ raise NotImplementedError("This method should be implemented by the child class.")
26
+
27
+ def set_username(self, username: str) -> Optional[int]:
28
+ """
29
+ Creates the User in the database if not alredy present.
30
+ """
31
+ raise NotImplementedError("This method should be implemented by the child class.")
32
+
33
+ def set_preference(
34
+ self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
35
+ ) -> Tuple[bool, str]:
36
+ """
37
+ Sets a new preference entry.
38
+ """
39
+ raise NotImplementedError("This method should be implemented by the child class.")
40
+
41
+ def get_next_comparison(
42
+ self, user_id: int
43
+ ) -> Optional[Tuple[Tuple[str, str, str], Tuple[Image.Image, Dict, Dict], Tuple[int, int]]]:
44
+ """
45
+ Selects the next model pair for comparison.
46
+
47
+ Args:
48
+ user_id (int): ID of the current user.
49
+
50
+ Returns:
51
+ Optional[Tuple[Tuple[str, str, str], Tuple[Dict, Dict], Tuple[int, int]]]: ((reference_id, model_left_id, model_right_id), left_entry, right_entry), (num_preferences, total_num_comparison)
52
+ """
53
+ raise NotImplementedError("This method should be implemented by the child class.")
src/gecora/logic/loomis_painter.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any, Optional, Tuple, Union
4
+
5
+ from gecora.app.i_to_v_app import ItoVApp
6
+ from gecora.dataset.base_manager import BaseDatasetManager, DatasetManagerConfig
7
+ from gecora.dataset.video_manager import VideoManager
8
+ from gecora.db.hf_jsonl import HFJsonlDB
9
+ from gecora.db.sqlite import SQLiteDB
10
+ from gecora.logic.base import LogicBase
11
+ from gecora.logic.utils import save_video_from_frames
12
+ from gecora.ranking.ranking_system import RankingSystem
13
+
14
+
15
+ class LoomisPainterApp(LogicBase):
16
+ """
17
+ Loomis Painter Ranking App.
18
+
19
+ Attributes:
20
+ root_path (str): Root directory for the application.
21
+ dataset_path (Optional[str]): Path to the dataset, if provided.
22
+ ranking_system (ModelRankingSystem): Instance of the model ranking system.
23
+ gecora_db (Union[SQLiteDB, HFJsonlDB]): Database instance (SQLite or HF JSONL).
24
+ dataset_manager (Optional[GecoraDatasetManager]): Placeholder for dataset manager.
25
+ itov_app (Optional[ItoVApp]): Placeholder for ItoVApp integration.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ root_path: str,
31
+ dataset_path: Optional[str] = None,
32
+ use_hf_db: bool = True,
33
+ hf_repo_id: Optional[str] = None,
34
+ hf_token: Optional[str] = None,
35
+ force_model_id: Optional[str] = None,
36
+ desired_num_selections: Optional[int] = 25,
37
+ ):
38
+ """Initializes LoomisPainterApp.
39
+
40
+ Args:
41
+ root_path (str): The root directory containing subfolders with datasets.
42
+ dataset_path (Optional[str]): Path to the dataset, if provided.
43
+ use_hf_db (bool): Whether to use HuggingFace JSONL database instead of SQLite.
44
+ hf_repo_id (Optional[str]): HF repository ID (required if use_hf_db=True).
45
+ hf_token (Optional[str]): HF API token (if not provided, uses HF_TOKEN env var).
46
+ force_model_id (Optional[str]): If set select preferences where force_model_id is always included.
47
+ desired_num_selections (Optional[int]): The desired number of comparison the user should at least select.
48
+ """
49
+ self.root_path = root_path
50
+ self.dataset_path = dataset_path
51
+ self.force_model_id = force_model_id
52
+ self.desired_num_selections = desired_num_selections
53
+
54
+ # Core logic
55
+ self.ranking_system = RankingSystem()
56
+
57
+ # External integrations (to be implemented or injected)
58
+ if dataset_path is None:
59
+ dataset_path = root_path
60
+
61
+ self.gecora_db: Union[HFJsonlDB, SQLiteDB]
62
+ # Initialize database based on configuration
63
+ if use_hf_db:
64
+ if hf_repo_id is None:
65
+ raise ValueError("hf_repo_id must be provided when use_hf_db=True")
66
+ self.gecora_db = HFJsonlDB(
67
+ repo_id=hf_repo_id,
68
+ experiment_name="arena",
69
+ token=hf_token,
70
+ log_folder_path=root_path,
71
+ )
72
+ else:
73
+ self.gecora_db = SQLiteDB(
74
+ db_folder_path=root_path,
75
+ )
76
+
77
+ self.gecora_db.initialize_database()
78
+
79
+ dm_conf = DatasetManagerConfig(
80
+ root_path=dataset_path,
81
+ reference_column_name="reference_image",
82
+ genereated_column_name="video",
83
+ logging_path=root_path,
84
+ )
85
+ self.dataset_manager: BaseDatasetManager = VideoManager(config=dm_conf)
86
+ task_desc: str = """
87
+ ### 🎨 Choose the Best Step-By-Step Painting video
88
+ You will be shown a **reference painting** in the center, with two generated videos (Left and Right) that depict the painting process.
89
+ Your goal is to decide which step-by-step process is better based on the criteria below.
90
+
91
+ Please consider the following when making your decision:
92
+ * **Process Completeness:** Choose the video that best captures the entire painting process, from start to the finished work.
93
+ * **Visual Fidelity:** The final frame of the video should match the reference painting as closely as possible.
94
+ * **Ignore Duration:** The videos may be of different lengths. Please do not let the video duration influence your decision.
95
+
96
+ #### How to Vote
97
+ * Click **←** if the **Left** video is better.
98
+ * Click **→** if the **Right** video is better.
99
+ * Click **Tie** if neither video is clearly superior.
100
+ """
101
+ ref_img_label: str = "Reference Painting"
102
+ left_media_label: str = "Left Painting Process"
103
+ right_media_label: str = "Right Painting Process"
104
+ self.itov_app: ItoVApp = ItoVApp(
105
+ logic_class=self,
106
+ task_desc=task_desc,
107
+ ref_img_label=ref_img_label,
108
+ left_media_label=left_media_label,
109
+ right_media_label=right_media_label,
110
+ desired_num_selections=self.desired_num_selections,
111
+ )
112
+
113
+ self.log_path = os.path.join(root_path, "log.txt")
114
+ logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
115
+ self.logger = logging.getLogger()
116
+
117
+ def launch(self):
118
+ """Launches the App."""
119
+ self.itov_app.launch()
120
+
121
+ def set_username(self, username: str) -> Optional[int]:
122
+ """
123
+ Creates the User in the database if not already present.
124
+ """
125
+ user_id = self.gecora_db.get_user_id_by_username(username=username)
126
+ if user_id is None:
127
+ user_id, msg = self.gecora_db.create_user(username=username)
128
+ if user_id is None:
129
+ self.logger.error(f"Error while creating user with username {username}: {msg}")
130
+ return None
131
+ return user_id
132
+
133
+ def set_preference(
134
+ self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
135
+ ) -> Tuple[bool, str]:
136
+ """
137
+ Sets a new preference entry.
138
+ """
139
+ return self.gecora_db.insert_preference(
140
+ user_id=user_id,
141
+ reference_id=reference_id,
142
+ model_left_id=model_left_id,
143
+ model_right_id=model_right_id,
144
+ preferred_side=preferred_side,
145
+ )
146
+
147
+ def run_ranking_update(self):
148
+ """
149
+ Loads preferences and updates model rankings.
150
+ """
151
+ preferences = self.load_preferences()
152
+ self.ranking_system.update_rankings(preferences)
153
+
154
+ def get_next_comparison(
155
+ self, user_id: int
156
+ ) -> Optional[Tuple[Tuple[str, str, str], Tuple[Any, Any, Any], Tuple[int, int]]]:
157
+ """
158
+ Selects the next model pair for comparison.
159
+
160
+ Args:
161
+ user_id (int): ID of the current user.
162
+
163
+ Returns:
164
+ Optional[Tuple[Tuple[str, str, str], Tuple[Any, Any, Any], Tuple[int, int]]]: (reference_id, model_left_id, model_right_id), (reference_image, left_entry, right_entry), (num_preferences, total_num_comparison)
165
+ """
166
+ reference_ids = self.dataset_manager.common_entry_ids
167
+ model_ids = self.dataset_manager.get_dataset_names()
168
+ # Having a data base call here every time may not be ideal. The assumption is that one user
169
+ # does have a small number of set preferences, therefore runtime is still good.
170
+ preferences_tuple = self.gecora_db.get_preferences_by_user(user_id=user_id)
171
+ # Mapping from List[Tuple] to List[Dict[str, Any]]
172
+ preferences = self.gecora_db.map_preferences_to_dicts(preferences=preferences_tuple)
173
+ selected = self.ranking_system.select_next_comparison(
174
+ preferences, reference_ids, model_ids, force_model_id=self.force_model_id
175
+ )
176
+ num_preferences, total_num_comparison = self.ranking_system.calculate_progress(
177
+ preferences, reference_ids, model_ids, force_model_id=self.force_model_id
178
+ )
179
+ try:
180
+ if selected is None:
181
+ self.logger.warning(f"Was not able to get next comparison for user {user_id}")
182
+ return None
183
+ reference_id, model_left_id, model_right_id = selected
184
+ reference_image, left_entry, right_entry = self.dataset_manager.get_entries_by_id(
185
+ entry_id=reference_id, dataset_name1=model_left_id, dataset_name2=model_right_id
186
+ )
187
+ if reference_image is None or left_entry is None or right_entry is None:
188
+ return None
189
+
190
+ if isinstance(left_entry, list):
191
+ left_entry = save_video_from_frames(
192
+ left_entry, video_output_path=self.itov_app.tmp_video_path_left, fps=2
193
+ )
194
+ if isinstance(right_entry, list):
195
+ right_entry = save_video_from_frames(
196
+ right_entry, video_output_path=self.itov_app.tmp_video_path_right, fps=2
197
+ )
198
+ except Exception as e:
199
+ self.logger.error(f"Error in get_next_comparison({user_id}) of LoomisPainter: {e}")
200
+ return None
201
+ return (
202
+ (reference_id, model_left_id, model_right_id),
203
+ (reference_image, left_entry, right_entry),
204
+ (num_preferences, total_num_comparison),
205
+ )
src/gecora/logic/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import List, Optional
4
+
5
+ import imageio.v3 as iio
6
+ from PIL import Image
7
+
8
+
9
+ def create_temp_file(suffix: str = ".mp4") -> str:
10
+ """Function to create a temporary file and return its path."""
11
+ # Create a temporary file with a .mp4 extension
12
+ temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
13
+ temp_file_path = temp_file.name
14
+ temp_file.close() # Close the file but keep it on disk
15
+ return temp_file_path
16
+
17
+
18
+ def cleanup(temp_file_path: str):
19
+ """Cleanup function to delete the temporary file."""
20
+ if os.path.exists(temp_file_path):
21
+ os.remove(temp_file_path)
22
+
23
+
24
+ def cleanup_list(temp_file_path_list: List[str]):
25
+ """Cleanup function to delete the temporary file."""
26
+ for temp_file_path in temp_file_path_list:
27
+ if os.path.exists(temp_file_path):
28
+ os.remove(temp_file_path)
29
+
30
+
31
+ def save_video_from_frames(video: List[Image], video_output_path: str, fps: int = 2) -> Optional[str]:
32
+ """Saves the frames from as a video to video_output_path"""
33
+ try:
34
+ iio.imwrite(video_output_path, video, fps=fps, codec="libx264")
35
+ return video_output_path
36
+ except Exception:
37
+ return None
src/gecora/py.typed ADDED
File without changes
src/gecora/ranking/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
src/gecora/ranking/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (170 Bytes). View file
 
src/gecora/ranking/__pycache__/ranking_system.cpython-312.pyc ADDED
Binary file (5.58 kB). View file
 
src/gecora/ranking/ranking_system.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import random
3
+ from collections import defaultdict
4
+ from typing import Dict, List, Optional, Set, Tuple
5
+
6
+
7
+ class RankingSystem:
8
+ def __init__(self, base_rating: int = 1500, k: int = 32):
9
+ """
10
+ Initializes the ranking system.
11
+
12
+ Args:
13
+ base_rating (int): Initial ELO rating for all models.
14
+ k (int): ELO adjustment factor.
15
+ """
16
+ self.base_rating = base_rating
17
+ self.k = k
18
+ self.elo_ratings: Dict[str, float] = defaultdict(lambda: base_rating)
19
+ self.win_counts: Dict[str, int] = defaultdict(int)
20
+ self.match_counts: Dict[str, int] = defaultdict(int)
21
+
22
+ def update_rankings(self, preferences: List[Dict[str, str]]) -> None:
23
+ """
24
+ Updates ELO ratings and win counts based on preferences.
25
+
26
+ Args:
27
+ preferences (List[Dict[str, str]]): List of preference records.
28
+ """
29
+
30
+ def expected_score(rating_a: float, rating_b: float) -> float:
31
+ return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
32
+
33
+ for pref in preferences:
34
+ left = pref["model_left_id"]
35
+ right = pref["model_right_id"]
36
+ result = pref["preferred_side"]
37
+
38
+ rating_left = self.elo_ratings[left]
39
+ rating_right = self.elo_ratings[right]
40
+
41
+ expected_left = expected_score(rating_left, rating_right)
42
+ expected_right = expected_score(rating_right, rating_left)
43
+
44
+ if result == "left":
45
+ score_left, score_right = 1.0, 0.0
46
+ self.win_counts[left] += 1
47
+ elif result == "right":
48
+ score_left, score_right = 0.0, 1.0
49
+ self.win_counts[right] += 1
50
+ else: # tie
51
+ score_left = score_right = 0.5
52
+
53
+ self.elo_ratings[left] += self.k * (score_left - expected_left)
54
+ self.elo_ratings[right] += self.k * (score_right - expected_right)
55
+
56
+ self.match_counts[left] += 1
57
+ self.match_counts[right] += 1
58
+
59
+ def get_elo_ratings(self) -> Dict[str, float]:
60
+ """Returns the current ELO ratings."""
61
+ return dict(self.elo_ratings)
62
+
63
+ def get_winrates(self) -> Dict[str, float]:
64
+ """Returns the win rate for each model."""
65
+ return {
66
+ model: self.win_counts[model] / self.match_counts[model] if self.match_counts[model] > 0 else 0.0
67
+ for model in self.match_counts
68
+ }
69
+
70
+ def calculate_progress(
71
+ self,
72
+ preferences: List[Dict[str, str]],
73
+ reference_ids: List[str],
74
+ model_ids: List[str],
75
+ force_model_id: Optional[str] = None,
76
+ ) -> Tuple[int, int]:
77
+ """
78
+ Calculates the number of unique comparisons completed and the total possible comparisons.
79
+
80
+ Args:
81
+ preferences: List of preference dictionaries (assumed to be for a single user).
82
+ reference_ids: List of all possible reference IDs.
83
+ model_ids: List of all possible model IDs.
84
+ force_model_id: If set, restricts total calculation to pairs involving this model.
85
+
86
+ Returns:
87
+ (seen_count, total_possible_count)
88
+ """
89
+ # We assume preferences are unique
90
+ seen_count = len(preferences)
91
+
92
+ num_refs = len(reference_ids)
93
+ num_models = len(model_ids)
94
+
95
+ # Safety check to avoid math errors if lists are empty
96
+ if num_refs == 0 or num_models < 2:
97
+ return seen_count, 0
98
+
99
+ if force_model_id:
100
+ if force_model_id not in model_ids:
101
+ total_possible = 0
102
+ else:
103
+ total_possible = num_refs * (num_models - 1)
104
+ else:
105
+ # Standard combinatorics: nCr = n! / (r! * (n-r)!)
106
+ # Number of unique pairs = N * (N - 1) / 2
107
+ unique_pairs_count = (num_models * (num_models - 1)) // 2
108
+ total_possible = num_refs * unique_pairs_count
109
+
110
+ return seen_count, total_possible
111
+
112
+ def select_next_comparison(
113
+ self,
114
+ preferences: List[Dict[str, str]],
115
+ reference_ids: List[str],
116
+ model_ids: List[str],
117
+ force_model_id: Optional[str],
118
+ ) -> Optional[Tuple[str, str, str]]:
119
+ """
120
+ Selects the next pair of models for comparison for a given user and reference.
121
+
122
+ Ensures the same (reference_id, model_left_id, model_right_id) is not repeated for the user.
123
+
124
+ Args:
125
+ preferences (List[Dict[str, str]]): Existing preferences.
126
+ user_id (int): ID of the current user.
127
+ reference_ids (List[str]): List of possible reference IDs.
128
+ model_ids (List[str]): List of model IDs to choose from.
129
+ force_model_id (Optional[str]): If set always includes this model_id in the comparison.
130
+
131
+ Returns:
132
+ Optional[Tuple[str, str, str]]: A tuple of (reference_id, model_left_id, model_right_id),
133
+ or None if no valid pair is found.
134
+ """
135
+ seen_comparisons: Set[Tuple[str, frozenset]] = {
136
+ (pref["reference_id"], frozenset({pref["model_left_id"], pref["model_right_id"]})) for pref in preferences
137
+ }
138
+
139
+ # Copy list by slicing and shuffle
140
+ shuffled_refs = reference_ids[:]
141
+ random.shuffle(shuffled_refs)
142
+ shuffled_models = model_ids[:]
143
+ random.shuffle(shuffled_models)
144
+
145
+ # Handling force_model_id
146
+ if force_model_id:
147
+ if force_model_id not in model_ids:
148
+ return None
149
+
150
+ opponents = [m for m in shuffled_models if m != force_model_id]
151
+ for ref_id in shuffled_refs:
152
+ for opponent in opponents:
153
+ current_pair = frozenset({force_model_id, opponent})
154
+
155
+ if (ref_id, current_pair) not in seen_comparisons:
156
+ # Randomize left/right placement for UI neutrality
157
+ return (
158
+ (ref_id, force_model_id, opponent)
159
+ if random.random() > 0.5
160
+ else (ref_id, opponent, force_model_id)
161
+ )
162
+ return None
163
+
164
+ # Generate all possible unique pairs from the shuffled list
165
+ # combinations('ABCD', 2) --> AB AC AD BC BD CD
166
+ possible_pairs = list(itertools.combinations(shuffled_models, 2))
167
+
168
+ # Optional: Shuffle pairs again to ensure 'AB' isn't always checked before 'CD'
169
+ random.shuffle(possible_pairs)
170
+
171
+ for ref_id in shuffled_refs:
172
+ for m1, m2 in possible_pairs:
173
+ current_pair = frozenset({m1, m2})
174
+
175
+ if (ref_id, current_pair) not in seen_comparisons:
176
+ return (ref_id, m1, m2)
177
+
178
+ return None