lfm_complete_code / merge_model.py
Techiiot's picture
Upload folder using huggingface_hub
27c46c6 verified
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from peft import PeftModel
# import torch
# print("Loading base model...")
# base_model = AutoModelForCausalLM.from_pretrained(
# "./models/LFM2-1.2B",
# torch_dtype=torch.bfloat16,
# device_map="auto",
# trust_remote_code=True
# )
# print("Loading LoRA adapters...")
# model = PeftModel.from_pretrained(base_model, "./counselor_model/final_model")
# print("Merging adapters with base model...")
# merged_model = model.merge_and_unload()
# print("Saving merged model...")
# merged_model.save_pretrained("./counselor_model-merged", safe_serialization=True)
# tokenizer = AutoTokenizer.from_pretrained("./models/LFM2-1.2B")
# tokenizer.save_pretrained("./counselor_model-merged")
# print("Model merge complete!")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import os
def merge_and_save_model(
base_model_name: str = "LiquidAI/LFM2-2.6B",
adapter_path: str = "./lfm_minimal_output/final_model",
output_path: str = "./merged_counselor_minimal_2b"
):
"""
Properly merge LoRA weights with base model
"""
print("Loading base model...")
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("Loading LoRA adapter...")
# Load the PEFT model (LoRA adapter)
model = PeftModel.from_pretrained(
base_model,
adapter_path,
torch_dtype=torch.float16,
)
print("Merging weights...")
# Merge LoRA weights with base model
model = model.merge_and_unload()
print(f"Saving merged model to {output_path}...")
# Save the merged model
model.save_pretrained(output_path)
# Also save the tokenizer
tokenizer = AutoTokenizer.from_pretrained(adapter_path)
tokenizer.save_pretrained(output_path)
print("✅ Model merged and saved successfully!")
return model, tokenizer
# Run the merge
if __name__ == "__main__":
merge_and_save_model()