Update modeling_phi.py
Browse files- modeling_phi.py +65 -18
modeling_phi.py
CHANGED
|
@@ -928,27 +928,74 @@ class PhiModel(PhiPreTrainedModel):
|
|
| 928 |
return hidden_states
|
| 929 |
|
| 930 |
group_definitions = [
|
| 931 |
-
list(range(0,
|
| 932 |
-
list(range(
|
| 933 |
-
list(range(
|
| 934 |
-
list(range(
|
| 935 |
-
list(range(
|
| 936 |
-
list(range(
|
| 937 |
-
list(range(
|
| 938 |
-
list(range(
|
| 939 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 940 |
]
|
| 941 |
|
| 942 |
repetitions = [
|
| 943 |
-
1,
|
| 944 |
-
2,
|
| 945 |
-
2,
|
| 946 |
-
2,
|
| 947 |
-
2,
|
| 948 |
-
2,
|
| 949 |
-
2,
|
| 950 |
-
|
| 951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
]
|
| 953 |
|
| 954 |
class AdvancedSharedLayerModule(nn.Module):
|
|
|
|
| 928 |
return hidden_states
|
| 929 |
|
| 930 |
group_definitions = [
|
| 931 |
+
list(range(0, 1)), # Indices for the 1st group
|
| 932 |
+
list(range(1, 2)),
|
| 933 |
+
list(range(2, 3)),
|
| 934 |
+
list(range(3, 4)),
|
| 935 |
+
list(range(4, 5)),
|
| 936 |
+
list(range(5, 6)),
|
| 937 |
+
list(range(6, 7)),
|
| 938 |
+
list(range(7, 8)),
|
| 939 |
+
list(range(8, 9)),
|
| 940 |
+
list(range(9, 10)),
|
| 941 |
+
list(range(10, 11)),
|
| 942 |
+
list(range(11, 12)),
|
| 943 |
+
list(range(12, 13)),
|
| 944 |
+
list(range(13, 14)),
|
| 945 |
+
list(range(14, 15)),
|
| 946 |
+
list(range(15, 16)),
|
| 947 |
+
list(range(16, 17)),
|
| 948 |
+
list(range(17, 18)),
|
| 949 |
+
list(range(18, 19)),
|
| 950 |
+
list(range(19, 20)),
|
| 951 |
+
list(range(20, 21)),
|
| 952 |
+
list(range(21, 22)),
|
| 953 |
+
list(range(22, 23)),
|
| 954 |
+
list(range(23, 24)),
|
| 955 |
+
list(range(24, 25)),
|
| 956 |
+
list(range(25, 26)),
|
| 957 |
+
list(range(26, 27)),
|
| 958 |
+
list(range(27, 28)),
|
| 959 |
+
list(range(28, 29)),
|
| 960 |
+
list(range(29, 30)),
|
| 961 |
+
list(range(30, 31)),
|
| 962 |
+
list(range(31, 32)),
|
| 963 |
]
|
| 964 |
|
| 965 |
repetitions = [
|
| 966 |
+
1,
|
| 967 |
+
2,
|
| 968 |
+
2,
|
| 969 |
+
2,
|
| 970 |
+
2,
|
| 971 |
+
2,
|
| 972 |
+
2,
|
| 973 |
+
2,
|
| 974 |
+
2,
|
| 975 |
+
2,
|
| 976 |
+
2,
|
| 977 |
+
2,
|
| 978 |
+
2,
|
| 979 |
+
2,
|
| 980 |
+
2,
|
| 981 |
+
2,
|
| 982 |
+
2,
|
| 983 |
+
2,
|
| 984 |
+
2,
|
| 985 |
+
2,
|
| 986 |
+
2,
|
| 987 |
+
2,
|
| 988 |
+
2,
|
| 989 |
+
2,
|
| 990 |
+
2,
|
| 991 |
+
2,
|
| 992 |
+
2,
|
| 993 |
+
2,
|
| 994 |
+
2,
|
| 995 |
+
2,
|
| 996 |
+
2,
|
| 997 |
+
1,
|
| 998 |
+
|
| 999 |
]
|
| 1000 |
|
| 1001 |
class AdvancedSharedLayerModule(nn.Module):
|