test
Browse files
mar.py
CHANGED
|
@@ -103,7 +103,21 @@ class MARBert(nn.Module):
|
|
| 103 |
grad_checkpointing=grad_checkpointing
|
| 104 |
)
|
| 105 |
self.diffusion_batch_mul = diffusion_batch_mul
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def initialize_weights(self):
|
| 108 |
# parameters
|
| 109 |
torch.nn.init.normal_(self.class_emb.weight, std=.02)
|
|
|
|
| 103 |
grad_checkpointing=grad_checkpointing
|
| 104 |
)
|
| 105 |
self.diffusion_batch_mul = diffusion_batch_mul
|
| 106 |
+
print("test")
|
| 107 |
+
print("test")
|
| 108 |
+
print("test")
|
| 109 |
+
print("test")
|
| 110 |
+
print("test")
|
| 111 |
+
print("test")
|
| 112 |
+
print("test")
|
| 113 |
+
print("test")
|
| 114 |
+
print("test")
|
| 115 |
+
print("test")
|
| 116 |
+
print("test")
|
| 117 |
+
print("test")
|
| 118 |
+
print("test")
|
| 119 |
+
print("test")
|
| 120 |
+
|
| 121 |
def initialize_weights(self):
|
| 122 |
# parameters
|
| 123 |
torch.nn.init.normal_(self.class_emb.weight, std=.02)
|