File size: 25,605 Bytes
b1af3e6
 
 
87d2508
b1af3e6
87d2508
b1af3e6
 
87d2508
b1af3e6
87d2508
 
a4057b1
b1af3e6
 
 
911da12
87d2508
a4057b1
 
27e5eec
 
87d2508
b2cf79c
8d601ec
24fb34a
87d2508
 
911da12
 
 
 
 
8d601ec
87d2508
911da12
b1af3e6
 
bfa3575
b1af3e6
 
bfa3575
b1af3e6
 
 
 
 
 
 
a4057b1
 
 
0ac8168
f4825b9
505436d
a4057b1
 
 
bfa3575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71cbe5b
505436d
6f155f6
505436d
 
 
bfa3575
 
505436d
 
 
27e5eec
 
 
 
a4057b1
27e5eec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f29aac
 
 
 
27e5eec
 
 
 
 
 
 
 
 
 
b1af3e6
 
a4057b1
 
 
 
0ac8168
f4825b9
a4057b1
 
6f155f6
b1af3e6
a4057b1
bfa3575
 
 
 
 
 
 
b1af3e6
36d5aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f155f6
b1af3e6
36d5aef
 
b1af3e6
 
36d5aef
87d2508
36d5aef
c4a67a5
36d5aef
c4a67a5
36d5aef
c4a67a5
36d5aef
c4a67a5
 
36d5aef
bfa3575
 
 
36d5aef
bfa3575
 
 
 
 
 
 
 
 
36d5aef
a4057b1
bfa3575
b1af3e6
 
 
36d5aef
b1af3e6
 
 
ed1313f
 
b1af3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d601ec
 
 
f4825b9
8d601ec
 
 
 
f4825b9
 
8d601ec
 
 
 
 
 
 
 
 
b1af3e6
a4057b1
b1af3e6
 
bfa3575
 
47ad573
505436d
 
 
b1af3e6
ed1313f
 
 
b1af3e6
 
36d5aef
71cbe5b
6f155f6
bfa3575
505436d
ed1313f
b1af3e6
 
 
 
 
 
505436d
b1af3e6
bfa3575
 
71cbe5b
6f155f6
 
505436d
6f155f6
505436d
 
 
bfa3575
 
 
 
 
a4057b1
505436d
87d2508
 
 
505436d
 
 
bfa3575
 
 
 
 
 
 
 
 
 
 
 
 
6f155f6
b1af3e6
 
 
505436d
b1af3e6
 
 
 
 
 
 
bfa3575
 
 
 
87d2508
 
 
0ac8168
f4825b9
a4057b1
b1af3e6
87d2508
 
 
b1af3e6
 
71cbe5b
b1af3e6
 
ed1313f
b1af3e6
 
 
 
8d601ec
 
 
 
f4825b9
87d2508
b1af3e6
ed1313f
 
505436d
bfa3575
47ad573
505436d
 
f4825b9
8d601ec
47ad573
505436d
 
 
 
 
 
 
b1af3e6
 
 
 
ed1313f
b1af3e6
8d601ec
 
 
 
 
f4825b9
8d601ec
b1af3e6
5a6cec3
bfa3575
b1af3e6
 
 
 
 
 
71cbe5b
bfa3575
 
 
 
 
 
 
b1af3e6
505436d
 
 
 
b1af3e6
8d601ec
 
 
b1af3e6
 
ed1313f
b1af3e6
 
 
 
f4825b9
 
b1af3e6
a4057b1
87d2508
a4057b1
f4825b9
 
a4057b1
 
 
87d2508
f4825b9
bfa3575
f4825b9
8d601ec
a4057b1
b1af3e6
 
87d2508
f4825b9
bfa3575
87d2508
f4825b9
bfa3575
b1af3e6
71cbe5b
d285015
 
71cbe5b
b1af3e6
 
ed1313f
b1af3e6
 
bfa3575
b1af3e6
f4825b9
 
b1af3e6
a4057b1
 
