| import os |
| from typing import Literal |
|
|
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
| def infer_multilora(infer_request: 'InferRequest', infer_backend: Literal['vllm', 'pt']): |
| |
| adapter_path = safe_snapshot_download('swift/test_lora') |
| adapter_path2 = safe_snapshot_download('swift/test_lora2') |
| args = BaseArguments.from_pretrained(adapter_path) |
| if infer_backend == 'pt': |
| engine = PtEngine(args.model) |
| elif infer_backend == 'vllm': |
| from swift.llm import VllmEngine |
| engine = VllmEngine(args.model, enable_lora=True, max_loras=1, max_lora_rank=16) |
| template = get_template(args.template, engine.processor, args.system) |
| request_config = RequestConfig(max_tokens=512, temperature=0) |
| adapter_request = AdapterRequest('lora1', adapter_path) |
| adapter_request2 = AdapterRequest('lora2', adapter_path2) |
|
|
| |
| resp_list = engine.infer([infer_request], request_config, template=template, adapter_request=adapter_request) |
| response = resp_list[0].choices[0].message.content |
| print(f'lora1-response: {response}') |
| |
| resp_list = engine.infer([infer_request], request_config) |
| response = resp_list[0].choices[0].message.content |
| print(f'response: {response}') |
| |
| resp_list = engine.infer([infer_request], request_config, template=template, adapter_request=adapter_request2) |
| response = resp_list[0].choices[0].message.content |
| print(f'lora2-response: {response}') |
|
|
|
|
| def infer_lora(infer_request: 'InferRequest'): |
| request_config = RequestConfig(max_tokens=512, temperature=0) |
| adapter_path = safe_snapshot_download('swift/test_lora') |
| args = BaseArguments.from_pretrained(adapter_path) |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| model, tokenizer = get_model_tokenizer(args.model) |
| model = Swift.from_pretrained(model, adapter_path) |
| template = get_template(args.template, tokenizer, args.system) |
| engine = PtEngine.from_model_template(model, template) |
|
|
| resp_list = engine.infer([infer_request], request_config) |
| response = resp_list[0].choices[0].message.content |
| print(f'lora-response: {response}') |
|
|
|
|
| if __name__ == '__main__': |
| from swift.llm import (PtEngine, RequestConfig, AdapterRequest, get_template, BaseArguments, InferRequest, |
| safe_snapshot_download, get_model_tokenizer) |
| from swift.tuners import Swift |
| infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}]) |
| |
| infer_multilora(infer_request, 'pt') |
|
|