Mayamaya commited on
Commit
16a1428
·
1 Parent(s): 0acda39

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +729 -0
  2. combined_data/blendshape_annotation_preprocess.py +114 -0
  3. combined_data/combination_to_filename.json +80 -0
  4. combined_data/lapwing/images/000001.png +0 -0
  5. combined_data/lapwing/images/000002.png +0 -0
  6. combined_data/lapwing/images/000003.png +0 -0
  7. combined_data/lapwing/images/000004.png +0 -0
  8. combined_data/lapwing/images/000005.png +0 -0
  9. combined_data/lapwing/images/000006.png +0 -0
  10. combined_data/lapwing/images/000007.png +0 -0
  11. combined_data/lapwing/images/000008.png +0 -0
  12. combined_data/lapwing/images/000009.png +0 -0
  13. combined_data/lapwing/images/000010.png +0 -0
  14. combined_data/lapwing/images/000011.png +0 -0
  15. combined_data/lapwing/images/000012.png +0 -0
  16. combined_data/lapwing/images/000013.png +0 -0
  17. combined_data/lapwing/images/000014.png +0 -0
  18. combined_data/lapwing/images/000015.png +0 -0
  19. combined_data/lapwing/images/000016.png +0 -0
  20. combined_data/lapwing/images/000017.png +0 -0
  21. combined_data/lapwing/images/000018.png +0 -0
  22. combined_data/lapwing/images/000019.png +0 -0
  23. combined_data/lapwing/images/000020.png +0 -0
  24. combined_data/lapwing/images/000021.png +0 -0
  25. combined_data/lapwing/images/000022.png +0 -0
  26. combined_data/lapwing/images/000023.png +0 -0
  27. combined_data/lapwing/images/000024.png +0 -0
  28. combined_data/lapwing/images/000025.png +0 -0
  29. combined_data/lapwing/images/000026.png +0 -0
  30. combined_data/lapwing/images/000027.png +0 -0
  31. combined_data/lapwing/images/000028.png +0 -0
  32. combined_data/lapwing/images/000029.png +0 -0
  33. combined_data/lapwing/images/000030.png +0 -0
  34. combined_data/lapwing/images/000031.png +0 -0
  35. combined_data/lapwing/images/000032.png +0 -0
  36. combined_data/lapwing/images/000033.png +0 -0
  37. combined_data/lapwing/images/000034.png +0 -0
  38. combined_data/lapwing/images/000035.png +0 -0
  39. combined_data/lapwing/images/000036.png +0 -0
  40. combined_data/lapwing/images/000037.png +0 -0
  41. combined_data/lapwing/images/000038.png +0 -0
  42. combined_data/lapwing/images/000039.png +0 -0
  43. combined_data/lapwing/images/000040.png +0 -0
  44. combined_data/lapwing/images/000041.png +0 -0
  45. combined_data/lapwing/images/000042.png +0 -0
  46. combined_data/lapwing/images/000043.png +0 -0
  47. combined_data/lapwing/images/000044.png +0 -0
  48. combined_data/lapwing/images/000045.png +0 -0
  49. combined_data/lapwing/images/000046.png +0 -0
  50. combined_data/lapwing/images/000047.png +0 -0