47ad573
f4825b9
 
a4057b1
bfa3575
 
47ad573
bfa3575
 
 
71cbe5b
bfa3575
 
47ad573
bfa3575
 
 
 
 
 
 
 
 
 
 
 
 
71cbe5b
bfa3575
 
 
 
 
 
 
 
 
 
47ad573
bfa3575
 
71cbe5b
b1af3e6
bfa3575
a4057b1
 
 
5a6cec3
f4825b9
f76dc8c
 
 
b1af3e6
 
5a6cec3
d7004d8
47ad573
 
d7004d8
bfa3575
d7004d8
 
bfa3575
 
 
 
 
 
d7004d8
 
 
 
 
 
 
 
b1af3e6
 
 
 
 
453f64d
 
47ad573
453f64d
71cbe5b
6f155f6
 
 
 
 
 
 
 
 
 
b1af3e6
87d2508
b1af3e6
 
 
9f8ddd4
 
 
701a5dd
be358f5
71cbe5b
d45c855
9f8ddd4
d45c855
be358f5
d45c855
 
71cbe5b
d45c855
 
 
 
9f8ddd4
 
be358f5
71cbe5b
d45c855
9f8ddd4
b1af3e6
 
 
 
 
 
 
 
 
 
 
ed1313f
b1af3e6
 
 
 
 
c9e5ad6
 
2ddab66
71cbe5b
c9e5ad6
 
 
 
 
71cbe5b
a4057b1
f4825b9
5a6cec3
87d2508
f4825b9
71cbe5b
f4825b9
5a6cec3
87d2508
f4825b9
71cbe5b
f4825b9
5a6cec3
87d2508
f4825b9
71cbe5b
f4825b9
5a6cec3
0ac8168
f4825b9
71cbe5b
f4825b9
 
 
 
a4057b1
87d2508
8d601ec
 
 
87d2508
71cbe5b
b1af3e6
c974b82
34563fc
a4057b1
712213e
71cbe5b
b1af3e6
 
 
 
 
 
 
ed1313f
b1af3e6
a4057b1
f4825b9
a4057b1
 
b1af3e6
 
 
 
d285015
 
 
 
 
b1af3e6
71cbe5b
d285015
 
b1af3e6
d7004d8
b1af3e6
 
 
 
 
5a6cec3
f4825b9
5a6cec3
f4825b9
 
 
b1af3e6
 
ef5c435
453f64d
ef5c435
b1af3e6
 
 
6f155f6
 
b1af3e6
6f155f6
 
 
b1af3e6
 
 
 
911da12
 
 
 
 
 
bfa3575
0f29aac
911da12
 
 
 
 
b1af3e6
36d5aef
71cbe5b
bfa3575
b1af3e6
bfa3575
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
import os
import json
import hashlib
import random
import threading
import time
from dataclasses import dataclass
from typing import List, Dict, Any

import gradio as gr
from PIL import Image
from huggingface_hub import HfApi, CommitOperationAdd

# ----------------------
# Configuration
# ----------------------
# --- HF Repo ---
HF_RESULTS_REPO = os.getenv("HF_RESULTS_REPO")
HF_RESULTS_REPO_TYPE = "dataset"
HF_TOKEN = os.getenv("HF_TOKEN")
_hf_api = HfApi(token=HF_TOKEN)

# --- Main settings ---
TARGET_PER_PERSON = 30
CONTACT_EMAIL = "ffallah@asu.edu"

# --- Paths ---

GT_MASKED_DIR   = "data/gt_b"   # Image 1
GT_UNMASKED_DIR = "data/adc_b"  # Image 2
SR_DIR          = "data/sr_b"   # Image 3
ORIGINAL_DIR    = "data/lr_b"   # Image 4
IMAGE_5_DIR     = "data/see_b"  # Image 5

# --- Results ---
RESULTS_DIR = "results"
PROGRESS_PATH = os.path.join(RESULTS_DIR, "progress.json")
ALL_RESULTS_JSONL = os.path.join(RESULTS_DIR, "all_results.jsonl")
SAVE_PII = True

