File size: 56,140 Bytes
2b452b4
9e65be3
2b452b4
 
 
 
9e65be3
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e65be3
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e65be3
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e65be3
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
2b452b4
 
 
 
 
 
 
 
 
 
 
 
3857954
 
2b452b4
3857954
2b452b4
 
 
 
 
 
 
 
 
3857954
2b452b4
 
 
 
3857954
2b452b4
3857954
2b452b4
 
 
 
 
 
 
 
3857954
2b452b4
 
 
3857954
2b452b4
 
3857954
 
 
 
 
 
 
 
 
191f0e8
 
 
 
 
57d8ecf
191f0e8
 
 
 
 
57d8ecf
191f0e8
 
 
57d8ecf
191f0e8
2b452b4
 
3857954
 
 
2b452b4
 
3857954
 
2b452b4
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
3857954
 
 
 
 
 
 
 
2b452b4
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
3857954
 
 
 
 
 
 
 
2b452b4
 
 
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
3857954
 
 
 
 
 
 
 
 
2b452b4
3857954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b452b4
3857954
 
 
 
 
 
 
 
 
2b452b4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
# -*- coding: utf-8 -*-
"""demo_bias_v6.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1ByBGxmGTMZkAqDcZ9eFQYf-DBkSWmSfZ

# Demo app to mitigate bias
"""


# install general libraries
import os
import torch.nn as nn
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from facenet_pytorch import MTCNN
from torch.autograd import Function
import torch.nn.functional as F
import gradio as gr
from collections import OrderedDict
from copy import deepcopy
import cv2




def get_device():
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print("Device Selected:", device)
    return device

## Constants or Global

cpu_batch_size = 8

device = get_device()

mtcnn = MTCNN(image_size=224, device= device) # default is 224, now no need to mention later on

## Loading Models
class Vgg_vd_face_sfew_dag(nn.Module):

    def __init__(self):
        super(Vgg_vd_face_sfew_dag, self).__init__()
        self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
                     'std': [1, 1, 1],
                     'imageSize': [224, 224, 3]}
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu1_1 = nn.ReLU()
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu1_2 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu2_1 = nn.ReLU()
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu2_2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_1 = nn.ReLU()
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_2 = nn.ReLU()
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_1 = nn.ReLU()
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_2 = nn.ReLU()
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_3 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_1 = nn.ReLU()
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_2 = nn.ReLU()
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_3 = nn.ReLU()
        self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.fc6 = nn.Conv2d(512, 4096, kernel_size=[7, 7], stride=(1, 1))
        self.relu6 = nn.ReLU()
        self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
        self.relu7 = nn.ReLU()
        self.fc8 = nn.Linear(in_features=4096, out_features=7, bias=True)

    def forward(self, data):
        x1 = self.conv1_1(data)
        x2 = self.relu1_1(x1)
        x3 = self.conv1_2(x2)
        x4 = self.relu1_2(x3)
        x5 = self.pool1(x4)
        x6 = self.conv2_1(x5)
        x7 = self.relu2_1(x6)
        x8 = self.conv2_2(x7)
        x9 = self.relu2_2(x8)
        x10 = self.pool2(x9)
        x11 = self.conv3_1(x10)
        x12 = self.relu3_1(x11)
        x13 = self.conv3_2(x12)
        x14 = self.relu3_2(x13)
        x15 = self.conv3_3(x14)
        x16 = self.relu3_3(x15)
        x17 = self.pool3(x16)
        x18 = self.conv4_1(x17)
        x19 = self.relu4_1(x18)
        x20 = self.conv4_2(x19)
        x21 = self.relu4_2(x20)
        x22 = self.conv4_3(x21)
        x23 = self.relu4_3(x22)
        x24 = self.pool4(x23)
        x25 = self.conv5_1(x24)
        x26 = self.relu5_1(x25)
        x27 = self.conv5_2(x26)
        x28 = self.relu5_2(x27)
        x29 = self.conv5_3(x28)
        x30 = self.relu5_3(x29)
        x31 = self.pool5(x30)
        x32 = self.fc6(x31) # this is a conv layer, this is the output we need
        x33_preflatten = self.relu6(x32)
        x33 = x33_preflatten.view(x33_preflatten.size(0), -1)
        x34 = self.fc7(x33)
        x35 = self.relu7(x34)
        prediction = self.fc8(x35)
        return prediction


model_pretrained = Vgg_vd_face_sfew_dag()

