ASomeoneWhoInterestedWithAI commited on
Commit
766a703
·
verified ·
1 Parent(s): 904ded4

Update CleanedCode.md

Browse files
Files changed (1) hide show
  1. CleanedCode.md +506 -0
CleanedCode.md CHANGED
@@ -1,4 +1,6 @@
1
  # Cleaned code
 
 
2
  ```python
3
  import os
4
  import math
@@ -665,4 +667,508 @@ print(
665
  f"Final model size: "
666
  f"{os.path.getsize('LookThem_STL.pth') / (1024*1024):.2f} MB"
667
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  ```
 
1
  # Cleaned code
2
+ ## Training
3
+
4
  ```python
5
  import os
6
  import math
 
667
  f"Final model size: "
668
  f"{os.path.getsize('LookThem_STL.pth') / (1024*1024):.2f} MB"
669
  )
670
+ ```
671
+
672
+ ## Inference
673
+ ```python
674
+ import torch
675
+ import torch.nn as nn
676
+ import torch.nn.functional as F
677
+ import torchvision.transforms as transforms
678
+
679
+ from PIL import Image
680
+ import math
681
+
682
+
683
+ # =========================================================
684
+ # 1. LOOKTHEM CORE LAYER
685
+ # =========================================================
686
+
687
+ class LookThemLayer(nn.Module):
688
+ """
689
+ Relational token-processing layer used by
690
+ the LookThem STL architecture.
691
+ """
692
+
693
+ def __init__(self, num_tokens, in_features, hidden_dim):
694
+ super(LookThemLayer, self).__init__()
695
+
696
+ self.num_tokens = num_tokens
697
+ self.in_features = in_features
698
+
699
+ # -------------------------------------------------
700
+ # Branch 1
701
+ # -------------------------------------------------
702
+ self.mod1_w1 = nn.Parameter(
703
+ torch.randn(num_tokens, in_features, hidden_dim)
704
+ )
705
+
706
+ self.mod1_b1 = nn.Parameter(
707
+ torch.zeros(num_tokens, hidden_dim)
708
+ )
709
+
710
+ self.mod1_w2 = nn.Parameter(
711
+ torch.randn(num_tokens, hidden_dim, 1)
712
+ )
713
+
714
+ self.mod1_b2 = nn.Parameter(
715
+ torch.zeros(num_tokens, 1)
716
+ )
717
+
718
+ # -------------------------------------------------
719
+ # Branch 2
720
+ # -------------------------------------------------
721
+ self.mod2_w1 = nn.Parameter(
722
+ torch.randn(num_tokens, in_features, hidden_dim)
723
+ )
724
+
725
+ self.mod2_b1 = nn.Parameter(
726
+ torch.zeros(num_tokens, hidden_dim)
727
+ )
728
+
729
+ self.mod2_w2 = nn.Parameter(
730
+ torch.randn(num_tokens, hidden_dim, 1)
731
+ )
732
+
733
+ self.mod2_b2 = nn.Parameter(
734
+ torch.zeros(num_tokens, 1)
735
+ )
736
+
737
+ # -------------------------------------------------
738
+ # Relational transformation
739
+ # -------------------------------------------------
740
+ self.trans_w = nn.Parameter(
741
+ torch.randn(num_tokens, 1, 1)
742
+ )
743
+
744
+ self.trans_b = nn.Parameter(
745
+ torch.zeros(num_tokens, 1)
746
+ )
747
+
748
+ self._init_weights()
749
+
750
+ def _init_weights(self):
751
+
752
+ for w in [
753
+ self.mod1_w1,
754
+ self.mod2_w1,
755
+ self.mod1_w2,
756
+ self.mod2_w2,
757
+ self.trans_w
758
+ ]:
759
+ nn.init.kaiming_uniform_(
760
+ w,
761
+ a=math.sqrt(5)
762
+ )
763
+
764
+ def forward(self, x):
765
+
766
+ N = self.num_tokens
767
+
768
+ # =================================================
769
+ # Branch 1
770
+ # =================================================
771
+ h1 = (
772
+ torch.einsum(
773
+ 'bti,tij->btj',
774
+ x,
775
+ self.mod1_w1
776
+ )
777
+ + self.mod1_b1
778
+ )
779
+
780
+ out_m1 = (
781
+ torch.einsum(
782
+ 'btj,tjk->btk',
783
+ F.gelu(h1),
784
+ self.mod1_w2
785
+ )
786
+ + self.mod1_b2
787
+ )
788
+
789
+ # =================================================
790
+ # Branch 2
791
+ # =================================================
792
+ h2 = (
793
+ torch.einsum(
794
+ 'bti,tij->btj',
795
+ x,
796
+ self.mod2_w1
797
+ )
798
+ + self.mod2_b1
799
+ )
800
+
801
+ out_m2 = (
802
+ torch.einsum(
803
+ 'btj,tjk->btk',
804
+ F.gelu(h2),
805
+ self.mod2_w2
806
+ )
807
+ + self.mod2_b2
808
+ )
809
+
810
+ # Numerical stabilization
811
+ out_m2_safe = out_m2 + 1e-5
812
+
813
+ # =================================================
814
+ # Pairwise comparison
815
+ # =================================================
816
+ compare = torch.tanh(
817
+ out_m1.unsqueeze(2) /
818
+ out_m2_safe.unsqueeze(1)
819
+ )
820
+
821
+ compare2 = torch.tanh(
822
+ out_m1.unsqueeze(1) /
823
+ out_m2_safe.unsqueeze(2)
824
+ )
825
+
826
+ # =================================================
827
+ # Relational transformation
828
+ # =================================================
829
+ bias_reshaped = self.trans_b.view(
830
+ 1,
831
+ 1,
832
+ N,
833
+ 1
834
+ )
835
+
836
+ trans_compare = (
837
+ torch.einsum(
838
+ 'bije,jef->bijf',
839
+ compare,
840
+ self.trans_w
841
+ )
842
+ + bias_reshaped
843
+ )
844
+
845
+ trans_compare2 = (
846
+ torch.einsum(
847
+ 'bije,jef->bijf',
848
+ compare2,
849
+ self.trans_w
850
+ )
851
+ + bias_reshaped
852
+ )
853
+
854
+ # =================================================
855
+ # Interaction fusion
856
+ # =================================================
857
+ interaction = (
858
+ trans_compare * x.unsqueeze(2)
859
+ + trans_compare2 * x.unsqueeze(1)
860
+ ) / 2
861
+
862
+ # Remove self-interaction
863
+ mask = 1.0 - torch.eye(
864
+ N,
865
+ device=x.device
866
+ )
867
+
868
+ interaction_masked = (
869
+ interaction *
870
+ mask.view(1, N, N, 1)
871
+ )
872
+
873
+ return (
874
+ interaction_masked.sum(dim=2)
875
+ / (N - 1.0)
876
+ )
877
+
878
+
879
+ # =========================================================
880
+ # 2. LOOKTHEM STL MODEL
881
+ # =========================================================
882
+
883
+ class LookThemSTLV1(nn.Module):
884
+
885
+ def __init__(self):
886
+ super(LookThemSTLV1, self).__init__()
887
+
888
+ # =================================================
889
+ # STREAM A — MACRO STRUCTURE
890
+ # =================================================
891
+ self.stream_a = nn.Sequential(
892
+
893
+ nn.Conv2d(
894
+ 3,
895
+ 16,
896
+ kernel_size=3,
897
+ stride=2,
898
+ padding=1
899
+ ),
900
+ nn.BatchNorm2d(16),
901
+ nn.GELU(),
902
+
903
+ nn.Conv2d(
904
+ 16,
905
+ 32,
906
+ kernel_size=3,
907
+ stride=2,
908
+ padding=1
909
+ ),
910
+ nn.BatchNorm2d(32),
911
+ nn.GELU(),
912
+
913
+ nn.Conv2d(
914
+ 32,
915
+ 64,
916
+ kernel_size=3,
917
+ stride=2,
918
+ padding=1
919
+ ),
920
+ nn.BatchNorm2d(64),
921
+ nn.GELU(),
922
+
923
+ nn.AdaptiveMaxPool2d((8, 8))
924
+ )
925
+
926
+ # =================================================
927
+ # STREAM B — MICRO DETAIL
928
+ # =================================================
929
+ self.stream_b = nn.Sequential(
930
+
931
+ nn.Conv2d(
932
+ 3,
933
+ 16,
934
+ kernel_size=3,
935
+ stride=1,
936
+ padding=1
937
+ ),
938
+ nn.BatchNorm2d(16),
939
+ nn.GELU(),
940
+
941
+ nn.Conv2d(
942
+ 16,
943
+ 32,
944
+ kernel_size=3,
945
+ stride=1,
946
+ padding=1
947
+ ),
948
+ nn.BatchNorm2d(32),
949
+ nn.GELU(),
950
+
951
+ nn.Conv2d(
952
+ 32,
953
+ 64,
954
+ kernel_size=3,
955
+ stride=2,
956
+ padding=1
957
+ ),
958
+ nn.BatchNorm2d(64),
959
+ nn.GELU(),
960
+
961
+ nn.AdaptiveMaxPool2d((8, 8))
962
+ )
963
+
964
+ # =================================================
965
+ # RELATIONAL PROCESSORS
966
+ # =================================================
967
+ self.lookthemA = LookThemLayer(
968
+ num_tokens=64,
969
+ in_features=64,
970
+ hidden_dim=16
971
+ )
972
+
973
+ self.lookthemB = LookThemLayer(
974
+ num_tokens=64,
975
+ in_features=64,
976
+ hidden_dim=16
977
+ )
978
+
979
+ self.lookthem = LookThemLayer(
980
+ num_tokens=64,
981
+ in_features=128,
982
+ hidden_dim=32
983
+ )
984
+
985
+ # =================================================
986
+ # TOKEN COMPRESSOR
987
+ # =================================================
988
+ self.compressor = nn.AdaptiveAvgPool1d(32)
989
+
990
+ # =================================================
991
+ # CLASSIFIER HEAD
992
+ # =================================================
993
+ self.classifier = nn.Sequential(
994
+
995
+ nn.Flatten(),
996
+
997
+ nn.Linear(64 * 32, 512),
998
+ nn.ReLU(),
999
+ nn.Dropout(0.4),
1000
+
1001
+ nn.Linear(512, 256),
1002
+ nn.ReLU(),
1003
+ nn.Dropout(0.2),
1004
+
1005
+ nn.Linear(256, 10)
1006
+ )
1007
+
1008
+ def forward(self, x):
1009
+
1010
+ batch_size = x.size(0)
1011
+
1012
+ # =================================================
1013
+ # STREAM A
1014
+ # =================================================
1015
+ feat_a = self.stream_a(x)
1016
+
1017
+ feat_a_flat = feat_a.view(
1018
+ batch_size,
1019
+ 64,
1020
+ 64
1021
+ )
1022
+
1023
+ feat_a_tokens = feat_a_flat.transpose(1, 2)
1024
+
1025
+ feat_a_lt = self.lookthemA(feat_a_tokens)
1026
+
1027
+ # =================================================
1028
+ # STREAM B
1029
+ # =================================================
1030
+ feat_b = self.stream_b(x)
1031
+
1032
+ feat_b_tokens = (
1033
+ feat_b
1034
+ .view(batch_size, 64, 64)
1035
+ .transpose(1, 2)
1036
+ )
1037
+
1038
+ feat_b_lt = self.lookthemB(feat_b_tokens)
1039
+
1040
+ # =================================================
1041
+ # FEATURE FUSION
1042
+ # =================================================
1043
+ tokens_combined = torch.cat(
1044
+ [feat_a_lt, feat_b_lt],
1045
+ dim=2
1046
+ )
1047
+
1048
+ # =================================================
1049
+ # RELATIONAL COGNITION
1050
+ # =================================================
1051
+ out_lookthem = self.lookthem(tokens_combined)
1052
+
1053
+ compressed = self.compressor(out_lookthem)
1054
+
1055
+ return self.classifier(compressed)
1056
+
1057
+
1058
+ # =========================================================
1059
+ # 3. DEVICE SETUP
1060
+ # =========================================================
1061
+
1062
+ device = torch.device(
1063
+ "cuda" if torch.cuda.is_available() else "cpu"
1064
+ )
1065
+
1066
+ print(f"Using device: {device}")
1067
+
1068
+
1069
+ # =========================================================
1070
+ # 4. CLASS LABELS
1071
+ # =========================================================
1072
+
1073
+ classes = [
1074
+ "airplane",
1075
+ "bird",
1076
+ "car",
1077
+ "cat",
1078
+ "deer",
1079
+ "dog",
1080
+ "horse",
1081
+ "monkey",
1082
+ "ship",
1083
+ "truck"
1084
+ ]
1085
+
1086
+
1087
+ # =========================================================
1088
+ # 5. IMAGE TRANSFORM
1089
+ # =========================================================
1090
+
1091
+ transform = transforms.Compose([
1092
+
1093
+ transforms.Resize((96, 96)),
1094
+
1095
+ transforms.ToTensor(),
1096
+
1097
+ transforms.Normalize(
1098
+ (0.4914, 0.4822, 0.4465),
1099
+ (0.2470, 0.2435, 0.2616)
1100
+ )
1101
+ ])
1102
+
1103
+
1104
+ # =========================================================
1105
+ # 6. LOAD MODEL
1106
+ # =========================================================
1107
+
1108
+ model = LookThemSTLV1().to(device)
1109
+
1110
+ model.load_state_dict(
1111
+ torch.load(
1112
+ "LookThem_STL.pth",
1113
+ map_location=device
1114
+ )
1115
+ )
1116
+
1117
+ model.eval()
1118
+
1119
+ print("Model loaded successfully!")
1120
+
1121
+
1122
+ # =========================================================
1123
+ # 7. LOAD IMAGE
1124
+ # =========================================================
1125
+
1126
+ # Replace with your image path
1127
+ image_path = "test.jpg"
1128
+
1129
+ image = Image.open(image_path).convert("RGB")
1130
+
1131
+ input_tensor = transform(image)
1132
+
1133
+ # Add batch dimension
1134
+ input_tensor = input_tensor.unsqueeze(0).to(device)
1135
+
1136
+
1137
+ # =========================================================
1138
+ # 8. INFERENCE
1139
+ # =========================================================
1140
+
1141
+ with torch.no_grad():
1142
+
1143
+ output = model(input_tensor)
1144
+
1145
+ probabilities = F.softmax(output, dim=1)
1146
+
1147
+ confidence, predicted = torch.max(
1148
+ probabilities,
1149
+ dim=1
1150
+ )
1151
+
1152
+ predicted_class = classes[predicted.item()]
1153
+
1154
+ confidence_score = confidence.item() * 100
1155
+
1156
+
1157
+ # =========================================================
1158
+ # 9. RESULT
1159
+ # =========================================================
1160
+
1161
+ print("\n===== INFERENCE RESULT =====")
1162
+
1163
+ print(f"Predicted Class : {predicted_class}")
1164
+
1165
+ print(f"Confidence : {confidence_score:.2f}%")
1166
+
1167
+ print("\n===== CLASS PROBABILITIES =====")
1168
+
1169
+ for idx, class_name in enumerate(classes):
1170
+
1171
+ prob = probabilities[0][idx].item() * 100
1172
+
1173
+ print(f"{class_name:<10} : {prob:.2f}%")
1174
  ```