WRITE_LOCK = threading.Lock()
STRICT_ENFORCEMENT = False

# ----------------------
# Data model
# ----------------------
@dataclass
class Sample:
    sample_id: str
    masked_gt_path: str   # Image 1
    unmasked_gt_path: str # Image 2
    sr_path: str          # Image 3
    original_path: str    # Image 4
    image_5_path: str     # Image 5

# ----------------------
# Helpers
# ----------------------
# def ensure_sample_objects(samples_input):
#     """
#     Accepts either:
#      - list[Sample] (already objects), or
#      - list[dict] (serialized Sample.__dict__)
#     Returns list[Sample].
#     """
#     if not samples_input:
#         return []
#     if isinstance(samples_input, list):
#         if len(samples_input) == 0:
#             return []
#         first = samples_input[0]
#         if isinstance(first, dict):
#             try:
#                 return [Sample(**s) for s in samples_input]
#             except Exception:
#                 # fall through to returning empty to avoid crashes
#                 return []
#         elif isinstance(first, Sample):
#             return samples_input
#     return []

def user_target_count(samples: List[Sample]) -> int:
    return min(len(samples), TARGET_PER_PERSON)

def user_left_count(user_seen: List[str], samples: List[Sample]) -> int:
    target = user_target_count(samples)
    seen = set(user_seen or [])
    allowed_ids = {s.sample_id for s in samples}
    seen_in_allowed = len([sid for sid in seen if sid in allowed_ids])
    return max(0, target - seen_in_allowed)

def _ensure_private_repo(repo_id: str):
    try:
        _hf_api.repo_info(repo_id, repo_type=HF_RESULTS_REPO_TYPE)
    except Exception:
        _hf_api.create_repo(repo_id=repo_id, repo_type=HF_RESULTS_REPO_TYPE, private=True)

def push_results_to_private_repo(uid: str):
    if not HF_TOKEN or not HF_RESULTS_REPO:
        return
    try:
        os.makedirs(RESULTS_DIR, exist_ok=True)
        user_file = os.path.join(RESULTS_DIR, f"{uid}.jsonl")

        ops = [
            CommitOperationAdd(
                path_in_repo="results/all_results.jsonl",
                path_or_fileobj=ALL_RESULTS_JSONL
            ),
            CommitOperationAdd(
                path_in_repo=f"results/users/{uid}.jsonl",
                path_or_fileobj=user_file
            ),
            CommitOperationAdd(
                path_in_repo="results/progress.json",
                path_or_fileobj=PROGRESS_PATH
            ),
        ]
        _hf_api.create_commit(
            repo_id=HF_RESULTS_REPO,
            repo_type=HF_RESULTS_REPO_TYPE,
            operations=ops,
            commit_message="Update RTS eval results"
        )
    except Exception as e:
        print("[WARN] push_results_to_private_repo failed:", e)

def ensure_paths():
    os.makedirs(RESULTS_DIR, exist_ok=True)
    for pth, name in [
        (GT_MASKED_DIR, "GT_MASKED_DIR"),
        (GT_UNMASKED_DIR, "GT_UNMASKED_DIR"),
        (SR_DIR, "SR_DIR"),
        (ORIGINAL_DIR, "ORIGINAL_DIR"),
        (IMAGE_5_DIR, "IMAGE_5_DIR"),
    ]:
        if not os.path.isdir(pth):
            print(f"Warning: Directory '{pth}' for {name} not found.")

def load_image(path: str) -> Image.Image:
    if not path or not os.path.exists(path):
        # return a simple placeholder image so UI doesn't crash
        return Image.new("RGB", (256, 256), color="gray")
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (256, 256), color="gray")

