|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from typing import Dict, List, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from transformers.utils import is_flash_attn_2_available |
|
|
|
|
|
import moe_peft |
|
|
import moe_peft.adapters |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="MoE-PEFT main program") |
|
|
parser.add_argument( |
|
|
"--base_model", type=str, required=True, help="Path to or name of base model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--inference", action="store_true", help="The inference mode (just for test)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--evaluate", action="store_true", help="The evaluate mode (just for test)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--disable_prompter", action="store_true", help="Disable prompter when inference" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load_adapter", |
|
|
action="store_true", |
|
|
help="Load adapter from file instead of init randomly", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--disable_adapter", action="store_true", help="Disable the adapter modules" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--attn_impl", type=str, help="Specify the implementation of attention" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sliding_window", |
|
|
action="store_true", |
|
|
help="Use sliding window attention (requires flash attention)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--disable_cache", |
|
|
action="store_true", |
|
|
help="Disable cache when inference", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--cache_implementation", |
|
|
type=str, |
|
|
help="Specify the implementation of cache", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp16", action="store_true", help="Load base model in float16 precision" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--bf16", action="store_true", help="Load base model in bfloat16 precision" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tf32", action="store_true", help="Use tfloat32 instead of float32 if available" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load_8bit", action="store_true", help="Load base model with 8bit quantization" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load_4bit", action="store_true", help="Load base model with 4bit quantization" |
|
|
) |
|
|
parser.add_argument("--device", type=str, help="Specify which GPU to be used") |
|
|
parser.add_argument( |
|
|
"--config", type=str, required=True, help="Path to finetune configuration" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", type=int, default=42, help="Random seed in integer, default is 42" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dir", type=str, default=".", help="Path to read or save checkpoints" |
|
|
) |
|
|
parser.add_argument("--disable_log", action="store_true", help="Disable logging") |
|
|
parser.add_argument("--log_file", type=str, help="Save log to specific file") |
|
|
parser.add_argument( |
|
|
"--verbose", action="store_true", help="Show extra informations such as parameters" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--overwrite", |
|
|
action="store_true", |
|
|
help="Overwrite adapter model when older one existed", |
|
|
) |
|
|
parser.add_argument("--debug", action="store_true", help="Enabling debugging mode") |
|
|
parser.add_argument( |
|
|
"--deterministic", |
|
|
action="store_true", |
|
|
help="Use deterministic algorithms to improve the reproducibility", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
def query_yes_no(question, default="no"): |
|
|
valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} |
|
|
if default is None: |
|
|
prompt = " [y/n] " |
|
|
elif default == "yes": |
|
|
prompt = " [Y/n] " |
|
|
elif default == "no": |
|
|
prompt = " [y/N] " |
|
|
else: |
|
|
raise ValueError("invalid default answer: '%s'" % default) |
|
|
|
|
|
while True: |
|
|
sys.stdout.write(question + prompt) |
|
|
choice = input().lower() |
|
|
if default is not None and choice == "": |
|
|
return valid[default] |
|
|
elif choice in valid: |
|
|
return valid[choice] |
|
|
else: |
|
|
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n") |
|
|
|
|
|
|
|
|
def load_base_model() -> Tuple[moe_peft.Tokenizer, moe_peft.LLMModel]: |
|
|
logging.info("Initializing pre-trained model.") |
|
|
model = moe_peft.LLMModel.from_pretrained( |
|
|
name_or_path=args.base_model, |
|
|
device=args.device, |
|
|
attn_impl=args.attn_impl, |
|
|
use_sliding_window=args.sliding_window, |
|
|
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)), |
|
|
load_dtype=( |
|
|
torch.bfloat16 |
|
|
if args.bf16 |
|
|
else (torch.float16 if args.fp16 else torch.float32) |
|
|
), |
|
|
) |
|
|
|
|
|
tokenizer = moe_peft.Tokenizer(args.base_model) |
|
|
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
def init_adapter_config( |
|
|
config: Dict[str, any], |
|
|
llm_model: moe_peft.LLMModel, |
|
|
) -> List[Union[moe_peft.GenerateConfig, moe_peft.TrainConfig]]: |
|
|
config_list = [] |
|
|
|
|
|
if config["cutoff_len"] == -1: |
|
|
config["cutoff_len"] = llm_model.config_.max_seq_len_ |
|
|
logging.info(f"Setting cutoff_len to {llm_model.config_.max_seq_len_} automatically.") |
|
|
|
|
|
for lora_config in config["lora"]: |
|
|
adapter_name = lora_config["name"] |
|
|
adapter_path = f"{args.dir}{os.sep}{adapter_name}" |
|
|
if not args.load_adapter and os.path.exists(adapter_path): |
|
|
if args.overwrite: |
|
|
logging.warning( |
|
|
f"Overwriting existed adapter model file: {adapter_path}" |
|
|
) |
|
|
elif not query_yes_no( |
|
|
f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?" |
|
|
): |
|
|
logging.info("User canceled training due to file conflict.") |
|
|
exit(0) |
|
|
|
|
|
if args.load_adapter: |
|
|
llm_model.load_adapter(adapter_path, adapter_name) |
|
|
else: |
|
|
llm_model.init_adapter(moe_peft.adapters.lora_config_factory(lora_config)) |
|
|
|
|
|
if args.inference: |
|
|
config_class = moe_peft.GenerateConfig(adapter_name=adapter_name) |
|
|
if not args.disable_prompter: |
|
|
config_class.prompt_template = lora_config.get("prompt", None) |
|
|
config_list.append(config_class) |
|
|
elif args.evaluate: |
|
|
config_list.extend(moe_peft.EvaluateConfig.from_config(lora_config)) |
|
|
else: |
|
|
config_list.append(moe_peft.TrainConfig.from_config(lora_config)) |
|
|
|
|
|
if args.verbose: |
|
|
logging.info(config_list[-1].__dict__) |
|
|
|
|
|
return config_list |
|
|
|
|
|
|
|
|
def inference_callback(cur_pos, outputs): |
|
|
print(f"POSITION: {cur_pos}") |
|
|
for adapter_name, output in outputs.items(): |
|
|
print(f"{adapter_name} OUTPUT: {output[0]}") |
|
|
|
|
|
|
|
|
def inference( |
|
|
model: moe_peft.LLMModel, |
|
|
tokenizer: moe_peft.Tokenizer, |
|
|
configs: List[moe_peft.GenerateConfig], |
|
|
concurrent_jobs: int, |
|
|
): |
|
|
while True: |
|
|
input_raw = input("INPUT WITHOUT PROMPT: ") |
|
|
if input_raw == "QUIT": |
|
|
return |
|
|
for config in configs: |
|
|
config.prompts = [input_raw] |
|
|
callback = None if args.disable_log else inference_callback |
|
|
outputs = moe_peft.generate( |
|
|
model, |
|
|
tokenizer, |
|
|
configs, |
|
|
max_gen_len=128, |
|
|
use_cache=not args.disable_cache, |
|
|
concurrent_jobs=concurrent_jobs, |
|
|
cache_implementation=args.cache_implementation, |
|
|
stream_callback=callback, |
|
|
) |
|
|
print(f"\n{'='*10}\n") |
|
|
print(f"PROMPT: {input_raw}") |
|
|
for adapter_name, output in outputs.items(): |
|
|
print(f"{adapter_name} OUTPUT:") |
|
|
print(output[0]) |
|
|
print(f"\n{'='*10}\n") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if args.debug: |
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
|
|
if args.inference or args.evaluate: |
|
|
args.load_adapter = True |
|
|
inference_mode = True |
|
|
else: |
|
|
inference_mode = False |
|
|
|
|
|
moe_peft.setup_logging("INFO", args.log_file) |
|
|
|
|
|
moe_peft_executor = moe_peft.executor |
|
|
|
|
|
if not moe_peft_executor.check_available(): |
|
|
exit(-1) |
|
|
|
|
|
if args.attn_impl is None: |
|
|
if ( |
|
|
inference_mode |
|
|
and moe_peft_executor.device_name() == "cuda" |
|
|
and is_flash_attn_2_available() |
|
|
): |
|
|
args.attn_impl = "flash_attn" |
|
|
else: |
|
|
args.attn_impl = "eager" |
|
|
|
|
|
if args.device is None: |
|
|
args.device = moe_peft.executor.default_device_name() |
|
|
|
|
|
moe_peft_executor.use_deterministic_algorithms(args.deterministic) |
|
|
moe_peft_executor.allow_tf32(args.tf32) |
|
|
moe_peft_executor.manual_seed(args.seed) |
|
|
|
|
|
with open(args.config, "r", encoding="utf8") as fp: |
|
|
config = json.load(fp) |
|
|
|
|
|
tokenizer, model = load_base_model() |
|
|
adapters = init_adapter_config(config, model) |
|
|
|
|
|
moe_peft_executor.empty_cache() |
|
|
|
|
|
if os.getenv("MOE_PEFT_EVALUATE_MODE") is None: |
|
|
logging.info("Using efficient operators.") |
|
|
else: |
|
|
logging.info("Using deterministic operators.") |
|
|
|
|
|
if args.inference: |
|
|
inference( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
configs=adapters, |
|
|
concurrent_jobs=config.get("inference_lora_simultaneously_num", 2), |
|
|
) |
|
|
elif args.evaluate: |
|
|
moe_peft.evaluate( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
configs=adapters, |
|
|
max_concurrent_jobs=config.get("eval_lora_simultaneously_num", None), |
|
|
retrying_steps=config.get("eval_rollback_retrying_steps", 20), |
|
|
max_seq_len=config["cutoff_len"], |
|
|
save_file=config.get("evaluate_result", None), |
|
|
require_attention = -1, |
|
|
require_hide = -1, |
|
|
) |
|
|
else: |
|
|
moe_peft.train( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
configs=adapters, |
|
|
max_concurrent_jobs=config.get("train_lora_simultaneously_num", None), |
|
|
strategy=config["train_strategy"], |
|
|
cutoff_len=config["cutoff_len"], |
|
|
save_step=config["save_step"], |
|
|
save_dir=args.dir, |
|
|
) |
|
|
|