Spaces:
Running
on
Zero
Running
on
Zero
Update model_merging method to reflect main repo
Browse filesLast version was moving the model more and more towards the base model because the model merging value becomes smaller every call to `generate`
- NatureLM/models/NatureLM.py +10 -4
NatureLM/models/NatureLM.py
CHANGED
|
@@ -775,10 +775,16 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
|
|
| 775 |
adapter_name (str): The name of the adapter to rescale when merging.
|
| 776 |
"""
|
| 777 |
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
|
| 783 |
@torch.inference_mode()
|
| 784 |
def generate(self, samples, generate_cfg, prompts) -> list[str]:
|
|
|
|
| 775 |
adapter_name (str): The name of the adapter to rescale when merging.
|
| 776 |
"""
|
| 777 |
|
| 778 |
+
# Store original scaling on first call, then always scale relative to original
|
| 779 |
+
if not hasattr(self, "_original_lora_scaling"):
|
| 780 |
+
self._original_lora_scaling = {}
|
| 781 |
+
for name, module in self.llama_model.named_modules():
|
| 782 |
+
if hasattr(module, "r") and isinstance(module.r, dict) and adapter_name in module.r:
|
| 783 |
+
self._original_lora_scaling[name] = module.scaling[adapter_name]
|
| 784 |
+
|
| 785 |
+
for name, module in self.llama_model.named_modules():
|
| 786 |
+
if name in self._original_lora_scaling:
|
| 787 |
+
module.scaling[adapter_name] = merging_alpha * self._original_lora_scaling[name]
|
| 788 |
|
| 789 |
@torch.inference_mode()
|
| 790 |
def generate(self, samples, generate_cfg, prompts) -> list[str]:
|