def load_dataset(
    gt_masked_dir: str,
    gt_unmasked_dir: str,
    sr_dir: str,
    original_dir: str,
    image_5_dir: str,
) -> List[Sample]:
    """
    Build samples only from the 5 folders.
    Each folder should have the same filenames.
    Example layout:
        data/gt_b/xxx.png
        data/adc_b/xxx.png
        data/sr_b/xxx.png
        data/lr_b/xxx.png
        data/see_b/xxx.png
    """

    def list_images(dir_path: str) -> set:
        if not os.path.isdir(dir_path):
            print(f"Warning: directory not found: {dir_path}")
            return set()
        files = []
        for f in os.listdir(dir_path):
            f_lower = f.lower()
            if f_lower.endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp")):
                files.append(f)
        return set(files)

    masked_files   = list_images(gt_masked_dir)
    unmasked_files = list_images(gt_unmasked_dir)
    sr_files       = list_images(sr_dir)
    orig_files     = list_images(original_dir)
    img5_files     = list_images(image_5_dir)

    # Common filenames present in ALL 5 folders
    common_files = masked_files & unmasked_files & sr_files & orig_files & img5_files

    if not common_files:
        print("No common image files found in all 5 folders.")
        return []

    # Optional: simple debug info
    print(f"Found {len(common_files)} common images.")

    samples: List[Sample] = []
    for base_filename in sorted(common_files):
        sample_id = os.path.splitext(base_filename)[0]

        paths = {
            "masked":   os.path.join(gt_masked_dir, base_filename),
            "unmasked": os.path.join(gt_unmasked_dir, base_filename),
            "sr":       os.path.join(sr_dir, base_filename),
            "original": os.path.join(original_dir, base_filename),
            "img5":     os.path.join(image_5_dir, base_filename),
        }

        # If STRICT_ENFORCEMENT is True, skip if any file missing
        if STRICT_ENFORCEMENT:
            if not all(os.path.exists(p) for p in paths.values()):
                missing = [k for k, v in paths.items() if not os.path.exists(v)]
                print(f"Skipping {base_filename}: missing in folders {missing}")
                continue

        samples.append(
            Sample(
                sample_id=sample_id,
                masked_gt_path=paths["masked"],
                unmasked_gt_path=paths["unmasked"],
                sr_path=paths["sr"],
                original_path=paths["original"],
                image_5_path=paths["img5"],
            )
        )

    return samples


# ----------------------
# Progress & results I/O
# ----------------------
def hash_user_id(name: str, email: str) -> str:
    norm = (name or "").strip().lower() + "|" + (email or "").strip().lower()
    return hashlib.sha256(norm.encode("utf-8")).hexdigest()[:16]

