Update README.md
Browse files
README.md
CHANGED
|
@@ -56,6 +56,7 @@ final_layer_keys = {
|
|
| 56 |
"final_layer.linear.bias",
|
| 57 |
}
|
| 58 |
|
|
|
|
| 59 |
# a) Contextual downstream architecture
|
| 60 |
# ----------------------------------
|
| 61 |
model = SignalJEPA_Contextual(
|
|
@@ -74,6 +75,7 @@ FILTERED_model_state_dict = {
|
|
| 74 |
k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
|
| 75 |
}
|
| 76 |
|
|
|
|
| 77 |
# b) Post-local downstream architecture
|
| 78 |
# ----------------------------------
|
| 79 |
model = SignalJEPA_PostLocal(
|
|
@@ -86,6 +88,7 @@ missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict,
|
|
| 86 |
assert unexpected_keys == []
|
| 87 |
assert set(missing_keys) == final_layer_keys
|
| 88 |
|
|
|
|
| 89 |
# c) Pre-local architecture
|
| 90 |
# ----------------------
|
| 91 |
model = SignalJEPA_PreLocal(
|
|
|
|
| 56 |
"final_layer.linear.bias",
|
| 57 |
}
|
| 58 |
|
| 59 |
+
|
| 60 |
# a) Contextual downstream architecture
|
| 61 |
# ----------------------------------
|
| 62 |
model = SignalJEPA_Contextual(
|
|
|
|
| 75 |
k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
|
| 76 |
}
|
| 77 |
|
| 78 |
+
|
| 79 |
# b) Post-local downstream architecture
|
| 80 |
# ----------------------------------
|
| 81 |
model = SignalJEPA_PostLocal(
|
|
|
|
| 88 |
assert unexpected_keys == []
|
| 89 |
assert set(missing_keys) == final_layer_keys
|
| 90 |
|
| 91 |
+
|
| 92 |
# c) Pre-local architecture
|
| 93 |
# ----------------------
|
| 94 |
model = SignalJEPA_PreLocal(
|