Update instruction_template_retriever.py
Browse files
instruction_template_retriever.py
CHANGED
|
@@ -142,6 +142,7 @@ def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
|
|
| 142 |
sigma (float): Standard deviation for Gaussian weighting.
|
| 143 |
alpha (float): Weighting factor for merging with standard mean pooling.
|
| 144 |
"""
|
|
|
|
| 145 |
if isinstance(m[1], GaussianCoveragePooling):
|
| 146 |
m = unuse_gaussian_coverage_pooling(m)
|
| 147 |
word_embedding_model = m[0]
|
|
@@ -151,6 +152,7 @@ def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
|
|
| 151 |
old_pooling = m[1]
|
| 152 |
new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
|
| 153 |
new_m.old_pooling = {"old_pooling": old_pooling}
|
|
|
|
| 154 |
return new_m
|
| 155 |
|
| 156 |
|
|
|
|
| 142 |
sigma (float): Standard deviation for Gaussian weighting.
|
| 143 |
alpha (float): Weighting factor for merging with standard mean pooling.
|
| 144 |
"""
|
| 145 |
+
old_device = m.device
|
| 146 |
if isinstance(m[1], GaussianCoveragePooling):
|
| 147 |
m = unuse_gaussian_coverage_pooling(m)
|
| 148 |
word_embedding_model = m[0]
|
|
|
|
| 152 |
old_pooling = m[1]
|
| 153 |
new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
|
| 154 |
new_m.old_pooling = {"old_pooling": old_pooling}
|
| 155 |
+
new_m = new_m.to(old_device)
|
| 156 |
return new_m
|
| 157 |
|
| 158 |
|