def load_progress() -> Dict[str, Dict[str, Any]]:
    if not os.path.exists(PROGRESS_PATH):
        return {}
    try:
        with open(PROGRESS_PATH, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return {}

def save_progress(progress: Dict[str, Dict[str, Any]]):
    with WRITE_LOCK:
        with open(PROGRESS_PATH, "w", encoding="utf-8") as f:
            json.dump(progress, f, ensure_ascii=False, indent=2)

def append_jsonl(path: str, record: Dict[str, Any]):
    line = json.dumps(record, ensure_ascii=False)
    with WRITE_LOCK:
        with open(path, "a", encoding="utf-8") as f:
            f.write(line + "\n")

# ----------------------
# LOGIC FOR CONVERTING SLIDERS TO RANK
# ----------------------
def convert_scores_to_rank(s1, s2, s3, s4, s5) -> Dict[str, int]:
    scores = [
        ("image_1", s1),
        ("image_2", s2),
        ("image_3", s3),
        ("image_4", s4),
        ("image_5", s5)
    ]
    scores.sort(key=lambda x: x[1], reverse=True)
    ranks = {}
    current_rank = 1
    for img_key, score in scores:
        ranks[img_key] = current_rank
        current_rank += 1
    return ranks

# ----------------------
# App logic
# ----------------------
def pick_next_index(user_seen: List[str], samples: List[Sample]) -> int:
    # FIX: define seen_set and use samples directly
    seen_set = set(user_seen or [])
    remaining = [i for i, s in enumerate(samples) if s.sample_id not in seen_set]
    if not remaining:
        return -1
    return random.choice(remaining)

def start_or_resume(name: str, email: str):
    if not name or not email:
        raise gr.Error("Please enter your name and email to begin.")

    ensure_paths()
    samples = load_dataset(GT_MASKED_DIR, GT_UNMASKED_DIR, SR_DIR, ORIGINAL_DIR, IMAGE_5_DIR)

    if not samples:
        raise gr.Error("No images found. Please check dataset configuration.")

    uid = hash_user_id(name, email)
    progress = load_progress()
    if uid not in progress:
        progress[uid] = {"seen": []}
        save_progress(progress)

    user_seen: List[str] = progress[uid].get("seen", [])
    left = user_left_count(user_seen, samples)

    # placeholder image to avoid Gradio trying to load None
    placeholder_img = Image.new("RGB", (256, 256), color="gray")

    # If the user has finished their target
    if left == 0 and len(user_seen) >= user_target_count(samples):
        status = (
            f"Welcome back, {name}. You’ve completed all {user_target_count(samples)} images. 🎉\n"
            f"Your personal results file: {os.path.join(RESULTS_DIR, f'{uid}.jsonl')}"
        )
        return (
            uid,
            samples,
            user_seen,
            -1,
            placeholder_img, placeholder_img, placeholder_img, placeholder_img, placeholder_img,
            status,
            os.path.join(RESULTS_DIR, f"{uid}.jsonl"),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
        )

    idx = pick_next_index(user_seen, samples)
    if idx == -1:
        return (
            uid,
            samples,
            user_seen,
            -1,
            placeholder_img, placeholder_img, placeholder_img, placeholder_img, placeholder_img,
            "No more new images available.",
            "",
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True)
        )

    sample = samples[idx]

    status = (
        f"Welcome, {name}. Personal progress — images left: {left} of {user_target_count(samples)}.\n"
        f"Current sample: {sample.sample_id}"
    )

    os.makedirs(RESULTS_DIR, exist_ok=True)
    user_file_path = os.path.join(RESULTS_DIR, f"{uid}.jsonl")

    return (
        uid,
        samples,
        user_seen,
        idx,
        load_image(sample.masked_gt_path),
        load_image(sample.unmasked_gt_path),
        load_image(sample.sr_path),
        load_image(sample.original_path),
        load_image(sample.image_5_path),
        status,
        user_file_path,
        gr.update(visible=True),
        gr.update(visible=False),
        gr.update(visible=False),
    )


def _save_record_and_progress(
    name: str,
    email: str,
    uid: str,
    samples: List[Sample],
    user_seen: List[str],
    idx: int,
    score_1: float,
    score_2: float,
    score_3: float,
    score_4: float,
    score_5: float,
    q1_notes: str,
):
    if not name or not email:
        raise gr.Error("Please enter your name and email.")

    # FIX: use samples directly
    if idx is None or idx < 0 or idx >= len(samples):
        return load_progress()

    rank_dict = convert_scores_to_rank(score_1, score_2, score_3, score_4, score_5)

    sample = samples[idx]
    progress = load_progress()
    progress.setdefault(uid, {"seen": []})
    seen = set(progress[uid].get("seen", []))

    if sample.sample_id in seen:
        return progress

    record: Dict[str, Any] = {
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "user_id": uid,
        "name": name if SAVE_PII else None,
        "email": email if SAVE_PII else None,
        "sample_id": sample.sample_id,
        "raw_scores": {
            "image_1": score_1,
            "image_2": score_2,
            "image_3": score_3,
            "image_4": score_4,
            "image_5": score_5,
        },
        "responses": {
            "notes": q1_notes or "",
            "image_ranking": rank_dict,
        },
    }

    os.makedirs(RESULTS_DIR, exist_ok=True)
    append_jsonl(os.path.join(RESULTS_DIR, f"{uid}.jsonl"), record)
    append_jsonl(ALL_RESULTS_JSONL, record)

    # start background push but don't let failures crash the app
    try:
        thread = threading.Thread(target=push_results_to_private_repo, args=(uid,))
        thread.daemon = True
        thread.start()
    except Exception:
        pass

    seen.add(sample.sample_id)
    progress[uid]["seen"] = sorted(list(seen))
    save_progress(progress)
    return progress

