srSergio commited on
Commit
5ab090b
verified
1 Parent(s): d0990fe

Upload saju_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. saju_inference.py +12 -1
saju_inference.py CHANGED
@@ -98,12 +98,23 @@ class SajuAttentionModel(nn.Module):
98
  self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True, dropout=0.3)
99
  self.alpha = nn.Parameter(torch.tensor(0.5))
100
 
 
 
 
 
 
 
 
 
101
  def forward(self, year_emb, month_emb, day_emb, time_emb):
102
  pillars_seq = torch.stack([year_emb, month_emb, day_emb, time_emb], dim=1)
103
  attn_out, attn_weights = self.attention(pillars_seq, pillars_seq, pillars_seq)
104
  saju_context = attn_out.mean(dim=1)
105
  base_saju = pillars_seq.mean(dim=1)
106
- return base_saju + (self.alpha * saju_context), attn_weights
 
 
 
107
 
108
  # -------------------------------------------------------------------
109
  # 3. Genuine Universal Engine (Hugging Face Ready)
 
98
  self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True, dropout=0.3)
99
  self.alpha = nn.Parameter(torch.tensor(0.5))
100
 
101
+ # Red Neuronal Profunda de Proyecci贸n (Non-linear mapping)
102
+ self.projection = nn.Sequential(
103
+ nn.Linear(embedding_dim, embedding_dim * 2),
104
+ nn.GELU(),
105
+ nn.Dropout(0.2),
106
+ nn.Linear(embedding_dim * 2, embedding_dim)
107
+ )
108
+
109
  def forward(self, year_emb, month_emb, day_emb, time_emb):
110
  pillars_seq = torch.stack([year_emb, month_emb, day_emb, time_emb], dim=1)
111
  attn_out, attn_weights = self.attention(pillars_seq, pillars_seq, pillars_seq)
112
  saju_context = attn_out.mean(dim=1)
113
  base_saju = pillars_seq.mean(dim=1)
114
+
115
+ # Conexi贸n residual (Deltas perfectos)
116
+ combined = base_saju + (self.alpha * saju_context)
117
+ return combined + self.projection(combined), attn_weights
118
 
119
  # -------------------------------------------------------------------
120
  # 3. Genuine Universal Engine (Hugging Face Ready)