| QLoRA+百万数据对baichun-7b模型进行高效指令微调 |
|
|
| 更多详情请查看Github项目: [Firefly(流萤): 中文对话式大语言模型(全量微调+QLoRA)](https://github.com/yangjianxin1/Firefly) |
|
|
| 单轮对话脚本: |
|
|
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| model_name = 'YeungNLP/firefly-baichuan-7b-qlora-sft-merge' |
| max_new_tokens = 500 |
| top_p = 0.9 |
| temperature = 0.35 |
| repetition_penalty = 1.0 |
| device = 'cuda' |
| input_pattern = '<s>{}</s>' |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float16, |
| device_map='auto' |
| ) |
| model.eval() |
| model = model.to(device) |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| text = input('User:') |
| while True: |
| text = input_pattern.format(text) |
| input_ids = tokenizer(text, return_tensors="pt").input_ids |
| input_ids = input_ids.to(device) |
| outputs = model.generate( |
| input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, |
| top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| rets = tokenizer.batch_decode(outputs) |
| output = rets[0].strip().replace(text, "").replace('</s>', "") |
| print("Firefly:{}".format(output)) |
| text = input('User:') |
| ``` |
|
|
|
|
| 多轮对话脚本: |
|
|
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| device = 'cuda' |
| model_name = 'YeungNLP/firefly-baichuan-7b1-qlora-sft-merge' |
| max_new_tokens = 500 |
| top_p = 0.9 |
| temperature = 0.35 |
| repetition_penalty = 1.0 |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float16, |
| device_map='auto' |
| ) |
| model.eval() |
| model = model.to(device) |
| # 记录所有历史记录 |
| history_token_ids = tokenizer('<s>', return_tensors="pt").input_ids |
| # 输入模型的最大长度 |
| history_max_len = 1000 |
| user_input = input('User:') |
| while True: |
| user_input = '{}</s>'.format(user_input) |
| user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids |
| history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1) |
| model_input_ids = history_token_ids[:, -history_max_len:].to(device) |
| outputs = model.generate( |
| input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, |
| temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id |
| ) |
| model_input_ids_len = model_input_ids.size(1) |
| response_ids = outputs[:, model_input_ids_len:] |
| history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1) |
| response = tokenizer.batch_decode(response_ids) |
| print("Firefly:" + response[0].strip().replace('</s>', "")) |
| user_input = input('User:') |
| ``` |
|
|
|
|