# ----------------------
# Buttons
# ----------------------
def submit_finish(
    name: str,
    email: str,
    uid: str,
    samples: List[Sample],
    user_seen: List[str],
    idx: int,
    s1: float, s2: float, s3: float, s4: float, s5: float,
    q1_notes: str
):
    try:
        _save_record_and_progress(
            name, email, uid, samples, user_seen, idx,
            s1, s2, s3, s4, s5,
            q1_notes
        )
    except gr.Error:
        return (
            user_seen, idx,
            gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
            gr.update(),
            gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
            gr.update(),
        )

    return (
        user_seen, idx,
        gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None),
        gr.update(value=""),
        gr.update(value="Finished!"),
        gr.update(value=5), gr.update(value=5), gr.update(value=5), gr.update(value=5), gr.update(value=5),
        gr.update(value=None),
    )

def pause_exit(user_seen, samples):
    return user_seen, samples

def submit_next_image(
    name: str,
    email: str,
    uid: str,
    samples: List[Sample],
    user_seen: List[Sample],
    idx: int,
    s1: float, s2: float, s3: float, s4: float, s5: float,
    q1_notes: str
):
    try:
        progress = _save_record_and_progress(
            name, email, uid, samples, user_seen, idx,
            s1, s2, s3, s4, s5,
            q1_notes
        )
    except gr.Error as e:
        raise e

    seen_list = progress.get(uid, {}).get("seen", [])
    left_after = user_left_count(seen_list, samples)
    target = user_target_count(samples)

    # placeholder image to avoid Gradio trying to load None
    placeholder_img = Image.new("RGB", (256, 256), color="gray")

    # If user reached the target, return placeholders for images and let the then() chain show thanks
    if left_after == 0:
        status = (
            f"Saved! You’ve completed all {target} images. 🎉 "
            f"Click **Exit** to close this session."
        )
        return (
            seen_list, -1,
            placeholder_img, placeholder_img, placeholder_img, placeholder_img, placeholder_img,
            gr.update(value=status),
            gr.update(value=""),
            5, 5, 5, 5, 5,
        )

    idx_next = pick_next_index(seen_list, samples)
    if idx_next == -1:
        # no more images but target not met (rare). return placeholders too.
        return (
            seen_list, -1,
            placeholder_img, placeholder_img, placeholder_img, placeholder_img, placeholder_img,
            "No more images.",
            "",
            5, 5, 5, 5, 5,
        )

    # FIX: define sample_next correctly
    sample_next = samples[idx_next]

    return (
        seen_list, idx_next,
        load_image(sample_next.masked_gt_path),
        load_image(sample_next.unmasked_gt_path),
        load_image(sample_next.sr_path),
        load_image(sample_next.original_path),
        load_image(sample_next.image_5_path),
        gr.update(value=""),
        gr.update(value=""),
        5, 5, 5, 5, 5,
    )


