sattwik21 commited on
Commit
a681cee
·
verified ·
1 Parent(s): bea330b

Upload modeling_gestr_jepa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_gestr_jepa.py +44 -0
modeling_gestr_jepa.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PretrainedConfig, PreTrainedModel
4
+
5
+ # --- 1. Configuration Class ---
6
+ class GestrJEPAConfig(PretrainedConfig):
7
+ model_type = "gestr-jepa"
8
+
9
+ def __init__(
10
+ self,
11
+ input_dim=16,
12
+ embed_dim=64,
13
+ hidden_dim=256,
14
+ **kwargs
15
+ ):
16
+ self.input_dim = input_dim
17
+ self.embed_dim = embed_dim
18
+ self.hidden_dim = hidden_dim
19
+ super().__init__(**kwargs)
20
+
21
+ # --- 2. The Model Wrapper ---
22
+ class GestrJEPAForClassification(PreTrainedModel):
23
+ config_class = GestrJEPAConfig
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.encoder = nn.Sequential(
28
+ nn.Linear(config.input_dim, config.hidden_dim),
29
+ nn.GELU(),
30
+ nn.Linear(config.hidden_dim, config.hidden_dim),
31
+ nn.GELU(),
32
+ nn.Linear(config.hidden_dim, config.embed_dim),
33
+ nn.LayerNorm(config.embed_dim)
34
+ )
35
+ self.classifier = nn.Linear(config.embed_dim, config.num_labels)
36
+
37
+ def forward(self, sensor_values, labels=None):
38
+ embeddings = self.encoder(sensor_values)
39
+ logits = self.classifier(embeddings)
40
+ loss = None
41
+ if labels is not None:
42
+ loss_fct = nn.CrossEntropyLoss()
43
+ loss = loss_fct(logits, labels.view(-1))
44
+ return {"loss": loss, "logits": logits, "hidden_states": embeddings}