import torch # import wandb import os import yaml from peft import LoraConfig, get_peft_model_state_dict from torch.utils.data import DataLoader import time from typing import List, Tuple # import prodigyopt ### import copy from dataclasses import field, dataclass, asdict from typing import Sequence, Literal, Dict import transformers from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer from transformers import Trainer from transformers.modeling_utils import * from transformers.trainer import _is_peft_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.data.data_collator import DataCollator from transformers.training_args import TrainingArguments from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from torch.utils.data import Dataset, IterableDataset from datasets import load_dataset ## #from ..pipeline.flux_omini import transformer_forward, encode_images # from ...omini.rotation import RotationTuner, RotationConfig # from smpeft.sama import RotationTuner, RotationConfig from smpeft import PeftModel from .config import MainConfig, convert_to_trainer_args import draccus # from omegaconf import OmegaConf import argparse IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" PROMPT = ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ) @draccus.wrap() def main(mainCfg: MainConfig): print('='*120) model_name = mainCfg.model.model_name if mainCfg.model.merge_adapter_path is not None: adapter = mainCfg.model.merge_adapter_path + "/ft2" print(f'Merging... from mainCfg.model.merge_adapter_path {adapter}') elif mainCfg.model.adapter_path: adapter = mainCfg.model.adapter_path + "/ft2" print(f'From mainCfg.model.adapter_path {adapter}') else: raise KeyError('No adapter path') if mainCfg.model.merge_output_path is not None: output_path = mainCfg.model.merge_output_path else: output_path = mainCfg.model.merge_adapter_path + "/merge" model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",) tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto') # config = PeftConfig.from_pretrained(args.adapter) model = PeftModel.from_pretrained(model, adapter) model = model.merge_and_unload() model.save_pretrained(output_path, safe_serialization=False) tokenizer.save_pretrained(output_path) # print(model) print(f'The end: merge.py, from {adapter},\n \t \t to {output_path}') return if __name__ == "__main__": main()