def to_thanks(name: str, user_seen: List[str], samples: List[Sample]):
    left = user_left_count(user_seen, samples)
    target = user_target_count(samples)
    if left > 0:
        msg = (
            f"### ⏸️ Session Paused!\n\n"
            f"### ✅ Thanks, {name}! Your progress has been saved.\n\n"
            f"We’re grateful for your time and expertise. Our suggested target is "
            f"{TARGET_PER_PERSON} images per reviewer.\n\n"
            f"You have **{left}** images left.\n\n"
            f"You can close this tab and return whenever you like—just use the same Name and Email to **continue where you left off**.\n\n"
            f"If you have questions, issues, or suggestions, please email **{CONTACT_EMAIL}**.\n\n"
            f"Click **Start Again** to evaluate another image."
        )
    else:
        msg = (
            f"### ✅ All Done, {name}!\n\n"
            f"You’ve completed the target of **{target}** images. Your responses are securely saved.\n\n"
            f"We’re extremely grateful for your time and expertise. You are welcome to continue with more images if you wish, or you can finish here.\n\n"
            f"If you have questions, issues, or suggestions, please email **{CONTACT_EMAIL}**.\n\n"
        )
    return gr.update(visible=False), gr.update(visible=True), gr.update(value=msg)

def hide_thanks():
    return gr.update(visible=False)

def maybe_show_thanks(name: str, seen: List[str], samples: List[Sample]):
    if len(set(seen or [])) >= TARGET_PER_PERSON:
        return to_thanks(name, seen, samples)
    return gr.update(visible=True), gr.update(visible=False), gr.update()

def reset_to_start():
    return (
        gr.update(value=""), # Clear Name
        gr.update(value=""), # Clear Email
        gr.update(visible=True), # Show Start Group
        gr.update(visible=True), # Show Intro
        gr.update(visible=False), # Hide Eval
        gr.update(visible=False), # Hide Thanks
    )

