gagannarula commited on
Commit
084684d
·
verified ·
1 Parent(s): 93182a6

Update model_merging method to reflect main repo

Browse files

Last version was moving the model more and more towards the base model because the model merging value becomes smaller every call to `generate`

Files changed (1) hide show
  1. NatureLM/models/NatureLM.py +10 -4
NatureLM/models/NatureLM.py CHANGED
@@ -775,10 +775,16 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
775
  adapter_name (str): The name of the adapter to rescale when merging.
776
  """
777
 
778
- for module in self.llama_model.modules():
779
- # Check if the module is a LoRA layer and has the specified adapter
780
- if hasattr(module, "r") and isinstance(module.r, dict) and adapter_name in module.r:
781
- module.scaling[adapter_name] = merging_alpha * module.scaling[adapter_name]
 
 
 
 
 
 
782
 
783
  @torch.inference_mode()
784
  def generate(self, samples, generate_cfg, prompts) -> list[str]:
 
775
  adapter_name (str): The name of the adapter to rescale when merging.
776
  """
777
 
778
+ # Store original scaling on first call, then always scale relative to original
779
+ if not hasattr(self, "_original_lora_scaling"):
780
+ self._original_lora_scaling = {}
781
+ for name, module in self.llama_model.named_modules():
782
+ if hasattr(module, "r") and isinstance(module.r, dict) and adapter_name in module.r:
783
+ self._original_lora_scaling[name] = module.scaling[adapter_name]
784
+
785
+ for name, module in self.llama_model.named_modules():
786
+ if name in self._original_lora_scaling:
787
+ module.scaling[adapter_name] = merging_alpha * self._original_lora_scaling[name]
788
 
789
  @torch.inference_mode()
790
  def generate(self, samples, generate_cfg, prompts) -> list[str]: