| | |
| | import os |
| | from typing import List |
| |
|
| | os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
| |
|
| |
|
| | def infer_batch(engine: 'InferEngine', infer_requests: List['InferRequest']): |
| | resp_list = engine.infer(infer_requests) |
| | query0 = infer_requests[0].messages[0]['content'] |
| | query1 = infer_requests[1].messages[0]['content'] |
| | print(f'query0: {query0}') |
| | print(f'response0: {resp_list[0].choices[0].message.content}') |
| | print(f'query1: {query1}') |
| | print(f'response1: {resp_list[1].choices[0].message.content}') |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | from swift.llm import InferEngine, InferRequest, PtEngine, load_dataset, safe_snapshot_download, BaseArguments |
| | from swift.tuners import Swift |
| | adapter_path = safe_snapshot_download('swift/test_bert') |
| | args = BaseArguments.from_pretrained(adapter_path) |
| | args.max_length = 512 |
| | args.truncation_strategy = 'right' |
| | |
| | model, processor = args.get_model_processor() |
| | model = Swift.from_pretrained(model, adapter_path) |
| | template = args.get_template(processor) |
| | engine = PtEngine.from_model_template(model, template, max_batch_size=64) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | dataset = load_dataset(['DAMO_NLP/jd:cls#1000'], seed=42)[0] |
| | print(f'dataset: {dataset}') |
| | infer_requests = [InferRequest(messages=data['messages']) for data in dataset] |
| | infer_batch(engine, infer_requests) |
| |
|
| | infer_batch(engine, [ |
| | InferRequest(messages=[{ |
| | 'role': 'user', |
| | 'content': '今天天气真好呀' |
| | }]), |
| | InferRequest(messages=[{ |
| | 'role': 'user', |
| | 'content': '真倒霉' |
| | }]) |
| | ]) |
| |
|