# ----------------------
# UI
# ----------------------
with gr.Blocks(title="RTS Human Evaluation", theme=gr.themes.Soft()) as demo:
    intro_md = gr.Markdown(
    f"""
    # Retrogressive Thaw Slump (RTS) Human Evaluation

    ### 👋 Welcome, and thanks for lending your expertise!
    We’re inviting domain experts to help evaluate satellite image patches for RTS.

    ---

    ### 📋 Instructions
    * **Suggested target:** ~{TARGET_PER_PERSON} images per reviewer.
    * **The Task:** For each set, you will see 5 variations of the same satellite image.
    * **Rating:** Rate each image from **1 (Poor)** to **10 (Excellent)** based on how clearly the RTS feature (indicated by the **Red Box**) is depicted.

    ### ⏸️ Saving & Resuming
    * **Automatic Saving:** Your progress is saved automatically after every "Submit".
    * **Take a Break:** You can close this tab at any time.
    * **How to Resume:** Simply return here and enter the **exact same Name and Email**. The system will pick up exactly where you left off.

    ---
    **Questions or issues?** Email **{CONTACT_EMAIL}** — we appreciate your feedback and suggestions.

    **Ready?** Enter your details below to begin.
    """
    )

    # Hidden states
    state_uid = gr.State("")
    state_samples = gr.State([])
    state_seen = gr.State([])
    state_idx = gr.State(-1)

    with gr.Group() as start_group:
        with gr.Row():
            name = gr.Textbox(label="Full name", placeholder="Jane Doe", autofocus=True)
            email = gr.Textbox(label="Email address", placeholder="jane@example.com")
            start_btn = gr.Button("Start / Resume", variant="primary")
        status = gr.Markdown("\n")

    eval_panel = gr.Group(visible=False)
    with eval_panel:
        gr.Markdown(
            """
            Focus your attention on the area inside the **Red Box**. This marks the potential location of the RTS. Compare the five images below. Rate how clearly and realistically each image depicts the **RTS** feature.

            **Rating Scale (1 - 10):**
            * **10 (Excellent):** The RTS feature is sharp, distinct, and clearly visible.
            * **1 (Poor):** The RTS feature is blurry, distorted, or impossible to distinguish.
            """
        )

        with gr.Row():
            with gr.Column(scale=1, min_width=150):
                gr.Markdown("<div style='text-align:center; font-weight:600;'>Image 1</div>")
                image_1 = gr.Image(show_label=False, interactive=False, height=256, show_download_button=False)
                score_1 = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Score (1-10)")

            with gr.Column(scale=1, min_width=150):
                gr.Markdown("<div style='text-align:center; font-weight:600;'>Image 2</div>")
                image_2 = gr.Image(show_label=False, interactive=False, height=256, show_download_button=False)
                score_2 = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Score (1-10)")

            with gr.Column(scale=1, min_width=150):
                gr.Markdown("<div style='text-align:center; font-weight:600;'>Image 3</div>")
                image_3 = gr.Image(show_label=False, interactive=False, height=256, show_download_button=False)
                score_3 = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Score (1-10)")

            with gr.Column(scale=1, min_width=150):
                gr.Markdown("<div style='text-align:center; font-weight:600;'>Image 4</div>")
                image_4 = gr.Image(show_label=False, interactive=False, height=256, show_download_button=False)
                score_4 = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Score (1-10)")

            with gr.Column(scale=1, min_width=150):
                gr.Markdown("<div style='text-align:center; font-weight:600;'>Image 5</div>")
                image_5 = gr.Image(show_label=False, interactive=False, height=256, show_download_button=False)
                score_5 = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Score (1-10)")

        notes_q1 = gr.Textbox(
            label="Notes (Optional)",
            lines=2,
            placeholder="If there are multiple RTS or ambiguities, please note here."
        )

        with gr.Row():
            submit_next_btn = gr.Button("Submit & Next Image", variant="primary")
            pause_exit_btn = gr.Button("Exit", variant="secondary")

        your_jsonl_path = gr.State()

    with gr.Group(visible=False) as thanks_group:
        thanks_md = gr.Markdown("### ✅ Thanks! Your responses were saved.\n\nClick **Start Again** to evaluate another image.")
        restart_btn = gr.Button("Start Again", variant="primary")

    # --- Wiring ---
    start_event = start_btn.click(
        start_or_resume,
        inputs=[name, email],
        outputs=[
            state_uid, state_samples, state_seen, state_idx,
            image_1, image_2, image_3, image_4, image_5,
            status, your_jsonl_path,
            eval_panel, intro_md, start_group
        ],
    )
    start_event.then(hide_thanks, inputs=None, outputs=[thanks_group])

    # 1. When Pause is clicked, just pass the state through
    pause_event = pause_exit_btn.click(
        pause_exit,
        inputs=[state_seen, state_samples],
        outputs=[state_seen, state_samples],
    )

    # 2. Then show the "Thanks/Resume" screen with the 'how many left' message
    pause_event.then(
        to_thanks,
        inputs=[name, state_seen, state_samples],
        outputs=[eval_panel, thanks_group, thanks_md],
    )

    nextimg_event = submit_next_btn.click(
        submit_next_image,
        inputs=[name, email, state_uid, state_samples, state_seen, state_idx,
                score_1, score_2, score_3, score_4, score_5, notes_q1],
        outputs=[state_seen, state_idx,
                image_1, image_2, image_3, image_4, image_5,
                status, notes_q1,
                score_1, score_2, score_3, score_4, score_5],
    )
    nextimg_event.then(
        maybe_show_thanks,
        inputs=[name, state_seen, state_samples],
        outputs=[eval_panel, thanks_group, thanks_md],
    )

    restart_event = restart_btn.click(
        reset_to_start,
        inputs=[],
        outputs=[
            name, email,
            start_group, intro_md,
            eval_panel, thanks_group
        ],
    )

if __name__ == "__main__":
    if HF_RESULTS_REPO:
        from huggingface_hub import snapshot_download
        try:
            snapshot_download(
                repo_id=HF_RESULTS_REPO,
                repo_type="dataset",
                local_dir=".",
                allow_patterns=["data/*", "results/*"],
                token=HF_TOKEN
            )
        except Exception as e:
            print(f"Error reading from HF: {e}")

    ensure_paths()
    _ = load_dataset(GT_MASKED_DIR, GT_UNMASKED_DIR, SR_DIR, ORIGINAL_DIR, IMAGE_5_DIR)

    print("✅ Launching app.")
    demo.queue()
    demo.launch()