File size: 4,309 Bytes
cb2428f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference\n",
"We have trained a well-trained checkpoint through the `ocr-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import some libraries\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
"\n",
"from swift.llm import (\n",
" InferEngine, InferRequest, PtEngine, RequestConfig, get_template, load_dataset, load_image\n",
")\n",
"from swift.utils import get_model_parameter_info, get_logger, seed_everything\n",
"logger = get_logger()\n",
"seed_everything(42)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters for inference\n",
"last_model_checkpoint = 'output/checkpoint-xxx'\n",
"\n",
"# model\n",
"model_id_or_path = 'Qwen/Qwen2-VL-2B-Instruct' # model_id or model_path\n",
"system = None\n",
"infer_backend = 'pt'\n",
"\n",
"# dataset\n",
"dataset = ['AI-ModelScope/LaTeX_OCR#20000']\n",
"data_seed = 42\n",
"split_dataset_ratio = 0.01\n",
"num_proc = 4\n",
"strict = False\n",
"\n",
"# generation_config\n",
"max_new_tokens = 512\n",
"temperature = 0\n",
"stream = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get model and template, and load LoRA weights.\n",
"engine = PtEngine(model_id_or_path, adapters=[last_model_checkpoint])\n",
"template = get_template(engine.model_meta.template, engine.tokenizer, default_system=system)\n",
"# The default mode of the template is 'pt', so there is no need to make any changes.\n",
"# template.set_mode('pt')\n",
"\n",
"model_parameter_info = get_model_parameter_info(engine.model)\n",
"logger.info(f'model_parameter_info: {model_parameter_info}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Due to the data_seed setting, the validation set here is the same as the validation set used during training.\n",
"_, val_dataset = load_dataset(dataset, split_dataset_ratio=split_dataset_ratio, num_proc=num_proc,\n",
" strict=strict, seed=data_seed)\n",
"val_dataset = val_dataset.select(range(10)) # Take the first 10 items"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Streaming inference and save images from the validation set.\n",
"# The batch processing code can be found here: https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_mllm.py\n",
"def infer_stream(engine: InferEngine, infer_request: InferRequest):\n",
" request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)\n",
" gen_list = engine.infer([infer_request], request_config)\n",
" query = infer_request.messages[0]['content']\n",
" print(f'query: {query}\\nresponse: ', end='')\n",
" for resp in gen_list[0]:\n",
" if resp is None:\n",
" continue\n",
" print(resp.choices[0].delta.content, end='', flush=True)\n",
" print()\n",
"\n",
"from IPython.display import display\n",
"os.makedirs('images', exist_ok=True)\n",
"for i, data in enumerate(val_dataset):\n",
" image = data['images'][0]\n",
" image = load_image(image['bytes'] or image['path'])\n",
" image.save(f'images/{i}.png')\n",
" display(image)\n",
" infer_stream(engine, InferRequest(**data))\n",
" print('-' * 50)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test_py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|