app.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import json
5
+ import random
6
+ import re
7
+ from functools import partial
8
+ from datetime import datetime
9
+ from collections import defaultdict, Counter
10
+
11
+ import gradio as gr
12
+ from loguru import logger
13
+
14
+ # --- Global State (unchanged) ---
15
+ # --- Global State (unchanged) ---
16
+ GLOBAL_STATE = {
17
+ "participant_id": None,
18
+ "data_loaded": False,
19
+ "all_eval_data": [],
20
+ "shuffled_indices": [],
21
+ "current_prompt_index": 0,
22
+ "current_criterion_index": 0,
23
+ "image_mapping": {},
24
+ "image_dir": "",
25
+ "evaluation_results": {},
26
+ "image_orders": {},
27
+ "start_time": None,
28
+ "end_time": None,
29
+ "current_ranks": {},
30
+ "current_absolute_score": None,
31
+ # ▼▼▼ 追加 ▼▼▼
32
+ "current_absolute_score_worst": None,
33
+ }
34
+
35
+ # --- Configuration (unchanged) ---
36
+ BASE_RESULTS_DIR = "./results"
37
+ LOG_DIR = "./logs"
38
+ COMBINED_DATA_DIR = "./combined_data"
39
+ IMAGE_SUBDIR = os.path.join("lapwing", "images")
40
+ MAPPING_FILENAME = "combination_to_filename.json"
41
+ CONDITIONS = ["Ours", "w_o_Proto_Loss", "w_o_HitL", "w_o_Tuning", "LLM-based"]
42
+ CRITERIA = ["Alignment", "Naturalness", "Attractiveness"]
43
+
44
+ CRITERIA_GUIDANCE_JP = [
45
+ "テキストと表情がどれだけ一致しているか",
46
+ "テキストの感情に沿ったセリフを言っていると想像したとき、表情がどれだけ自然か",
47
+ "テキストの感情に沿ったセリフを言っていると想像したとき、表情がどれだけ魅力的か"
48
+ ]
49
+ CRITERIA_GUIDANCE_EN = [
50
+ "how well the expression aligns with the text",
51
+ "imagining the character is speaking a line that matches the emotion of the text, how natural the facial expression is",
52
+ "imagining the character is speaking a line that matches the emotion of the text, how attractive the facial expression is"
53
+ ]
54
+ IMAGE_LABELS = ['A', 'B', 'C', 'D', 'E']
55
+
56
+
57
+ # --- Helper Functions ---
58
+ def get_image_path_from_prediction(prediction: dict) -> str:
59
+ if not GLOBAL_STATE["image_mapping"]:
60
+ logger.error("Image mapping is not loaded.")
61
+ return ""
62
+ indices = prediction.get("blendshape_index", {})
63
+ if not isinstance(indices, dict):
64
+ logger.error(f"blendshape_index is not a dictionary: {indices}")
65
+ return ""
66
+ sorted_indices = sorted(indices.items(), key=lambda item: int(item[0]))
67
+ key = ",".join(str(idx) for _, idx in sorted_indices)
68
+ filename = GLOBAL_STATE["image_mapping"].get(key)
69
+ if not filename:
70
+ logger.warning(f"No image found for blendshape key: {key}")
71
+ return ""
72
+ return os.path.join(GLOBAL_STATE["image_dir"], filename)
73
+
74
+
75
+ # ▼▼▼ 2. prompt_categoryを読み込むように修正 ▼▼▼
76
+ def load_evaluation_data(participant_id: str):
77
+ mapping_path = os.path.join(COMBINED_DATA_DIR, MAPPING_FILENAME)
78
+ if not os.path.exists(mapping_path):
79
+ return f"<p class='feedback red'>Error: Mapping file not found at {mapping_path}</p>", gr.update(
80
+ interactive=True), gr.update(interactive=False)
81
+
82
+ with open(mapping_path, 'r', encoding='utf-8') as f:
83
+ GLOBAL_STATE["image_mapping"] = json.load(f)["mapping"]
84
+ GLOBAL_STATE["image_dir"] = os.path.join(COMBINED_DATA_DIR, IMAGE_SUBDIR)
85
+ logger.info(f"Successfully loaded image mapping. Image directory: {GLOBAL_STATE['image_dir']}")
86
+
87
+ participant_dir = os.path.join(BASE_RESULTS_DIR, participant_id)
88
+ if not os.path.isdir(participant_dir):
89
+ return f"<p class='feedback red'>Error: Participant directory not found: {participant_dir}</p>", gr.update(
90
+ interactive=True), gr.update(interactive=False)
91
+
92
+ merged_data = defaultdict(lambda: {"predictions": {}, "category": None})
93
+ found_files = 0
94
+ for cond in CONDITIONS:
95
+ cond_dir = os.path.join(participant_dir, cond)
96
+ pattern = os.path.join(cond_dir, f"{participant_id}_{cond}_*.jsonl")
97
+ files = glob.glob(pattern)
98
+ if not files:
99
+ logger.warning(f"No prediction file found for condition '{cond}' with pattern: {pattern}")
100
+ continue
101
+ found_files += 1
102
+ with open(files[0], 'r', encoding='utf-8') as f:
103
+ for line in f:
104
+ data = json.loads(line)
105
+ prompt = data["text_prompt"]
106
+ merged_data[prompt]["predictions"][cond] = data["prediction"]
107
+ if not merged_data[prompt]["category"]:
108
+ merged_data[prompt]["category"] = data.get("prompt_category")
109
+
110
+ if found_files != len(CONDITIONS):
111
+ return f"<p class='feedback red'>Error: Found prediction files for only {found_files}/{len(CONDITIONS)} conditions.</p>", gr.update(
112
+ interactive=True), gr.update(interactive=False)
113
+
114
+ GLOBAL_STATE["all_eval_data"] = [
115
+ {"prompt": p, "predictions": d["predictions"], "category": d["category"]}
116
+ for p, d in merged_data.items() if len(d["predictions"]) == len(CONDITIONS)
117
+ ]
118
+ # ▲▲▲ END OF UPDATE ▲▲▲
119
+
120
+ if not GLOBAL_STATE["all_eval_data"]:
121
+ return "<p class='feedback red'>Error: No valid evaluation data could be loaded.</p>", gr.update(
122
+ interactive=True), gr.update(interactive=False)
123
+
124
+ GLOBAL_STATE["shuffled_indices"] = list(range(len(GLOBAL_STATE["all_eval_data"])))
125
+ random.shuffle(GLOBAL_STATE["shuffled_indices"])
126
+ GLOBAL_STATE["current_prompt_index"] = 0
127
+ GLOBAL_STATE["current_criterion_index"] = 0
128
+ GLOBAL_STATE["data_loaded"] = True
129
+ GLOBAL_STATE["start_time"] = datetime.now()
130
+ for i in range(len(GLOBAL_STATE["all_eval_data"])):
131
+ prompt_text = GLOBAL_STATE["all_eval_data"][i]["prompt"]
132
+ GLOBAL_STATE["evaluation_results"][prompt_text] = {}
133
+
134
+ logger.info(f"Loaded and merged data for {len(GLOBAL_STATE['all_eval_data'])} prompts.")
135
+ done_msg = "<p class='feedback green'>Data loaded successfully. Please proceed to the 'Evaluation' tab. / データの読み込みに成功しました。「評価」タブに進んでください。</p>"
136
+ return done_msg, gr.update(interactive=False, visible=False), gr.update(interactive=True)
137
+
138
+
139
+ # --- Core Logic ---
140
+
141
+ def _create_button_updates():
142
+ updates = []
143
+ for img_label in IMAGE_LABELS:
144
+ selected_rank = GLOBAL_STATE["current_ranks"].get(img_label)
145
+ for rank_val in range(1, 6):
146
+ if rank_val == selected_rank:
147
+ updates.append(gr.update(variant='primary'))
148
+ else:
149
+ updates.append(gr.update(variant='secondary'))
150
+ return updates
151
+
152
+
153
+ def handle_rank_button_click(image_label, rank):
154
+ if GLOBAL_STATE["current_ranks"].get(image_label) == rank:
155
+ GLOBAL_STATE["current_ranks"][image_label] = None
156
+ else:
157
+ GLOBAL_STATE["current_ranks"][image_label] = rank
158
+ return _create_button_updates()
159
+
160
+
161
+ def handle_absolute_score_click(score):
162
+ if GLOBAL_STATE["current_absolute_score"] == score:
163
+ GLOBAL_STATE["current_absolute_score"] = None
164
+ else:
165
+ GLOBAL_STATE["current_absolute_score"] = score
166
+
167
+ updates = []
168
+ for i in range(1, 8):
169
+ if i == GLOBAL_STATE["current_absolute_score"]:
170
+ updates.append(gr.update(variant='primary'))
171
+ else:
172
+ updates.append(gr.update(variant='secondary'))
173
+ return updates
174
+
175
+
176
+ # ▼▼▼ 追加 ▼▼▼
177
+ def handle_absolute_score_worst_click(score):
178
+ if GLOBAL_STATE["current_absolute_score_worst"] == score:
179
+ GLOBAL_STATE["current_absolute_score_worst"] = None
180
+ else:
181
+ GLOBAL_STATE["current_absolute_score_worst"] = score
182
+
183
+ updates = []
184
+ for i in range(1, 8):
185
+ if i == GLOBAL_STATE["current_absolute_score_worst"]:
186
+ updates.append(gr.update(variant='primary'))
187
+ else:
188
+ updates.append(gr.update(variant='secondary'))
189
+ return updates
190
+
191
+ # ▼▼▼ 1. UIフリーズ問題を修正 ▼▼▼
192
+
193
+
194
+ # ▼▼▼ 修正後の display_current_prompt_and_criterion 関数 ▼▼▼
195
+ def display_current_prompt_and_criterion():
196
+ if not GLOBAL_STATE["data_loaded"] or GLOBAL_STATE["current_prompt_index"] >= len(GLOBAL_STATE["all_eval_data"]):
197
+ done_msg = "<p class='feedback green' style='text-align: center; font-size: 1.2em;'>All prompts have been evaluated! Please proceed to the 'Export' tab. <br>すべてのプロンプトの評価が完了しました!「エクスポート」タブに進んでください。</p>"
198
+ empty_button_updates = [gr.update(variant='secondary')] * 25
199
+ empty_abs_updates = [gr.update(variant='secondary')] * 7
200
+ return [
201
+ gr.update(value="Finished! / 完了!"),
202
+ gr.update(value=""),
203
+ gr.update(value=done_msg),
204
+ gr.update(value="", visible=False),
205
+ *[gr.update(value=None)] * 5,
206
+ *empty_button_updates,
207
+ gr.update(visible=False), # abs_group_best
208
+ *empty_abs_updates,
209
+ gr.update(visible=False), # abs_group_worst
210
+ *empty_abs_updates,
211
+ gr.update(interactive=False),
212
+ gr.update(interactive=False)
213
+ ]
214
+
215
+ prompt_idx = GLOBAL_STATE["shuffled_indices"][GLOBAL_STATE["current_prompt_index"]]
216
+ criterion_idx = GLOBAL_STATE["current_criterion_index"]
217
+
218
+ current_data = GLOBAL_STATE["all_eval_data"][prompt_idx]
219
+ prompt_text = current_data["prompt"]
220
+ criterion_name = CRITERIA[criterion_idx]
221
+
222
+ progress_text = f"Prompt {GLOBAL_STATE['current_prompt_index'] + 1} / {len(GLOBAL_STATE['all_eval_data'])} - **{criterion_name}**"
223
+ prompt_display_text = f"## \"{prompt_text}\""
224
+ guidance_text = f"### Please rank the 5 images based on **{CRITERIA_GUIDANCE_EN[criterion_idx]}**.<br>5つの画像を、**「{CRITERIA_GUIDANCE_JP[criterion_idx]}」**を基準にランキング付けしてください。"
225
+
226
+ if criterion_idx == 0:
227
+ GLOBAL_STATE["image_orders"] = {}
228
+
229
+ if criterion_name not in GLOBAL_STATE["image_orders"]:
230
+ conditions_shuffled = random.sample(CONDITIONS, len(CONDITIONS))
231
+ GLOBAL_STATE["image_orders"][criterion_name] = conditions_shuffled
232
+
233
+ current_image_order = GLOBAL_STATE["image_orders"][criterion_name]
234
+ image_updates = []
235
+ for cond_name in current_image_order:
236
+ prediction = current_data["predictions"][cond_name]
237
+ img_path = get_image_path_from_prediction(prediction)
238
+ image_updates.append(gr.update(value=img_path if img_path and os.path.exists(img_path) else None))
239
+
240
+ saved_ranks_dict = GLOBAL_STATE["evaluation_results"].get(prompt_text, {}).get("ranks", {}).get(criterion_name)
241
+ if saved_ranks_dict:
242
+ label_to_condition = {label: cond for label, cond in zip(IMAGE_LABELS, current_image_order)}
243
+ condition_to_label = {v: k for k, v in label_to_condition.items()}
244
+ GLOBAL_STATE["current_ranks"] = {
245
+ condition_to_label[cond]: rank for cond, rank in saved_ranks_dict.items() if cond in condition_to_label
246
+ }
247
+ else:
248
+ GLOBAL_STATE["current_ranks"] = {label: None for label in IMAGE_LABELS}
249
+
250
+ button_updates = _create_button_updates()
251
+
252
+ # --- Absolute Score (Best) ---
253
+ is_alignment_criterion = (criterion_name == "Alignment")
254
+ abs_group_update = gr.update(visible=is_alignment_criterion)
255
+ saved_abs_score = GLOBAL_STATE["evaluation_results"].get(prompt_text, {}).get("absolute_score")
256
+ GLOBAL_STATE["current_absolute_score"] = saved_abs_score if is_alignment_criterion else None
257
+
258
+ abs_button_updates = []
259
+ for i in range(1, 8):
260
+ variant = 'primary' if i == GLOBAL_STATE["current_absolute_score"] else 'secondary'
261
+ abs_button_updates.append(gr.update(variant=variant))
262
+
263
+ # --- Absolute Score (Worst) ---
264
+ abs_group_worst_update = gr.update(visible=is_alignment_criterion)
265
+ saved_abs_score_worst = GLOBAL_STATE["evaluation_results"].get(prompt_text, {}).get("absolute_score_worst")
266
+ GLOBAL_STATE["current_absolute_score_worst"] = saved_abs_score_worst if is_alignment_criterion else None
267
+
268
+ abs_button_worst_updates = []
269
+ for i in range(1, 8):
270
+ variant = 'primary' if i == GLOBAL_STATE["current_absolute_score_worst"] else 'secondary'
271
+ abs_button_worst_updates.append(gr.update(variant=variant))
272
+
273
+ return [
274
+ gr.update(value=progress_text),
275
+ gr.update(value=prompt_display_text),
276
+ gr.update(value=guidance_text),
277
+ gr.update(value="", visible=False),
278
+ *image_updates,
279
+ *button_updates,
280
+ abs_group_update,
281
+ *abs_button_updates,
282
+ abs_group_worst_update,
283
+ *abs_button_worst_updates,
284
+ gr.update(
285
+ interactive=(GLOBAL_STATE["current_prompt_index"] > 0 or GLOBAL_STATE["current_criterion_index"] > 0)),
286
+ gr.update(interactive=True)
287
+ ]
288
+
289
+
290
+ # ▼▼▼ 修正後の validate_and_navigate 関数 ▼▼▼
291
+ def validate_and_navigate():
292
+ ranks = GLOBAL_STATE["current_ranks"]
293
+ error_msg = None
294
+ criterion_name = CRITERIA[GLOBAL_STATE["current_criterion_index"]]
295
+ is_alignment_criterion = (criterion_name == "Alignment")
296
+
297
+ # --- Validation ---
298
+ if any(r is None for r in ranks.values()):
299
+ error_msg = "Please rank all 5 images. / 5つすべての画像を評価してください。"
300
+ elif 1 not in ranks.values():
301
+ error_msg = "You must assign a rank of '1' to at least one image. / 最低1つは「1位」を付けてください。"
302
+ elif is_alignment_criterion and GLOBAL_STATE["current_absolute_score"] is None:
303
+ error_msg = "Please provide an absolute score for the BEST matching image (1-7). / 最も一致している画像について、絶対評価(1~7)を選択してください。"
304
+ elif is_alignment_criterion and GLOBAL_STATE["current_absolute_score_worst"] is None:
305
+ error_msg = "Please provide an absolute score for the WORST matching image (1-7). / 最も一致していない画像について、絶対評価(1~7)を選択してください。"
306
+
307
+ if error_msg:
308
+ # The number of components to update is now 53 (1 tab + 52 eval components)
309
+ no_change_updates = [gr.update()] * 53
310
+ no_change_updates[4] = gr.update( # error_display is the 5th component (index 4)
311
+ value=f"<p class='feedback red' style='font-size: 1.2em; text-align: center;'>{error_msg}</p>",
312
+ visible=True)
313
+ return no_change_updates
314
+
315
+ # ... (Rank tie-breaking validation logic is unchanged) ...
316
+ sorted_ranks = sorted(list(ranks.values()))
317
+ rank_counts = Counter(sorted_ranks)
318
+ i = 0
319
+ while i < len(sorted_ranks):
320
+ current_rank = sorted_ranks[i]
321
+ count = rank_counts[current_rank]
322
+ if i + count < len(sorted_ranks):
323
+ next_rank = sorted_ranks[i + count]
324
+ expected_next_rank = current_rank + count
325
+ if next_rank < expected_next_rank:
326
+ error_msg = f"Ranking rule violation (tie-breaking). After {count} instance(s) of rank '{current_rank}', the next rank must be >= {expected_next_rank}, but it is '{next_rank}'. / 順位付けのルール違反です。'{current_rank}'位が{count}つあるため、次の順位は{expected_next_rank}位以上である必要がありますが、'{next_rank}'位が入力されています。"
327
+ break
328
+ i += count
329
+ if error_msg:
330
+ no_change_updates = [gr.update()] * 53
331
+ no_change_updates[4] = gr.update(
332
+ value=f"<p class='feedback red' style='font-size: 1.2em; text-align: center;'>{error_msg}</p>",
333
+ visible=True)
334
+ return no_change_updates
335
+ # --- End of Validation ---
336
+
337
+ prompt_idx = GLOBAL_STATE["shuffled_indices"][GLOBAL_STATE["current_prompt_index"]]
338
+ current_data = GLOBAL_STATE["all_eval_data"][prompt_idx]
339
+ prompt_text = current_data["prompt"]
340
+ current_image_order = GLOBAL_STATE["image_orders"][criterion_name]
341
+
342
+ label_to_condition = {label: cond for label, cond in zip(IMAGE_LABELS, current_image_order)}
343
+ ranks_by_condition = {label_to_condition[label]: rank for label, rank in ranks.items()}
344
+
345
+ if "ranks" not in GLOBAL_STATE["evaluation_results"][prompt_text]:
346
+ GLOBAL_STATE["evaluation_results"][prompt_text]["ranks"] = {}
347
+ if "orders" not in GLOBAL_STATE["evaluation_results"][prompt_text]:
348
+ GLOBAL_STATE["evaluation_results"][prompt_text]["orders"] = {}
349
+
350
+ GLOBAL_STATE["evaluation_results"][prompt_text]["ranks"][criterion_name] = ranks_by_condition
351
+ GLOBAL_STATE["evaluation_results"][prompt_text]["orders"][criterion_name] = current_image_order
352
+
353
+ if is_alignment_criterion:
354
+ GLOBAL_STATE["evaluation_results"][prompt_text]["absolute_score"] = GLOBAL_STATE["current_absolute_score"]
355
+ GLOBAL_STATE["evaluation_results"][prompt_text]["absolute_score_worst"] = GLOBAL_STATE[
356
+ "current_absolute_score_worst"]
357
+
358
+ logger.info(
359
+ f"Saved rank for P:{GLOBAL_STATE['participant_id']}, Prompt:'{prompt_text}', Criterion:{criterion_name}, Ranks:{ranks_by_condition}")
360
+
361
+ GLOBAL_STATE["current_criterion_index"] += 1
362
+ if GLOBAL_STATE["current_criterion_index"] >= len(CRITERIA):
363
+ GLOBAL_STATE["current_criterion_index"] = 0
364
+ GLOBAL_STATE["current_prompt_index"] += 1
365
+
366
+ if GLOBAL_STATE["current_prompt_index"] >= len(GLOBAL_STATE["all_eval_data"]):
367
+ GLOBAL_STATE["end_time"] = datetime.now()
368
+ eval_panel_updates = display_current_prompt_and_criterion()
369
+ # Activate export tab on completion
370
+ return [gr.update(interactive=True)] + eval_panel_updates
371
+ else:
372
+ # Keep export tab state as is
373
+ return [gr.update()] + display_current_prompt_and_criterion()
374
+
375
+
376
+
377
+ def navigate_previous():
378
+ GLOBAL_STATE["current_criterion_index"] -= 1
379
+ if GLOBAL_STATE["current_criterion_index"] < 0:
380
+ GLOBAL_STATE["current_criterion_index"] = len(CRITERIA) - 1
381
+ GLOBAL_STATE["current_prompt_index"] -= 1
382
+ GLOBAL_STATE["current_prompt_index"] = max(0, GLOBAL_STATE["current_prompt_index"])
383
+ return display_current_prompt_and_criterion()
384
+
385
+ # ▼▼▼ 修正後の export_results 関数 ▼▼▼
386
+ def export_results(participant_id, alignment_reason, naturalness_reason, attractiveness_reason, optional_comment):
387
+ if not alignment_reason.strip() or not naturalness_reason.strip() or not attractiveness_reason.strip():
388
+ error_msg = "<p class='feedback red'>Please fill in the reasoning for all three criteria (Alignment, Naturalness, Attractiveness). / 3つの評価基準(一致度, 自然さ, 魅力度)すべての判断理由を記入してください。</p>"
389
+ return None, error_msg
390
+
391
+ if not participant_id:
392
+ return None, "<p class='feedback red'>Participant ID is missing. / 参加者IDがありません。</p>"
393
+
394
+ output_dir = os.path.join(BASE_RESULTS_DIR, participant_id)
395
+ os.makedirs(output_dir, exist_ok=True)
396
+ filename = f"evaluation_results_{participant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
397
+ filepath = os.path.join(output_dir, filename)
398
+
399
+ duration = (GLOBAL_STATE["end_time"] - GLOBAL_STATE["start_time"]).total_seconds() if GLOBAL_STATE.get(
400
+ "start_time") and GLOBAL_STATE.get("end_time") else None
401
+
402
+ prompt_to_category = {item["prompt"]: item["category"] for item in GLOBAL_STATE["all_eval_data"]}
403
+
404
+ final_results_list = []
405
+ for prompt, data in GLOBAL_STATE["evaluation_results"].items():
406
+ if not data: continue
407
+
408
+ ranks_data = data.get("ranks", {})
409
+ orders_data = data.get("orders", {})
410
+
411
+ final_results_list.append({
412
+ "prompt": prompt,
413
+ "prompt_category": prompt_to_category.get(prompt),
414
+ "image_order_alignment": orders_data.get("Alignment", []),
415
+ "image_order_naturalness": orders_data.get("Naturalness", []),
416
+ "image_order_attractiveness": orders_data.get("Attractiveness", []),
417
+ "alignment_ranks": ranks_data.get("Alignment", {}),
418
+ "naturalness_ranks": ranks_data.get("Naturalness", {}),
419
+ "attractiveness_ranks": ranks_data.get("Attractiveness", {}),
420
+ "alignment_absolute_score": data.get("absolute_score"),
421
+ # ▼▼▼ 追加 ▼▼▼
422
+ "alignment_absolute_score_worst": data.get("absolute_score_worst")
423
+ })
424
+
425
+ export_data = {
426
+ "metadata": {
427
+ "participant_id": participant_id,
428
+ "export_timestamp": datetime.now().isoformat(),
429
+ "total_prompts_evaluated": len(final_results_list),
430
+ "evaluation_duration_seconds": duration,
431
+ "reasoning": {
432
+ "alignment": alignment_reason,
433
+ "naturalness": naturalness_reason,
434
+ "attractiveness": attractiveness_reason,
435
+ },
436
+ "optional_comment": optional_comment,
437
+ },
438
+ "results": final_results_list
439
+ }
440
+
441
+ try:
442
+ with open(filepath, 'w', encoding='utf-8') as f:
443
+ json.dump(export_data, f, ensure_ascii=False, indent=2)
444
+ logger.info(f"Successfully exported results to: {filepath}")
445
+ except Exception as e:
446
+ logger.error(f"Failed to write export file: {e}")
447
+ return None, f"<p class='feedback red'>An error occurred during file export: {e}</p>"
448
+
449
+ upload_link = "https://drive.google.com/drive/folders/1ujIPF-67Y6OG8qBm1TYG3FsmuYxqSAcR?usp=drive_link"
450
+ status_message = f"""
451
+ <div class='feedback green' style='text-align: left;'>
452
+ <p><b>エクスポートが完了しました。/ Export complete.</b></p>
453
+ <p>上のボタンからJSONファイルをダウンロードし、指定された場所にアップロードして実験を終了してください。ご協力ありがとうございました。</p>
454
+ <p>Please download the JSON file and upload it to the designated location. Thank you for your cooperation.</p>
455
+ <p><b>アップロード先 / Upload to:</b> <a href='{upload_link}' target='_blank'>{upload_link}</a></p>
456
+ </div>"""
457
+ return gr.update(value=filepath, visible=True), status_message
458
+
459
+
460
+ ## ▼▼▼ 修正後の create_gradio_interface 関数 ▼▼▼
461
+ def create_gradio_interface():
462
+ css = """
463
+ .gradio-container { font-family: 'Arial', sans-serif; }
464
+ .feedback { padding: 10px; border-radius: 5px; font-weight: bold; text-align: center; margin-top: 10px; }
465
+ .feedback.green { background-color: #e6ffed; color: #2f6f4a; }
466
+ .feedback.red { background-color: #ffe6e6; color: #b30000; }
467
+ .image-label { font-size: 2.5em; font-weight: bold; margin-bottom: 10px; color: #333; }
468
+ .prompt-display { text-align: center; margin-bottom: 5px; padding: 15px; background-color: #f0f8ff; border-radius: 8px; }
469
+ .prompt-sub-guidance { text-align: center; font-size: 0.9em; color: #555; margin-top: 5px; margin-bottom: 15px; }
470
+ .rank-instruction {
471
+ color: #D32F2F;
472
+ font-size: 1.1em;
473
+ text-align: left;
474
+ margin-bottom: 20px;
475
+ padding: 15px;
476
+ border: 1px solid #f5c6cb;
477
+ border-radius: 8px;
478
+ background-color: #f8d7da;
479
+ line-height: 1.6;
480
+ }
481
+ .rank-instruction ul { padding-left: 20px; margin: 0; }
482
+ .rank-guidance { text-align: center; margin-bottom: 10px; font-size: 1.2em; }
483
+ .rank-btn-row { justify-content: center; gap: 5px !important; }
484
+ .rank-btn {
485
+ min-width: 45px !important;
486
+ max-width: 45px !important;
487
+ height: 45px !important;
488
+ font-size: 1.2em !important;
489
+ font-weight: bold !important;
490
+ border-radius: 8px !important;
491
+ border: 1px solid #ccc !important;
492
+ }
493
+ .rank-btn.secondary {
494
+ background: #f0f0f0 !important;
495
+ color: #333 !important;
496
+ }
497
+ .rank-btn.secondary:hover {
498
+ background: #e0e0e0 !important;
499
+ border-color: #bbb !important;
500
+ }
501
+ .absolute-eval-group {
502
+ border: 1px solid #ddd;
503
+ border-radius: 8px;
504
+ padding: 15px;
505
+ margin-top: 20px;
506
+ }
507
+ """
508
+
509
+ with gr.Blocks(title="Expression Evaluation Experiment", css=css) as app:
510
+ gr.Markdown("# Text-to-Expression Evaluation Experiment / テキストからの表情生成 評価実験")
511
+
512
+ with gr.Tabs() as tabs:
513
+ with gr.TabItem("1. Setup / セットアップ") as tab_setup:
514
+ gr.Markdown("## (A) Participant Information / 参加者情報")
515
+ gr.Markdown("Please enter your participant ID and click 'Confirm'. / 参加者IDを入力して「確定」を押してください。")
516
+ with gr.Row():
517
+ participant_id_input = gr.Textbox(label="Participant ID", placeholder="e.g., P01")
518
+ confirm_id_btn = gr.Button("Confirm / 確定", variant="primary")
519
+ setup_warning = gr.Markdown(visible=False)
520
+ with gr.Group(visible=False) as setup_main_group:
521
+ gr.Markdown("---")
522
+ gr.Markdown("## (B) Instructions & Data Loading / 注意事項とデータ読み込み")
523
+ gr.Markdown(
524
+ """<div style='padding: 15px; border: 1px solid #f0ad4e; border-radius: 5px; background-color: #fcf8e3;'>
525
+ <h4>注意事項 / Instructions</h4>
526
+ <ul>
527
+ <li><b>この作業はPCで行ってください。/ Please perform this task on a PC.</b></li>
528
+ <li>途中で止めずに最後まで続けてください。ファイルをアップロードして完了となります。/ Please continue until the end. The experiment is complete when you upload the file.</li>
529
+ <li>ブラウザーをリロードしないでください (データが破損します)。/ Do not reload the browser (this will corrupt the data).</li>
530
+ </ul></div>""")
531
+ gr.Markdown(
532
+ "Click the button below to load your evaluation data. / 下のボタンを押して、評価データを読み込んでください。")
533
+ load_data_btn = gr.Button("Load Data / データ読み込み", variant="primary")
534
+ setup_status = gr.Markdown("Waiting to start...")
535
+
536
+ with gr.TabItem("2. Evaluation / 評価", interactive=False) as tab_evaluation:
537
+ progress_text = gr.Markdown("Prompt 0 / 0")
538
+
539
+ image_components = []
540
+ rank_buttons = []
541
+ with gr.Row(equal_height=False):
542
+ for label in IMAGE_LABELS:
543
+ with gr.Column(scale=1):
544
+ with gr.Group():
545
+ gr.Markdown(f"<div class='image-label' style='text-align: center;'>{label}</div>")
546
+ img = gr.Image(type="filepath", show_label=False, height=300)
547
+ image_components.append(img)
548
+ with gr.Row(elem_classes="rank-btn-row"):
549
+ for rank_val in range(1, 6):
550
+ btn = gr.Button(str(rank_val), variant='secondary', elem_classes="rank-btn")
551
+ rank_buttons.append(btn)
552
+
553
+ prompt_display = gr.Markdown("## \"Prompt Text Here\"", elem_classes="prompt-display")
554
+ gr.Markdown(
555
+ "<p class='prompt-sub-guidance'>You may use AI or web search for the meaning of the text. However, please do not ask an AI about the emotion of the image itself.<br>意味についてはAIに聞いたりネット検索しても構いません。ただし、画像そのものの感情をAIに尋ねるのを止めてください。</p>")
556
+ guidance_display = gr.Markdown("### Guidance", elem_classes="rank-guidance")
557
+ error_display = gr.Markdown(visible=False)
558
+
559
+ gr.Markdown(
560
+ """
561
+ <b>ランキングの付け方 / How to Rank:</b>
562
+ <ul>
563
+ <li><b>全く同じ表情の画像には、同じ順位</b>を付けてください。(Assign the <b>same rank</b> to identical expressions.)</li>
564
+ <li><b>少しでも違う表情の画像には、違う順位</b>を付けてください。(Assign <b>different ranks</b> to different expressions.)</li>
565
+ <li><b>必ず1位から</b>順位を付けてください。(You <b>must</b> assign a rank of '1' to at least one image.)</li>
566
+ <li>同順位がある場合、<b>その人数分だけ次の順位を飛ばしてください</b>。(When you have ties, <b>skip the next rank(s) accordingly</b>.)
567
+ <ul>
568
+ <li>例1: 1位が2つある場合、次は3位になります (Ex. 1: If there are two '1st' places, the next rank is '3rd'. e.g., <code>1, 1, 3, 4, 5</code>).</li>
569
+ <li>例2: 1位が1つ、2位が3つある場合、次は5位になります (Ex. 2: If there is one '1st' and three '2nd' places, the next rank is '5th'. e.g., <code>1, 2, 2, 2, 5</code>).</li>
570
+ </ul>
571
+ </li>
572
+ </ul>
573
+ """,
574
+ elem_classes="rank-instruction"
575
+ )
576
+
577
+ # ▼▼▼ 修正: 絶対評価(Best)のUI ▼▼▼
578
+ with gr.Group(visible=False, elem_classes="absolute-eval-group") as absolute_eval_group_best:
579
+ gr.Markdown("---")
580
+ gr.Markdown(
581
+ "#### 絶対評価 (Best) / Absolute Score (Best)\n最もテキストと一致している画像について、どのていど一致しているかを評価してください。\n(Please evaluate the degree of alignment for the image that **best** matches the text.)")
582
+ absolute_score_buttons = []
583
+ with gr.Row():
584
+ with gr.Column(scale=1):
585
+ gr.Markdown(
586
+ "<p style='text-align: right; margin-top: 10px;'>1 (全く一致してない / Not at all)</p>")
587
+ with gr.Column(scale=3):
588
+ with gr.Row(elem_classes="rank-btn-row"):
589
+ for i in range(1, 8):
590
+ btn = gr.Button(str(i), variant='secondary', elem_classes="rank-btn")
591
+ absolute_score_buttons.append(btn)
592
+ with gr.Column(scale=1):
593
+ gr.Markdown("<p style='text-align: left; margin-top: 10px;'>7 (完全に一致 / Absolutely)</p>")
594
+
595
+ # ▼▼▼ 追加: 絶対評価(Worst)のUI ▼▼▼
596
+ with gr.Group(visible=False, elem_classes="absolute-eval-group") as absolute_eval_group_worst:
597
+ gr.Markdown(
598
+ "#### 絶対評価 (Worst) / Absolute Score (Worst)\n最もテキストと一致していない画像について、どのていど一致していないかを評価してください。\n(Please evaluate the degree of alignment for the image that **least** matches the text.)")
599
+ absolute_score_worst_buttons = []
600
+ with gr.Row():
601
+ with gr.Column(scale=1):
602
+ gr.Markdown(
603
+ "<p style='text-align: right; margin-top: 10px;'>1 (全く一致してない / Not at all)</p>")
604
+ with gr.Column(scale=3):
605
+ with gr.Row(elem_classes="rank-btn-row"):
606
+ for i in range(1, 8):
607
+ btn = gr.Button(str(i), variant='secondary', elem_classes="rank-btn")
608
+ absolute_score_worst_buttons.append(btn)
609
+ with gr.Column(scale=1):
610
+ gr.Markdown("<p style='text-align: left; margin-top: 10px;'>7 (完全に一致 / Absolutely)</p>")
611
+
612
+ with gr.Row():
613
+ prev_btn = gr.Button("← Previous / 前へ", interactive=False)
614
+ next_btn = gr.Button("Save & Next / 保存して次へ →", variant="primary")
615
+
616
+ with gr.TabItem("3. Export / エクスポート", interactive=False) as tab_export:
617
+ gr.Markdown("## (C) Final Comments & Export / 最終コメントとエクスポート")
618
+ gr.Markdown(
619
+ "Thank you for completing the evaluation. Please provide the reasoning for your judgments for each criterion below. / 評価お疲れ様でした。以下の各評価基準について、判断の理由をご記入ください。")
620
+
621
+ with gr.Group():
622
+ gr.Markdown("#### Reasoning for Judgments (Required) / 判断理由(必須)")
623
+ alignment_reason_box = gr.Textbox(label="Alignment / 一致度", lines=3,
624
+ placeholder="Why did you rank them this way for alignment? / なぜ一致度について、このような順位付けをしましたか?")
625
+ naturalness_reason_box = gr.Textbox(label="Naturalness / 自然さ", lines=3,
626
+ placeholder="Why did you rank them this way for naturalness? / なぜ自然さについて、このような順位付けをしましたか?")
627
+ attractiveness_reason_box = gr.Textbox(label="Attractiveness / 魅力度", lines=3,
628
+ placeholder="Why did you rank them this way for attractiveness? / なぜ魅力度について、このような順位付けをしましたか?")
629
+
630
+ with gr.Group():
631
+ gr.Markdown("#### Overall Comments (Optional) / 全体的な感想(任意)")
632
+ optional_comment_box = gr.Textbox(label="Any other comments? / その他、実験全体に関するご意見・ご感想",
633
+ lines=4,
634
+ placeholder="e.g., 'Image B often looked the most natural.' / 例:「Bの画像が最も自然に見えることが多かったです。」")
635
+
636
+ gr.Markdown("---")
637
+ gr.Markdown(
638
+ "Finally, click the button below to export your results. / 最後に、下のボタンを押して結果をエクスポートしてください。")
639
+ export_btn = gr.Button("Export Results / 結果をエクスポート", variant="primary")
640
+ download_file = gr.File(label="Download JSON", visible=False)
641
+ export_status = gr.Markdown()
642
+
643
+ # --- Event Handlers ---
644
+ def check_and_confirm_id(pid):
645
+ pid = pid.strip()
646
+ if re.fullmatch(r"P\d{2}", pid):
647
+ GLOBAL_STATE["participant_id"] = pid
648
+ return gr.update(visible=False), gr.update(visible=True)
649
+ else:
650
+ error_msg = "<p class='feedback red'>Invalid ID. Must be 'P' followed by two digits (e.g., P01). / 無効なIDです。「P」と数字2桁の形式(例: P01)で入力してください。</p>"
651
+ return gr.update(value=error_msg, visible=True), gr.update(visible=False)
652
+
653
+ confirm_id_btn.click(check_and_confirm_id, [participant_id_input], [setup_warning, setup_main_group])
654
+ load_data_btn.click(load_evaluation_data, [participant_id_input], [setup_status, load_data_btn, tab_evaluation])
655
+
656
+ # ▼▼▼ 修正: all_eval_outputs に新しいUIコンポーネントを追加 ▼▼▼
657
+ all_eval_outputs = [
658
+ progress_text, prompt_display, guidance_display, error_display, *image_components,
659
+ *rank_buttons,
660
+ absolute_eval_group_best, *absolute_score_buttons,
661
+ absolute_eval_group_worst, *absolute_score_worst_buttons,
662
+ prev_btn, next_btn
663
+ ]
664
+
665
+ btn_idx = 0
666
+ for label in IMAGE_LABELS:
667
+ for rank_val in range(1, 6):
668
+ btn = rank_buttons[btn_idx]
669
+ btn.click(
670
+ partial(handle_rank_button_click, label, rank_val),
671
+ [],
672
+ rank_buttons
673
+ )
674
+ btn_idx += 1
675
+
676
+ for i, btn in enumerate(absolute_score_buttons):
677
+ btn.click(
678
+ partial(handle_absolute_score_click, i + 1),
679
+ [],
680
+ absolute_score_buttons
681
+ )
682
+
683
+ # ▼▼▼ 追加: 新しいボタンのイベントハンドラを接続 ▼▼▼
684
+ for i, btn in enumerate(absolute_score_worst_buttons):
685
+ btn.click(
686
+ partial(handle_absolute_score_worst_click, i + 1),
687
+ [],
688
+ absolute_score_worst_buttons
689
+ )
690
+
691
+ tab_evaluation.select(display_current_prompt_and_criterion, [], all_eval_outputs)
692
+
693
+ # ▼▼▼ 修正: next_btn の出力に tab_export を追加 ▼▼▼
694
+ next_btn.click(validate_and_navigate, [], [tab_export, *all_eval_outputs])
695
+
696
+ prev_btn.click(navigate_previous, [], all_eval_outputs)
697
+
698
+ export_tab_interactive_components = [alignment_reason_box, naturalness_reason_box, attractiveness_reason_box,
699
+ optional_comment_box, export_btn]
700
+
701
+ def on_select_export_tab():
702
+ # end_time is set only when all evaluations are complete
703
+ if GLOBAL_STATE.get("end_time"):
704
+ return [gr.update(interactive=True)] * 5
705
+ # This logic is now handled by next_btn click, but kept as a fallback.
706
+ return [gr.update(interactive=False)] * 5
707
+
708
+ tab_export.select(on_select_export_tab, [], export_tab_interactive_components)
709
+
710
+ export_btn.click(
711
+ export_results,
712
+ [participant_id_input, alignment_reason_box, naturalness_reason_box, attractiveness_reason_box,
713
+ optional_comment_box],
714
+ [download_file, export_status]
715
+ )
716
+
717
+ return app
718
+
719
+ if __name__ == "__main__":
720
+ os.makedirs(LOG_DIR, exist_ok=True)
721
+ log_file_path = os.path.join(LOG_DIR, "evaluation_ui_log_{time}.log")
722
+
723
+ random.seed(datetime.now().timestamp())
724
+ logger.remove()
725
+ logger.add(sys.stderr, level="INFO")
726
+ logger.add(log_file_path, rotation="10 MB")
727
+
728
+ app = create_gradio_interface()
729
+ app.launch(share=True, debug=True)
combined_data/blendshape_annotation_preprocess.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from loguru import logger
8
+ from typing import Dict, List, Tuple
9
+
10
+ # ---- あなたの定義済み Dataset, transforms があるなら import ----
11
+ # from your_dataset import BlendShapeDataset, image_transform
12
+
13
+ def unify_index(idx: int, group_size: int) -> int:
14
+ """
15
+ BlendShape の選択 index が group_size と同じ場合、-1 (none) に変換して返す。
16
+ """
17
+ if idx == group_size:
18
+ return -1
19
+ return idx
20
+
21
+ def build_combination_to_filename(
22
+ teacher_data_file: str,
23
+ meta_file: str,
24
+ ) -> Tuple[Dict[Tuple[int, ...], str], List[int]]:
25
+ """
26
+ BlendShapeData.json を読み込み、
27
+ 例:
28
+ { (g0_idx, g1_idx, g2_idx): "000001.png", ... }
29
+ のような辞書を作成して返す。
30
+
31
+ さらに各グループのサイズ (blendShapeNames数 + 1) のリスト group_sizes も返す。
32
+ """
33
+ with open(meta_file, "r", encoding="utf-8") as f:
34
+ meta = json.load(f)
35
+ blend_shape_groups = meta["blendShapeGroupsMeta"]
36
+ # group_sizes[i] = len(そのグループの blendShapeNames) + 1(none枠)
37
+ group_sizes = [len(g["blendShapeNames"]) + 1 for g in blend_shape_groups]
38
+
39
+ # teacher_data_file 読み込み
40
+ with open(teacher_data_file, "r", encoding="utf-8") as f:
41
+ teacher_data = json.load(f)
42
+ data_list = teacher_data["dataList"]
43
+
44
+ combination_to_filename = {}
45
+ for data in data_list:
46
+ photo_filename = data["photoFileName"]
47
+ blendShapeSelections = data["blendShapeSelectionsPerGroup"]
48
+
49
+ # グループ順に selectedBlendShapeIndex を取得しつつ、-1の場合は-1のまま、
50
+ # group_sizeと一致していたら-1へ変換(理想的には -1 しか登場しない想定だが一応対応)
51
+ combo = []
52
+ for group_idx, selection in enumerate(blendShapeSelections):
53
+ sel_idx = selection["selectedBlendShapeIndex"]
54
+ # group_sizes[group_idx] と同じなら none として -1
55
+ sel_idx = unify_index(sel_idx, group_sizes[group_idx])
56
+ combo.append(sel_idx)
57
+
58
+ combo = tuple(combo) # dictのキーにするのでtuple化
59
+ combination_to_filename[combo] = photo_filename
60
+
61
+ return combination_to_filename, group_sizes
62
+
63
+
64
+
65
+ def main_offline_precompute():
66
+ """
67
+ 1. Dataset(JSON)から (組み合わせ -> filename) を作る
68
+ 2. CLIP モデルロード
69
+ 3. filename -> embedding
70
+ 4. ペアワイズ類似度
71
+ 5. 保存
72
+ """
73
+ import argparse
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument("--image_dir", type=str, default="lapwing/images")
76
+ parser.add_argument("--meta_file", type=str, default="lapwing/texts/BlendShapeGroupsMeta.json")
77
+ parser.add_argument("--teacher_data_file", type=str, default="lapwing/texts/BlendShapeData.json")
78
+ parser.add_argument("--clip_model_name", type=str, default="ViT-L-14")
79
+ parser.add_argument("--clip_pretrained", type=str, default="openai")
80
+ parser.add_argument("--batch_size", type=int, default=8)
81
+ parser.add_argument("--device", type=str, default="cuda")
82
+ parser.add_argument("--out_comb2fn", type=str, default="combination_to_filename.json")
83
+ parser.add_argument("--out_sims", type=str, default="pairwise_clip_sims.json")
84
+
85
+ args = parser.parse_args()
86
+
87
+ # 1. 組み合わせ->filename辞書の作成
88
+ combination_to_filename, group_sizes = build_combination_to_filename(
89
+ teacher_data_file=args.teacher_data_file,
90
+ meta_file=args.meta_file,
91
+ )
92
+
93
+
94
+ # 5. 保存 (JSON形式)
95
+ # 5.1 (combination -> filename)
96
+ # group_sizes も保存しておくと後段のオンライン時に参照しやすい
97
+ # tupleは文字列化する必要あり
98
+ comb2fn_dict = {
99
+ "group_sizes": group_sizes,
100
+ "mapping": {
101
+ ",".join(map(str, comb)): fn
102
+ for comb, fn in combination_to_filename.items()
103
+ }
104
+ }
105
+ with open(args.out_comb2fn, "w", encoding="utf-8") as f:
106
+ json.dump(comb2fn_dict, f, ensure_ascii=False, indent=2)
107
+
108
+ # データ出力
109
+ for combo, filename in combination_to_filename.items():
110
+ logger.info(f"Combination: {combo}, Filename: {filename}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main_offline_precompute()
combined_data/combination_to_filename.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "group_sizes": [
3
+ 8,
4
+ 9
5
+ ],
6
+ "mapping": {
7
+ "-1,-1": "000001.png",
8
+ "-1,0": "000002.png",
9
+ "-1,1": "000003.png",
10
+ "-1,2": "000004.png",
11
+ "-1,3": "000005.png",
12
+ "-1,4": "000006.png",
13
+ "-1,5": "000007.png",
14
+ "-1,6": "000008.png",
15
+ "-1,7": "000009.png",
16
+ "0,-1": "000010.png",
17
+ "0,0": "000011.png",
18
+ "0,1": "000012.png",
19
+ "0,2": "000013.png",
20
+ "0,3": "000014.png",
21
+ "0,4": "000015.png",
22
+ "0,5": "000016.png",
23
+ "0,6": "000017.png",
24
+ "0,7": "000018.png",
25
+ "1,-1": "000019.png",
26
+ "1,0": "000020.png",
27
+ "1,1": "000021.png",
28
+ "1,2": "000022.png",
29
+ "1,3": "000023.png",
30
+ "1,4": "000024.png",
31
+ "1,5": "000025.png",
32
+ "1,6": "000026.png",
33
+ "1,7": "000027.png",
34
+ "2,-1": "000028.png",
35
+ "2,0": "000029.png",
36
+ "2,1": "000030.png",
37
+ "2,2": "000031.png",
38
+ "2,3": "000032.png",
39
+ "2,4": "000033.png",
40
+ "2,5": "000034.png",
41
+ "2,6": "000035.png",
42
+ "2,7": "000036.png",
43
+ "3,-1": "000037.png",
44
+ "3,0": "000038.png",
45
+ "3,1": "000039.png",
46
+ "3,2": "000040.png",
47
+ "3,3": "000041.png",
48
+ "3,4": "000042.png",
49
+ "3,5": "000043.png",
50
+ "3,6": "000044.png",
51
+ "3,7": "000045.png",
52
+ "4,-1": "000046.png",
53
+ "4,0": "000047.png",
54
+ "4,1": "000048.png",
55
+ "4,2": "000049.png",
56
+ "4,3": "000050.png",
57
+ "4,4": "000051.png",
58
+ "4,5": "000052.png",
59
+ "4,6": "000053.png",
60
+ "4,7": "000054.png",
61
+ "5,-1": "000055.png",
62
+ "5,0": "000056.png",
63
+ "5,1": "000057.png",
64
+ "5,2": "000058.png",
65
+ "5,3": "000059.png",
66
+ "5,4": "000060.png",
67
+ "5,5": "000061.png",
68
+ "5,6": "000062.png",
69
+ "5,7": "000063.png",
70
+ "6,-1": "000064.png",
71
+ "6,0": "000065.png",
72
+ "6,1": "000066.png",
73
+ "6,2": "000067.png",
74
+ "6,3": "000068.png",
75
+ "6,4": "000069.png",
76
+ "6,5": "000070.png",
77
+ "6,6": "000071.png",
78
+ "6,7": "000072.png"
79
+ }
80
+ }
combined_data/lapwing/images/000001.png ADDED
combined_data/lapwing/images/000002.png ADDED
combined_data/lapwing/images/000003.png ADDED
combined_data/lapwing/images/000004.png ADDED
combined_data/lapwing/images/000005.png ADDED
combined_data/lapwing/images/000006.png ADDED
combined_data/lapwing/images/000007.png ADDED
combined_data/lapwing/images/000008.png ADDED
combined_data/lapwing/images/000009.png ADDED
combined_data/lapwing/images/000010.png ADDED
combined_data/lapwing/images/000011.png ADDED
combined_data/lapwing/images/000012.png ADDED
combined_data/lapwing/images/000013.png ADDED
combined_data/lapwing/images/000014.png ADDED
combined_data/lapwing/images/000015.png ADDED
combined_data/lapwing/images/000016.png ADDED
combined_data/lapwing/images/000017.png ADDED
combined_data/lapwing/images/000018.png ADDED
combined_data/lapwing/images/000019.png ADDED
combined_data/lapwing/images/000020.png ADDED
combined_data/lapwing/images/000021.png ADDED
combined_data/lapwing/images/000022.png ADDED
combined_data/lapwing/images/000023.png ADDED
combined_data/lapwing/images/000024.png ADDED
combined_data/lapwing/images/000025.png ADDED
combined_data/lapwing/images/000026.png ADDED
combined_data/lapwing/images/000027.png ADDED
combined_data/lapwing/images/000028.png ADDED
combined_data/lapwing/images/000029.png ADDED
combined_data/lapwing/images/000030.png ADDED
combined_data/lapwing/images/000031.png ADDED
combined_data/lapwing/images/000032.png ADDED
combined_data/lapwing/images/000033.png ADDED
combined_data/lapwing/images/000034.png ADDED
combined_data/lapwing/images/000035.png ADDED
combined_data/lapwing/images/000036.png ADDED
combined_data/lapwing/images/000037.png ADDED
combined_data/lapwing/images/000038.png ADDED
combined_data/lapwing/images/000039.png ADDED
combined_data/lapwing/images/000040.png ADDED
combined_data/lapwing/images/000041.png ADDED
combined_data/lapwing/images/000042.png ADDED
combined_data/lapwing/images/000043.png ADDED
combined_data/lapwing/images/000044.png ADDED
combined_data/lapwing/images/000045.png ADDED
combined_data/lapwing/images/000046.png ADDED
combined_data/lapwing/images/000047.png ADDED