File size: 4,309 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "We have trained a well-trained checkpoint through the `ocr-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import some libraries\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "from swift.llm import (\n",
    "    InferEngine, InferRequest, PtEngine, RequestConfig, get_template, load_dataset, load_image\n",
    ")\n",
    "from swift.utils import get_model_parameter_info, get_logger, seed_everything\n",
    "logger = get_logger()\n",
    "seed_everything(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters for inference\n",
    "last_model_checkpoint = 'output/checkpoint-xxx'\n",
    "\n",
    "# model\n",
    "model_id_or_path = 'Qwen/Qwen2-VL-2B-Instruct'  # model_id or model_path\n",
    "system = None\n",
    "infer_backend = 'pt'\n",
    "\n",
    "# dataset\n",
    "dataset = ['AI-ModelScope/LaTeX_OCR#20000']\n",
    "data_seed = 42\n",
    "split_dataset_ratio = 0.01\n",
    "num_proc = 4\n",
    "strict = False\n",
    "\n",
    "# generation_config\n",
    "max_new_tokens = 512\n",
    "temperature = 0\n",
    "stream = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get model and template, and load LoRA weights.\n",
    "engine = PtEngine(model_id_or_path, adapters=[last_model_checkpoint])\n",
    "template = get_template(engine.model_meta.template, engine.tokenizer, default_system=system)\n",
    "# The default mode of the template is 'pt', so there is no need to make any changes.\n",
    "# template.set_mode('pt')\n",
    "\n",
    "model_parameter_info = get_model_parameter_info(engine.model)\n",
    "logger.info(f'model_parameter_info: {model_parameter_info}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Due to the data_seed setting, the validation set here is the same as the validation set used during training.\n",
    "_, val_dataset = load_dataset(dataset, split_dataset_ratio=split_dataset_ratio, num_proc=num_proc,\n",
    "                              strict=strict, seed=data_seed)\n",
    "val_dataset = val_dataset.select(range(10))  # Take the first 10 items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Streaming inference and save images from the validation set.\n",
    "# The batch processing code can be found here: https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_mllm.py\n",
    "def infer_stream(engine: InferEngine, infer_request: InferRequest):\n",
    "    request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)\n",
    "    gen_list = engine.infer([infer_request], request_config)\n",
    "    query = infer_request.messages[0]['content']\n",
    "    print(f'query: {query}\\nresponse: ', end='')\n",
    "    for resp in gen_list[0]:\n",
    "        if resp is None:\n",
    "            continue\n",
    "        print(resp.choices[0].delta.content, end='', flush=True)\n",
    "    print()\n",
    "\n",
    "from IPython.display import display\n",
    "os.makedirs('images', exist_ok=True)\n",
    "for i, data in enumerate(val_dataset):\n",
    "    image = data['images'][0]\n",
    "    image = load_image(image['bytes'] or image['path'])\n",
    "    image.save(f'images/{i}.png')\n",
    "    display(image)\n",
    "    infer_stream(engine, InferRequest(**data))\n",
    "    print('-' * 50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test_py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}