falmuqhim commited on
Commit
0e3e2de
·
verified ·
1 Parent(s): 4b6ee13

Update modeling_neuroclr.py

Browse files
Files changed (1) hide show
  1. modeling_neuroclr.py +16 -0
modeling_neuroclr.py CHANGED
@@ -56,6 +56,22 @@ class NeuroCLR(nn.Module):
56
 
57
  return h, z
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # --------------------------
61
  # Your ResNet1D head (verbatim)
 
56
 
57
  return h, z
58
 
59
+ class NeuroCLRModel(PreTrainedModel):
60
+ """
61
+ Loads with:
62
+ AutoModel.from_pretrained(..., trust_remote_code=True)
63
+ """
64
+ config_class = NeuroCLRConfig
65
+ base_model_prefix = "neuroclr"
66
+
67
+ def __init__(self, config: NeuroCLRConfig):
68
+ super().__init__(config)
69
+ self.neuroclr = NeuroCLR(config)
70
+ self.post_init()
71
+
72
+ def forward(self, x: torch.Tensor, **kwargs):
73
+ h, z = self.neuroclr(x)
74
+ return {"h": h, "z": z}
75
 
76
  # --------------------------
77
  # Your ResNet1D head (verbatim)