class GradientReversalFn(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class DANN_VGG(nn.Module):
    def __init__(self,
                 model_pretrained = model_pretrained,
                 num_classes=7,
                 dropout_rate = 0.1,
                 ):
        super(DANN_VGG, self).__init__()
        #---------------------Feature Extractor Network---------------#

        list_feature_extractor = list(model_pretrained.children())[:-4]
        self.feature_extractor = nn.Sequential(*list_feature_extractor)

        #---------------------Class Classifier------------------------#

        list_class_classifer = list(model_pretrained.children())[-4:]
        list_class_classifer.insert(2, nn.Dropout(dropout_rate))
        self.class_classifier = nn.Sequential(*list_class_classifer)

        #---------------------Domain Classifier-----------------------#

        self.domain_classifier = nn.Sequential(nn.ReLU(),
                                               nn.Linear(4096, 4096),
                                               nn.Dropout(dropout_rate),
                                               nn.ReLU(),
                                               nn.Linear(4096, 2)
                                               )

        # Initialize the 4096,4096 to pre-trained
        pretrained_weights = model_pretrained.fc7.weight
        pretrained_biases = model_pretrained.fc7.bias
        with torch.no_grad():
          self.domain_classifier[1].weight.copy_(pretrained_weights)
          self.domain_classifier[1].bias.copy_(pretrained_biases)


    def forward(self, input_data, alpha = 0.0):
        features = self.feature_extractor(input_data)
        # print("features.shape......", features.shape)
        features = features.view(-1,features.size(1))
        # print("features.shape after view......", features.shape)

        reverse_features = GradientReversalFn.apply(features,alpha)

        class_output = self.class_classifier(features)
        domain_output = self.domain_classifier(reverse_features)

        return class_output, domain_output, features

# Downloading Models
skeleton_model = DANN_VGG(model_pretrained = model_pretrained, num_classes=7)  # skeleton copy

non_dann_model_sfew_expw_inference =  deepcopy(skeleton_model)  # skeleton copy
non_dann_model_sfew_expw_inference.load_state_dict(torch.load('non_dann_sfew_expw_23_05_wo_se_a.pt',map_location=torch.device(device)))

dann_model_sfew_expw_inference =  deepcopy(skeleton_model)  # skeleton copy
dann_model_sfew_expw_inference.load_state_dict(torch.load('dann_sfew_expw_23_05_wo_se_a.pt',map_location=torch.device(device)))

ewc_dann_model_sfew_expw_inference =  deepcopy(skeleton_model)  # skeleton copy
ewc_dann_model_sfew_expw_inference.load_state_dict(torch.load('ewc_dann_sfew_expw_23_05_wo_se_a.pt',map_location=torch.device(device)))

# Load Samples
df_samples_loaded = torch.load('df_sample_c8.pt') # updated with 8 conditions

"""### Condition 1: Images that are predicted same in non_dann, dann, dann_ewc, ground truth - 30"""

condition_1 = (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_ewc_dann'])
df_condition_1 = df_samples_loaded[condition_1].reset_index(drop=True)

"""### Condition 2: Images that are nok in non_dann, ok in dann and dann_ewc - 20"""
condition_2 = (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_ewc_dann'])
df_condition_2 = df_samples_loaded[condition_2].reset_index(drop=True)

"""### Condition 3. Images that are nok in non_dann, nok in dann and ok in dann_ewc - 27"""
condition_3 = (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_ewc_dann'])
df_condition_3 = df_samples_loaded[condition_3].reset_index(drop=True)

"""### Condition 4. Images that are nok in non_dann, dann_ewc but ok in dann - 24"""
condition_4 = (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_ewc_dann'])
df_condition_4 = df_samples_loaded[condition_4].reset_index(drop=True)

"""### Condition 5. Images which are non_ok on all 3 - 40"""

condition_5 = (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_ewc_dann'])
df_condition_5 = df_samples_loaded[condition_5].reset_index(drop=True)

"""### Condition 6. Images that are ok in non_dann, but nok in dann and dann_ewc - 30"""
condition_6 = (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_ewc_dann'])
df_condition_6 = df_samples_loaded[condition_6].reset_index(drop=True)

"""### Condition 7: Images that are ok in non_dann, dann, but nok in dann_ewc - 25"""
condition_7 = (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_ewc_dann'])
df_condition_7 = df_samples_loaded[condition_7].reset_index(drop=True)

"""### Condition 8: Images that are ok in non_dann, dann_ewc, but nok in dann"""
condition_8 = (df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_non_dann']) & (df_samples_loaded['gt_emotion'] != df_samples_loaded['p_emotion_dann']) & ( df_samples_loaded['gt_emotion'] == df_samples_loaded['p_emotion_ewc_dann'])
df_condition_8 = df_samples_loaded[condition_8].reset_index(drop=True)

"""# Dataloader"""

# Define the labels map
labels_map = {
    "0": "angry",
    "1": "disgust",
    "2": "fear",
    "3": "happy",
    "4": "sad",
    "5": "surprise",
    "6": "neutral"
}

# List of labels
labels = list(labels_map.values())

# Create the one-hot encoding matrix
label_matrix = torch.eye(len(labels))

# Function to get the one-hot vector for a specific emotion
def get_one_hot_vector(emotion, labels = labels, label_matrix= label_matrix):
    if emotion in labels:
        idx = labels.index(emotion)
        return label_matrix[idx]
    else:
        raise ValueError(f"Emotion '{emotion}' not found in labels.")

class CustomImageDataset(Dataset):
    def __init__(self, dataframe,
                 transform=None,
                 image_file_colname = 'image',
                 race_colname  = 'gt_race',
                 gt_emotion_colname  = 'gt_emotion',
                 image_pil_colname = 'image_pil'):
        self.dataframe = dataframe.reset_index(drop=True)
        self.basic_transform = transforms.Compose([transforms.Resize(224),
                                                   transforms.ToTensor()])
        self.transform = transform
        self.image_file_colname = image_file_colname
        self.race_colname = race_colname
        self.gt_emotion_colname = gt_emotion_colname
        self.image_pil_colname = image_pil_colname

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):

        img_filename = self.dataframe.loc[idx, self.image_file_colname]
        race = self.dataframe.loc[idx, self.race_colname]
        emotion = self.dataframe.loc[idx, self.gt_emotion_colname]
        emotion_one_hot = get_one_hot_vector(emotion.lower()) # emotion.lower() to match the labels
        image = self.dataframe.loc[idx, self.image_pil_colname] # pil image
        cropped_image = mtcnn(image) # mtcnn takes in PIL, returns tensor in cropped image 3x224x224

        if self.transform:
          image_transformed = self.transform(image) # original image, this converts PIL into tensor
        else:
          image_transformed = self.basic_transform(image) # original image, this converts PIL into tensor

        if cropped_image is None: #error in cropping
          # in case of error, original image is returned
          return image_transformed, emotion_one_hot, image_transformed, race
        else: # cropping went ok
          # cropped image alongside original image is returned, there is no transform on cropped image.
          cropped_image = (cropped_image + 1) / 2 # changing form -1,1 to 0,1
          return cropped_image, emotion_one_hot, image_transformed, race

"""## Dataset and Dataloader for Conditions 1 to 8"""

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),          # Convert images to PyTorch tensors
])

# condition 1
dataset_condition_1 = CustomImageDataset(dataframe= df_condition_1, transform=transform)
dataloader_condition_1 = DataLoader(dataset_condition_1, batch_size=cpu_batch_size, shuffle=True)
# condition 2
dataset_condition_2 = CustomImageDataset(dataframe= df_condition_2, transform=transform)
dataloader_condition_2 = DataLoader(dataset_condition_2, batch_size=cpu_batch_size, shuffle=True)
# condition 3
dataset_condition_3 = CustomImageDataset(dataframe= df_condition_3, transform=transform)
dataloader_condition_3 = DataLoader(dataset_condition_3, batch_size=cpu_batch_size, shuffle=True)
# condition 4
dataset_condition_4 = CustomImageDataset(dataframe= df_condition_4, transform=transform)
dataloader_condition_4 = DataLoader(dataset_condition_4, batch_size=cpu_batch_size, shuffle=True)
# condition 5
dataset_condition_5 = CustomImageDataset(dataframe= df_condition_5, transform=transform)
dataloader_condition_5 = DataLoader(dataset_condition_5, batch_size=cpu_batch_size, shuffle=True)
# condition 6
dataset_condition_6 = CustomImageDataset(dataframe= df_condition_6, transform=transform)
dataloader_condition_6 = DataLoader(dataset_condition_6, batch_size=cpu_batch_size, shuffle=True)
# condition 7
dataset_condition_7 = CustomImageDataset(dataframe= df_condition_7, transform=transform)
dataloader_condition_7 = DataLoader(dataset_condition_7, batch_size=cpu_batch_size, shuffle=True)
# condition 8
dataset_condition_8 = CustomImageDataset(dataframe= df_condition_8, transform=transform)
dataloader_condition_8 = DataLoader(dataset_condition_8, batch_size=cpu_batch_size, shuffle=True)

"""# UI and its related functions 

### Get images function for all conditions
"""

transform_to_pil  = transforms.ToPILImage()
emotion_labels = [label.capitalize() for label in list(labels_map.values())]

def get_images(dataloader = dataloader_condition_1):
    cropped_images, emotions, images, races = next(iter(dataloader))

    list_pil_cropped_images = [transform_to_pil(cropped_img) for cropped_img in cropped_images]
    list_pil_images = [transform_to_pil(img) for img in images]
    list_emotions = list(emotions) # list of one hot tensort
    list_emotions = [ emotion_labels[torch.argmax(emotion).item()] for emotion in emotions]

    return list_pil_cropped_images, list_emotions, list_pil_images, list(races)

def get_images_condition_2(): return get_images(dataloader_condition_2)
def get_images_condition_3(): return get_images(dataloader_condition_3)
def get_images_condition_4(): return get_images(dataloader_condition_4)
def get_images_condition_5(): return get_images(dataloader_condition_5)
def get_images_condition_6(): return get_images(dataloader_condition_6)
def get_images_condition_7(): return get_images(dataloader_condition_7)
def get_images_condition_8(): return get_images(dataloader_condition_8)


"""### Classify Images All models"""

emotion_labels = [label.capitalize() for label in list(labels_map.values())]

def classify_image_all_models(input_image):
    image_transforms =  transforms.Compose([
                                  transforms.Resize((224,224)),
                                  transforms.ToTensor()
                                  ])
    transformed_image = image_transforms(input_image)
    image_tensor = transformed_image.to(device).unsqueeze(0)

    list_confidences = []
    for model in [non_dann_model_sfew_expw_inference, dann_model_sfew_expw_inference, ewc_dann_model_sfew_expw_inference]:
        model = model.to(device)
        model.eval()
        with torch.no_grad():
            logits, _ , _ = model((image_tensor*255))
            output = F.softmax(logits.view(-1), dim = -1)

            confidences = [(emotion_labels[i], float(output[i])) for i in range(len(emotion_labels))]
            confidences.sort(key=lambda x: x[1], reverse=True)
            confidences = OrderedDict(confidences[:2])
            label = torch.argmax(output).item()
            list_confidences.append(confidences)

    return list_confidences[0], list_confidences[1], list_confidences[2]

def display_image():
    # Load the image from a local file
    image = Image.open("DBMF.png")
    return image

cases_table = """
| **Cases (Tabs)** | **Baseline / Non DANN** | **DANN without EWC** | **DANN with EWC** |
|--------------|--------------|--------------|--------------|
| **Case 1 (C1)**    | **Same** as GT    | **Same** as GT     | **Same** as GT |
| **Case 2 (C2)**    | **Different** from GT        | **Same** as GT | **Same** as GT |
| **Case 3 (C3)**    | **Different** from GT        | **Different** from GT       | **Same** as GT|
| **Case 4 (C4)**    | **Different** from GT        | **Same** as GT | **Different** from GT   |
| **Case 5 (C5)**    | **Different** from GT        | **Different** from GT  | **Different** from GT   |
| **Case 6 (C6)**    | **Same** as GT       | **Different** from GT  | **Different** from GT   |
| **Case 7 (C7)**    | **Same** as GT       | **Same** as GT | **Different** from GT   |
| **Case 8 (C8)**    | **Same** as GT        | **Different** from GT  | **Same** as GT  |
"""

"""Gradio UI"""
"""Gradio UI"""
theme = gr.themes.Base()

with gr.Blocks(theme=theme) as demo:
    with gr.Tab("Introduction"):
        gr.Markdown("## Domain Adaptation in Deep Networks - Practical Implementation of Demographic Bias Mitigation Framework")
        with gr.Row():
            with gr.Column():
                image_output = gr.Image(value=display_image(), label = "Demographic Bias Mitigation Framework",height = 400, width = 600, show_label = True)
            with gr.Column():
                gr.Markdown('''
                    Source - Static Facial Expression in Wild (SFEW) 2.0 or SFEW2.0
                    -------
                    - SFEW 2.0 dataset is split into training (958 samples), validation (436 samples), and test sets (372 samples).
                    - Each image is labelled with one of seven emotions: Angry, Disgust, Fear, Happy, Sad, Surprise, and Neutral.
                    - It is used as racially bias source domain dataset in this demonstration
                    ''')
                gr.Markdown(
                    ''' Target - Expression in-the-Wild or ExpW dataset
                    -------
                    - It comprises of 91,793 manually labeled images without specific Train/Validation/Test splits.
                    - Similar to SFEW 2.0, each image in the ExpW dataset is categorized as: Angry, Disgust, Fear, Happy, Sad, Surprise, and Neutral.
                    - It is used as racially unbiased target domain dataset in this demonstration
                    '''
                )
                gr.Markdown(
                    '''
                    Evaluation Dataset - ExpW (Race annotated)
                    -------
                    - A subset of 8,458 images from the ExpW dataset was randomly selected and annotated with race in addition to the existing emotion labels.
                    - This subset also served as the validation dataset - **The images from this validation set are used**
                    '''
                )

        with gr.Row():
            with gr.Column():
                with gr.Row():
                    gr.Markdown(''' **Fig. 1.** **DANN based demographic bias mitigation framework (DBMF)**: Framework supports
                    (a) Non-domain adapted (Baseline, Non-DA) training, (b) Domain adapted (DA) training and (c) bias and task accuracy evaluation steps.
                    Pre-processed data is fed into feature extractor. Fully connected network in Task-specific component has classification/regression heads.
                    Domain classifier comprises of fully connected layers with a binary classification head.
                    Unsupervised DA is supported by a gradient reversal layer, which flips errors during backpropagation, forcing the feature extractor
                    to prioritize domain-invariant features that remain task-specific.
                    For (c) bias and task accuracy evaluation, an evaluation test set is used to understand task accuracy and bias using statistical metrics / tests.
                    ''')
                with gr.Row():
                    gr.Markdown('''
                    Models
                    ------
                    ''')
                with gr.Row():
                    gr.Markdown('''
                    **1. 	Baseline or Non-DA Network**: This network lacks DA components and only includes feature extractor and task-specific component Training only involves biased source domain data (with train, validation and if available, test splits). The network is trained to minimize the loss objective, which is specific to the task 
                    ''')
                with gr.Row():
                    gr.Markdown('''
                    **2. 	DA (without EWC) Network**: This network includes DA specific components like domain classifier and gradient reversal layer in addition to the components of Non-DA (baseline) network. It uses two datasets, i.e., demographic-biased source dataset (same as baseline network with train, validation, and, if available, test splits) and demographic-neutral target dataset, which have less bias / no bias. 
                    ''')
                with gr.Row(): 
                    gr.Markdown('''
                    **3. 	DA (with EWC) Network**: This network is similar to DA (without EWC) network, but enforces a regularization constraint on the parameters of feature extractor and task-specific component using EWC algorithm while training. 
                    ''')
            with gr.Column():
                with gr.Row():
                    gr.Markdown(''' **Table. 1.** Cases (Tabs) showing  predicted emotion in 3 networks as compared to ground truth emotion''')
                with gr.Row():
                    gr.HTML("<b><span style='color: blue;'>Click on the tabs for corresponding cases</span></b>")
                with gr.Row():
                    gr.Markdown(cases_table)
                with gr.Row():
                    gr.HTML("<b><span style='color: blue;'>Click on the tab - Predict Emotion for your own image - to predict emotions for your image(s) </span></b>")
    ################################################
    with gr.Tab("C1: All Same") as tabs1:
        imgs = gr.State()

        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that the predictions of all the models is same.

                The consistency in predictions across all models can be attributed to the robustness of the image features, the clarity and distinctness of the emotions, the effective representation of similar examples in the training data, and the overall quality of the images.

                In case of multiple faces in image, the correct face is cropped for which the emotion label is available.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")

            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 1 (C1): All Same**    | **Same** as emotion GT    | **Same** as emotion GT     | **Same** as emotion GT |

                """
                )


        with gr.Row():
            with gr.Column(scale=1):
                gallery = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C1 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    # selected = gr.Number(show_label=False)
                    selected = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images, list_emotions, big_images, list_races = get_images()

        def get_big_images():
            return big_images, big_images

        def get_select_index(evt: gr.SelectData):
            # return evt.index, cropped_images[evt.index]
            return list_emotions[evt.index], gr.update(visible=True), cropped_images[evt.index], list_races[evt.index], gr.update(visible=True)

        refresh_case1 = gr.Checkbox(visible=False)

        def refresh_gallery():
            return get_big_images()

        refresh_case1.change(refresh_gallery, None, [gallery, imgs])
        gallery.select(get_select_index, None, [selected, selected, cropped_image_display, txtbox_race, txtbox_race])
        button_classify_C1.click(fn=classify_image_all_models, inputs=[cropped_image_display],
                                 outputs=[label_classify_non_dann, label_classify_dann, label_classify_ewc])

    ################################################
    with gr.Tab("C2:Same:DANN, EWC DANN | Diff:Base") as tabs2:
        imgs_2 = gr.State()

        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that emotion predicted same as GT by DANN with and without EWC, but different by Baseline (Non-DANN)

                The key factors contributing to the correct predictions by the DANN models (with and without EWC) and the incorrect predictions by the Baseline (Non-DANN) model include the effectiveness of domain adaptation, feature sensitivity to race, robustness to biases, improved generalization capabilities, and the overall ability to handle domain shifts.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")
            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 2 (C2): Emotion predicted same as GT by DANN with and without EWC, but different by Baseline (Non-DANN)**    | **Different** from emotion GT        | **Same** as emotion GT | **Same** as emotion GT |

                """
                )

        with gr.Row():
            with gr.Column(scale=1):
                gallery_2 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_2 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C2 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_2 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_2 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_2 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_2 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_2 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_2, list_emotions_2, big_images_2, list_races_2 = get_images_condition_2()

        def get_big_images_2():
            return big_images_2, big_images_2

        def get_select_index_2(evt: gr.SelectData):
            # return evt.index, cropped_images[evt.index]
            return list_emotions_2[evt.index], gr.update(visible=True), cropped_images_2[evt.index], list_races_2[evt.index], gr.update(visible=True)

        refresh_case2 = gr.Checkbox(visible=False)

        def refresh_gallery_2():
            return get_big_images_2()

        refresh_case2.change(refresh_gallery_2, None, [gallery_2, imgs_2])
        gallery_2.select(get_select_index_2, None, [selected_2, selected_2, cropped_image_display_2, txtbox_race_2, txtbox_race_2])
        button_classify_C2.click(fn=classify_image_all_models, inputs=[cropped_image_display_2],
                                 outputs=[label_classify_non_dann_2, label_classify_dann_2, label_classify_ewc_2])

    ################################################
    with gr.Tab("C3:Same:EWC DANN | Diff:Base, DANN")as tabs3:
        imgs_3 = gr.State()
        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that emotion predicted same as GT by DANN with EWC, but different by Baseline (Non-DANN) and DANN without EWC.

                This can be attributed to the complexity (with respect to assigning emotion) in the images themselves, image quality, incorrect face crop and busy images.

                Another plausible attribution is that The DANN model with EWC might have learned features that are more invariant to irrelevant variations (such as race or other biases) and more sensitive to the actual emotional content of the images.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")
            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 3 (C3): Emotion predicted same as GT by DANN with EWC, but differently by Baseline( Non-DANN) and DANN without EWC**    | **Different** from emotion GT        | **Different** from emotion GT       | **Same** as emotion GT|
                """
                )

        with gr.Row():
            with gr.Column(scale=1):
                gallery_3 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_3 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C3 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_3 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_3 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_3 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_3 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_3 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_3, list_emotions_3, big_images_3, list_races_3 = get_images_condition_3()

        def get_big_images_3():
            return big_images_3, big_images_3

        def get_select_index_3(evt: gr.SelectData):
            # return evt.index, cropped_images[evt.index]
            return list_emotions_3[evt.index], gr.update(visible=True), cropped_images_3[evt.index], list_races_3[evt.index], gr.update(visible=True)

        refresh_case3 = gr.Checkbox(visible=False)

        def refresh_gallery_3():
            return get_big_images_3()

        refresh_case3.change(refresh_gallery_3, None, [gallery_3, imgs_3])
        gallery_3.select(get_select_index_3, None, [selected_3, selected_3, cropped_image_display_3, txtbox_race_3, txtbox_race_3])
        button_classify_C3.click(fn=classify_image_all_models, inputs=[cropped_image_display_3],
                                 outputs=[label_classify_non_dann_3, label_classify_dann_3, label_classify_ewc_3])

    ################################################
    with gr.Tab("C4:Same:DANN | Diff:Base, EWC DANN") as tabs4:
        imgs_4 = gr.State()
        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that emotion predicted same as GT by DANN without EWC, but different by Baseline (Non-DANN) and DANN with EWC.

                This can be attributed to the over-regularization by EWC failing adaptation and generalization. The DANN model without EWC might have found a balance between generalization and specificity that was optimal for this particular case.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")

            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 4 (C4): Emotion predicted same as GT by DANN without EWC, but differently by Baseline( Non-DANN) and DANN with EWC **    | **Different** from emotion GT | **Same** as emotion GT  | **Different** from emotion GT    |
                """
                )


        with gr.Row():
            with gr.Column(scale=1):
                gallery_4 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_4 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C4 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_4 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_4 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_4 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_4 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_4 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_4, list_emotions_4, big_images_4, list_races_4 = get_images_condition_4()

        def get_big_images_4():
            return big_images_4, big_images_4

        def get_select_index_4(evt: gr.SelectData):
            return list_emotions_4[evt.index], gr.update(visible=True), cropped_images_4[evt.index], list_races_4[evt.index], gr.update(visible=True)

        refresh_case4 = gr.Checkbox(visible=False)

        def refresh_gallery_4():
            return get_big_images_4()

        refresh_case4.change(refresh_gallery_4, None, [gallery_4, imgs_4])
        gallery_4.select(get_select_index_4, None, [selected_4, selected_4, cropped_image_display_4, txtbox_race_4, txtbox_race_4])
        button_classify_C4.click(fn=classify_image_all_models, inputs=[cropped_image_display_4],
                                 outputs=[label_classify_non_dann_4, label_classify_dann_4, label_classify_ewc_4])

    ################################################
    with gr.Tab("C5: All Diff") as tabs5:
        imgs_5 = gr.State()

        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that emotion predicted are different from GT by all the models.

                The consistent failure of all models to correctly predict the emotions as per the ground truth could be can be attributed to ambiguous or noisy data, inherent model limitation, adaptation failure (significant domain shift), complex emotion expressions, features extracted by the models might not be sensitive or relevant enough to the specific emotional cues present in these images.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")
            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 5 (C5) : Emotion predicted are different from GT by all the models**    | **Different** from emotion GT      | **Different** from emotion GT | **Different** from emotion GT |
                """
                )



        with gr.Row():
            with gr.Column(scale=1):
                gallery_5 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_5 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C5 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_5 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_5 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_5 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_5 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_5 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_5, list_emotions_5, big_images_5, list_races_5 = get_images_condition_5()

        def get_big_images_5():
            return big_images_5, big_images_5

        def get_select_index_5(evt: gr.SelectData):
            return list_emotions_5[evt.index], gr.update(visible=True), cropped_images_5[evt.index], list_races_5[evt.index], gr.update(visible=True)

        refresh_case5 = gr.Checkbox(visible=False)

        def refresh_gallery_5():
            return get_big_images_5()

        refresh_case5.change(refresh_gallery_5, None, [gallery_5, imgs_5])
        gallery_5.select(get_select_index_5, None, [selected_5, selected_5, cropped_image_display_5, txtbox_race_5, txtbox_race_5])
        button_classify_C5.click(fn=classify_image_all_models, inputs=[cropped_image_display_5],
                                 outputs=[label_classify_non_dann_5, label_classify_dann_5, label_classify_ewc_5])


    ################################################
    with gr.Tab("C6:Same:Base | Diff:DANN, EWC DANN") as tabs6:
        imgs_6 = gr.State()

        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that emotion predicted same as GT by by Baseline( Non-DANN), but differently  DANN with and without EWC models

                This can be attributed to DANN models overcompensating for the domain shift, leading to misalignment of critical features.

                Another plausible attirbution may be that Baseline (Non-DANN) model may have retained generalizable features relevant for emotion recognition.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")
            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 6 (C6): Emotion predicted same as GT by by Baseline( Non-DANN), but differently  DANN with and without EWC models**    | **Same** as emotion GT   | **Different** from emotion GT | **Different** from emotion GT |                """
                )



        with gr.Row():
            with gr.Column(scale=1):
                gallery_6 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_6 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C6 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_6 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_6 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_6 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_6 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_6 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_6, list_emotions_6, big_images_6, list_races_6 = get_images_condition_6()

        def get_big_images_6():
            return big_images_6, big_images_6

        def get_select_index_6(evt: gr.SelectData):
            return list_emotions_6[evt.index], gr.update(visible=True), cropped_images_6[evt.index], list_races_6[evt.index], gr.update(visible=True)

        refresh_case6 = gr.Checkbox(visible=False)

        def refresh_gallery_6():
            return get_big_images_6()

        refresh_case6.change(refresh_gallery_6, None, [gallery_6, imgs_6])
        gallery_6.select(get_select_index_6, None, [selected_6, selected_6, cropped_image_display_6, txtbox_race_6, txtbox_race_6])
        button_classify_C6.click(fn=classify_image_all_models, inputs=[cropped_image_display_6],
                                 outputs=[label_classify_non_dann_6, label_classify_dann_6, label_classify_ewc_6])
    ################################################
    with gr.Tab("C7:Same:Base, DANN | Diff: EWC DANN") as tabs7:
        imgs_7 = gr.State()

        with gr.Row():
            with gr.Column():
                gr.Markdown(''' In this case (scenario), we see that emotion predicted same as GT by by Baseline( Non-DANN) and DANN without EWC, but differently by DANN with models

                This can be attributed to potential over-regularization of DANN with EWC model, emotion expression complexity, label noise, image quality, incorrect face crop.

                EWC’s constraints might hinder the model’s ability to adapt effectively to the target domain features, leading to incorrect predictions.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")
            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 7 (C7): Emotion predicted same as GT by by Baseline( Non-DANN) and DANN without EWC, but differently by DANN with models**    | **Same** as emotion GT       | **Same** as emotion GT | **Different** from emotion GT   |                """
                )

        with gr.Row():
            with gr.Column(scale=1):
                gallery_7 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_7 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C7 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_7 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_7 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_7 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_7 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_7 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_7, list_emotions_7, big_images_7, list_races_7 = get_images_condition_7()

        def get_big_images_7():
            return big_images_7, big_images_7

        def get_select_index_7(evt: gr.SelectData):
            return list_emotions_7[evt.index], gr.update(visible=True), cropped_images_7[evt.index], list_races_7[evt.index], gr.update(visible=True)

        refresh_case7 = gr.Checkbox(visible=False)

        def refresh_gallery_7():
            return get_big_images_7()

        refresh_case7.change(refresh_gallery_7, None, [gallery_7, imgs_7])
        gallery_7.select(get_select_index_7, None, [selected_7, selected_7, cropped_image_display_7, txtbox_race_7, txtbox_race_7])
        button_classify_C7.click(fn=classify_image_all_models, inputs=[cropped_image_display_7],
                                 outputs=[label_classify_non_dann_7, label_classify_dann_7, label_classify_ewc_7])
    ################################################
    with gr.Tab("C8:Same:Base, EWC DANN | Diff: DANN") as tabs8:
        imgs_8 = gr.State()

        with gr.Row():
            with gr.Column():
                gr.Markdown('''
                In this case (scenario), we see that emotion predicted same as GT by by Baseline( Non-DANN) and DANN with EWC, but differently by DANN without models.

                This can be attributed to the lack of feature preservation in the DANN model without EWC, which could lead to over-adaptation or loss of crucial features.
                ''')
                gr.HTML("<b><span style='color: blue;'>Select the image from the gallery, <br>Click the Button to predict emotions using 3 models</span></b>")
            with gr.Column():
                gr.Markdown(
                """
                | **Demonstration scenario** | **Prediction Emotion for Baseline (Non-DANN)** | **Prediction Emotion for DANN without EWC** | **Prediction Emotion for DANN with EWC** |
                |--------------|--------------|--------------|--------------|
                | **Case 8 (C8): Emotion predicted same as GT by by Baseline( Non-DANN) and DANN with EWC, but differently by DANN without models**    | **Same** as emotion GT        | **Different** from emotion GT  | **Same** as emotion GT  |
                """
                )

        with gr.Row():
            with gr.Column(scale=1):
                gallery_8 = gr.Gallery(allow_preview=True, rows=2, columns=2)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_8 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C8 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')
                with gr.Row():
                    selected_8 = gr.Textbox(label="Ground Truth Emotion", visible=False)
                    txtbox_race_8 = gr.Textbox(label="Race", visible=False)

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_8 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_8 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_8 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)

        cropped_images_8, list_emotions_8, big_images_8, list_races_8 = get_images_condition_8()

        def get_big_images_8():
            return big_images_8, big_images_8

        def get_select_index_8(evt: gr.SelectData):
            return list_emotions_8[evt.index], gr.update(visible=True), cropped_images_8[evt.index], list_races_8[evt.index], gr.update(visible=True)

        refresh_case8 = gr.Checkbox(visible=False)

        def refresh_gallery_8():
            return get_big_images_8()

        refresh_case8.change(refresh_gallery_8, None, [gallery_8, imgs_8])
        gallery_8.select(get_select_index_8, None, [selected_8, selected_8, cropped_image_display_8, txtbox_race_8, txtbox_race_8])
        button_classify_C8.click(fn=classify_image_all_models, inputs=[cropped_image_display_8],
                                 outputs=[label_classify_non_dann_8, label_classify_dann_8, label_classify_ewc_8])

    ################################################
    with gr.Tab("Predict Emotion for your own Image") as tabs9:

        with gr.Row():
            gr.HTML("<b><span style='color: red;'>Upload an image and predict emotion for the image</span></b>")

        with gr.Row():
            with gr.Column(scale=1):
                with gr.Row():
                    big_image = gr.Image(label="Upload your own Image", type="pil", height=224, width=224)

            with gr.Column(scale=1):
                with gr.Row():
                    cropped_image_display_9 = gr.Image(label="Cropped Image", type="pil", height=224, width=224)
                with gr.Row():
                    button_classify_C9 = gr.Button("Click Button to Predict Emotion", visible=True, size='sm')

            with gr.Column(scale=1):
                with gr.Row():
                    label_classify_non_dann_9 = gr.Label(label="Baseline(Non DANN) Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_dann_9 = gr.Label(label="DANN Predicted Emotion", num_top_classes=2, visible=True)
                with gr.Row():
                    label_classify_ewc_9 = gr.Label(label="EWC DANN Predicted Emotion", num_top_classes=2, visible=True)


        def crop_image(big_image):
            cropped_image = mtcnn(big_image) # mtcnn takes in PIL, returns tensor in cropped image 3x224x224

            if cropped_image is None: #error in cropping
            # in case of error, original image is returned
                return big_image
            else: # cropping went ok
            # cropped image is returned, there is no transform on cropped image.
                cropped_image = (cropped_image + 1) / 2 # changing form -1,1 to 0,1
                transform_to_pil  = transforms.ToPILImage() #output image has to be returned in PIL format
                return transform_to_pil(cropped_image)


        big_image.change(fn=crop_image, inputs=[big_image], outputs=[cropped_image_display_9])
        button_classify_C9.click(fn=classify_image_all_models, inputs=[cropped_image_display_9],
                                 outputs=[label_classify_non_dann_9, label_classify_dann_9, label_classify_ewc_9])
    ################################################

    def refresh_tab(): return True

    tabs1.select(lambda x: True, None, [refresh_case1])
    tabs2.select(refresh_tab, None, [refresh_case2])
    tabs3.select(refresh_tab, None, [refresh_case3])
    tabs4.select(refresh_tab, None, [refresh_case4])
    tabs5.select(refresh_tab, None, [refresh_case5])
    tabs6.select(refresh_tab, None, [refresh_case6])
    tabs7.select(refresh_tab, None, [refresh_case7])
    tabs8.select(refresh_tab, None, [refresh_case8])

demo.launch(debug=True)