Yujivus commited on
Commit
70a8cc0
·
verified ·
1 Parent(s): e139839

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -9,7 +9,7 @@ datasets:
9
  metrics:
10
  - bleu
11
  model-index:
12
- - name: Yujivus/PRISM-Protomolecule
13
  results:
14
  - task:
15
  type: translation
 
9
  metrics:
10
  - bleu
11
  model-index:
12
+ - name: Yujivus/PRISM-Molecule
13
  results:
14
  - task:
15
  type: translation
__pycache__/modeling_prism_gated.cpython-312.pyc CHANGED
Binary files a/__pycache__/modeling_prism_gated.cpython-312.pyc and b/__pycache__/modeling_prism_gated.cpython-312.pyc differ
 
config.json CHANGED
@@ -8,5 +8,5 @@
8
  "num_encoder_layers": 6,
9
  "num_refining_layers": 0,
10
  "num_decoder_layers": 6,
11
- "architecture": "PRISM_Protomolecule"
12
  }
 
8
  "num_encoder_layers": 6,
9
  "num_refining_layers": 0,
10
  "num_decoder_layers": 6,
11
+ "architecture": "PRISM_Molecule"
12
  }
modeling_prism_gated.py CHANGED
@@ -81,7 +81,7 @@ class PRISMLayer(nn.Module):
81
  self.filter_len = max_len
82
 
83
  # 1. THE GATE (Data Dependency)
84
- self.gate_proj = nn.Linear(d_model * 2, d_model * 2)
85
 
86
  # 2. THE FILTER (Global Pattern)
87
  self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)
@@ -89,7 +89,7 @@ class PRISMLayer(nn.Module):
89
  # 3. INPUT MIXING
90
  self.mix_real = nn.Linear(d_model, d_model)
91
  self.mix_imag = nn.Linear(d_model, d_model)
92
-
93
  # 4. OUTPUT PROJECTION
94
  self.out_real = nn.Linear(d_model, d_model)
95
  self.out_imag = nn.Linear(d_model, d_model)
@@ -116,7 +116,7 @@ class PRISMLayer(nn.Module):
116
  x_cat = torch.cat([x_norm.real, x_norm.imag], dim=-1)
117
  gates = torch.sigmoid(self.gate_proj(x_cat))
118
  gate_r, gate_i = gates.chunk(2, dim=-1)
119
-
120
  # B. FILTER
121
  B, L, D = x_norm.shape
122
  x_freq = torch.fft.fft(x_norm, n=self.filter_len, dim=1)
@@ -235,7 +235,7 @@ class PRISMHybrid_RoPE(nn.Module):
235
  output = self.decoder(tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask)
236
 
237
  return self.final_linear(output)
238
-
239
  # ... (generate function remains the same) ...
240
  @torch.no_grad()
241
  def generate(self, src, max_length, num_beams=5):
 
81
  self.filter_len = max_len
82
 
83
  # 1. THE GATE (Data Dependency)
84
+ self.gate_proj = nn.Linear(d_model * 2, d_model * 2)
85
 
86
  # 2. THE FILTER (Global Pattern)
87
  self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)
 
89
  # 3. INPUT MIXING
90
  self.mix_real = nn.Linear(d_model, d_model)
91
  self.mix_imag = nn.Linear(d_model, d_model)
92
+
93
  # 4. OUTPUT PROJECTION
94
  self.out_real = nn.Linear(d_model, d_model)
95
  self.out_imag = nn.Linear(d_model, d_model)
 
116
  x_cat = torch.cat([x_norm.real, x_norm.imag], dim=-1)
117
  gates = torch.sigmoid(self.gate_proj(x_cat))
118
  gate_r, gate_i = gates.chunk(2, dim=-1)
119
+
120
  # B. FILTER
121
  B, L, D = x_norm.shape
122
  x_freq = torch.fft.fft(x_norm, n=self.filter_len, dim=1)
 
235
  output = self.decoder(tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask)
236
 
237
  return self.final_linear(output)
238
+
239
  # ... (generate function remains the same) ...
240
  @torch.no_grad()
241
  def generate(self, src, max_length, num_beams=5):