c2cite / c2cite.py
loadingy's picture
first push
51be264
import argparse
import json
import logging
import os
import sys
from typing import Dict, List, Tuple, Union
import torch
from transformers.utils import is_flash_attn_2_available
import moe_peft
import moe_peft.adapters
# Command Line Arguments
parser = argparse.ArgumentParser(description="MoE-PEFT main program")
parser.add_argument(
"--base_model", type=str, required=True, help="Path to or name of base model"
)
parser.add_argument(
"--inference", action="store_true", help="The inference mode (just for test)"
)
parser.add_argument(
"--evaluate", action="store_true", help="The evaluate mode (just for test)"
)
parser.add_argument(
"--disable_prompter", action="store_true", help="Disable prompter when inference"
)
parser.add_argument(
"--load_adapter",
action="store_true",
help="Load adapter from file instead of init randomly",
)
parser.add_argument(
"--disable_adapter", action="store_true", help="Disable the adapter modules"
)
parser.add_argument(
"--attn_impl", type=str, help="Specify the implementation of attention"
)
parser.add_argument(
"--sliding_window",
action="store_true",
help="Use sliding window attention (requires flash attention)",
)
parser.add_argument(
"--disable_cache",
action="store_true",
help="Disable cache when inference",
)
parser.add_argument(
"--cache_implementation",
type=str,
help="Specify the implementation of cache",
)
parser.add_argument(
"--fp16", action="store_true", help="Load base model in float16 precision"
)
parser.add_argument(
"--bf16", action="store_true", help="Load base model in bfloat16 precision"
)
parser.add_argument(
"--tf32", action="store_true", help="Use tfloat32 instead of float32 if available"
)
parser.add_argument(
"--load_8bit", action="store_true", help="Load base model with 8bit quantization"
)
parser.add_argument(
"--load_4bit", action="store_true", help="Load base model with 4bit quantization"
)
parser.add_argument("--device", type=str, help="Specify which GPU to be used")
parser.add_argument(
"--config", type=str, required=True, help="Path to finetune configuration"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed in integer, default is 42"
)
parser.add_argument(
"--dir", type=str, default=".", help="Path to read or save checkpoints"
)
parser.add_argument("--disable_log", action="store_true", help="Disable logging")
parser.add_argument("--log_file", type=str, help="Save log to specific file")
parser.add_argument(
"--verbose", action="store_true", help="Show extra informations such as parameters"
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite adapter model when older one existed",
)
parser.add_argument("--debug", action="store_true", help="Enabling debugging mode")
parser.add_argument(
"--deterministic",
action="store_true",
help="Use deterministic algorithms to improve the reproducibility",
)
args = parser.parse_args()
def query_yes_no(question, default="no"):
valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
if default is None:
prompt = " [y/n] "
elif default == "yes":
prompt = " [Y/n] "
elif default == "no":
prompt = " [y/N] "
else:
raise ValueError("invalid default answer: '%s'" % default)
while True:
sys.stdout.write(question + prompt)
choice = input().lower()
if default is not None and choice == "":
return valid[default]
elif choice in valid:
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
def load_base_model() -> Tuple[moe_peft.Tokenizer, moe_peft.LLMModel]:
logging.info("Initializing pre-trained model.")
model = moe_peft.LLMModel.from_pretrained(
name_or_path=args.base_model,
device=args.device,
attn_impl=args.attn_impl,
use_sliding_window=args.sliding_window,
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
load_dtype=(
torch.bfloat16
if args.bf16
else (torch.float16 if args.fp16 else torch.float32)
),
)
tokenizer = moe_peft.Tokenizer(args.base_model)
return tokenizer, model
def init_adapter_config(
config: Dict[str, any],
llm_model: moe_peft.LLMModel,
) -> List[Union[moe_peft.GenerateConfig, moe_peft.TrainConfig]]:
config_list = []
if config["cutoff_len"] == -1:
config["cutoff_len"] = llm_model.config_.max_seq_len_
logging.info(f"Setting cutoff_len to {llm_model.config_.max_seq_len_} automatically.")
for lora_config in config["lora"]:
adapter_name = lora_config["name"]
adapter_path = f"{args.dir}{os.sep}{adapter_name}"
if not args.load_adapter and os.path.exists(adapter_path):
if args.overwrite:
logging.warning(
f"Overwriting existed adapter model file: {adapter_path}"
)
elif not query_yes_no(
f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?"
):
logging.info("User canceled training due to file conflict.")
exit(0)
if args.load_adapter:
llm_model.load_adapter(adapter_path, adapter_name)
else:
llm_model.init_adapter(moe_peft.adapters.lora_config_factory(lora_config))
if args.inference:
config_class = moe_peft.GenerateConfig(adapter_name=adapter_name)
if not args.disable_prompter:
config_class.prompt_template = lora_config.get("prompt", None)
config_list.append(config_class)
elif args.evaluate:
config_list.extend(moe_peft.EvaluateConfig.from_config(lora_config))
else:
config_list.append(moe_peft.TrainConfig.from_config(lora_config))
if args.verbose:
logging.info(config_list[-1].__dict__)
return config_list
def inference_callback(cur_pos, outputs):
print(f"POSITION: {cur_pos}")
for adapter_name, output in outputs.items():
print(f"{adapter_name} OUTPUT: {output[0]}")
def inference(
model: moe_peft.LLMModel,
tokenizer: moe_peft.Tokenizer,
configs: List[moe_peft.GenerateConfig],
concurrent_jobs: int,
):
while True:
input_raw = input("INPUT WITHOUT PROMPT: ")
if input_raw == "QUIT":
return
for config in configs:
config.prompts = [input_raw]
callback = None if args.disable_log else inference_callback
outputs = moe_peft.generate(
model,
tokenizer,
configs,
max_gen_len=128,
use_cache=not args.disable_cache,
concurrent_jobs=concurrent_jobs,
cache_implementation=args.cache_implementation,
stream_callback=callback,
)
print(f"\n{'='*10}\n")
print(f"PROMPT: {input_raw}")
for adapter_name, output in outputs.items():
print(f"{adapter_name} OUTPUT:")
print(output[0])
print(f"\n{'='*10}\n")
# Main Function
if __name__ == "__main__":
if args.debug:
torch.autograd.set_detect_anomaly(True)
if args.inference or args.evaluate:
args.load_adapter = True
inference_mode = True
else:
inference_mode = False
#args.load_adapter = False##############################
moe_peft.setup_logging("INFO", args.log_file)
moe_peft_executor = moe_peft.executor
if not moe_peft_executor.check_available():
exit(-1)
if args.attn_impl is None:
if (
inference_mode
and moe_peft_executor.device_name() == "cuda"
and is_flash_attn_2_available()
):
args.attn_impl = "flash_attn"
else:
args.attn_impl = "eager"
if args.device is None:
args.device = moe_peft.executor.default_device_name()
moe_peft_executor.use_deterministic_algorithms(args.deterministic)
moe_peft_executor.allow_tf32(args.tf32)
moe_peft_executor.manual_seed(args.seed)
with open(args.config, "r", encoding="utf8") as fp:
config = json.load(fp)
tokenizer, model = load_base_model()
adapters = init_adapter_config(config, model)
moe_peft_executor.empty_cache()
if os.getenv("MOE_PEFT_EVALUATE_MODE") is None:
logging.info("Using efficient operators.")
else:
logging.info("Using deterministic operators.")
if args.inference:
inference(
model=model,
tokenizer=tokenizer,
configs=adapters,
concurrent_jobs=config.get("inference_lora_simultaneously_num", 2),
)
elif args.evaluate:
moe_peft.evaluate(
model=model,
tokenizer=tokenizer,
configs=adapters,
max_concurrent_jobs=config.get("eval_lora_simultaneously_num", None),
retrying_steps=config.get("eval_rollback_retrying_steps", 20),
max_seq_len=config["cutoff_len"],
save_file=config.get("evaluate_result", None),
require_attention = -1,
require_hide = -1,
)
else:
moe_peft.train(
model=model,
tokenizer=tokenizer,
configs=adapters,
max_concurrent_jobs=config.get("train_lora_simultaneously_num", None),
strategy=config["train_strategy"],
cutoff_len=config["cutoff_len"],
save_step=config["save_step"],
save_dir=args.dir,
)