Update modeling_neuroclr.py
Browse files- modeling_neuroclr.py +16 -0
modeling_neuroclr.py
CHANGED
|
@@ -56,6 +56,22 @@ class NeuroCLR(nn.Module):
|
|
| 56 |
|
| 57 |
return h, z
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# --------------------------
|
| 61 |
# Your ResNet1D head (verbatim)
|
|
|
|
| 56 |
|
| 57 |
return h, z
|
| 58 |
|
| 59 |
+
class NeuroCLRModel(PreTrainedModel):
|
| 60 |
+
"""
|
| 61 |
+
Loads with:
|
| 62 |
+
AutoModel.from_pretrained(..., trust_remote_code=True)
|
| 63 |
+
"""
|
| 64 |
+
config_class = NeuroCLRConfig
|
| 65 |
+
base_model_prefix = "neuroclr"
|
| 66 |
+
|
| 67 |
+
def __init__(self, config: NeuroCLRConfig):
|
| 68 |
+
super().__init__(config)
|
| 69 |
+
self.neuroclr = NeuroCLR(config)
|
| 70 |
+
self.post_init()
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
| 73 |
+
h, z = self.neuroclr(x)
|
| 74 |
+
return {"h": h, "z": z}
|
| 75 |
|
| 76 |
# --------------------------
|
| 77 |
# Your ResNet1D head (verbatim)
|