| def infer_hf(): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from modelscope import snapshot_download | |
| model_dir = snapshot_download('Qwen/Qwen2.5-7B-Instruct') | |
| adapter_dir = snapshot_download('swift/test_lora') | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_dir, torch_dtype='auto', device_map='auto', trust_remote_code=True) | |
| model = PeftModel.from_pretrained(model, adapter_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
| messages = [{ | |
| 'role': 'system', | |
| 'content': 'You are a helpful assistant.' | |
| }, { | |
| 'role': 'user', | |
| 'content': 'who are you?' | |
| }] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| model_inputs = tokenizer([text], return_tensors='pt', add_special_tokens=False).to(model.device) | |
| generated_ids = model.generate(**model_inputs, max_new_tokens=512, do_sample=False) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| print(f'response: {response}') | |
| return response | |
| def infer_swift(): | |
| from swift.llm import get_model_tokenizer, get_template, InferRequest, RequestConfig, PtEngine | |
| from modelscope import snapshot_download | |
| from swift.tuners import Swift | |
| model_dir = snapshot_download('Qwen/Qwen2.5-7B-Instruct') | |
| adapter_dir = snapshot_download('swift/test_lora') | |
| model, tokenizer = get_model_tokenizer(model_dir, device_map='auto') | |
| model = Swift.from_pretrained(model, adapter_dir) | |
| template = get_template(model.model_meta.template, tokenizer) | |
| engine = PtEngine.from_model_template(model, template) | |
| messages = [{ | |
| 'role': 'system', | |
| 'content': 'You are a helpful assistant.' | |
| }, { | |
| 'role': 'user', | |
| 'content': 'who are you?' | |
| }] | |
| request_config = RequestConfig(max_tokens=512, temperature=0) | |
| resp_list = engine.infer([InferRequest(messages=messages)], request_config=request_config) | |
| response = resp_list[0].choices[0].message.content | |
| print(f'response: {response}') | |
| return response | |
| if __name__ == '__main__': | |
| response = infer_hf() | |
| response2 = infer_swift() | |
| assert response == response2 | |