## Inference
We have trained a well-trained checkpoint through the `ocr-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it.

In [None]:
# import some libraries
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from swift.llm import (
 InferEngine, InferRequest, PtEngine, RequestConfig, get_template, load_dataset, load_image
)
from swift.utils import get_model_parameter_info, get_logger, seed_everything
logger = get_logger()
seed_everything(42)

In [2]:
# Hyperparameters for inference
last_model_checkpoint = 'output/checkpoint-xxx'

# model
model_id_or_path = 'Qwen/Qwen2-VL-2B-Instruct' # model_id or model_path
system = None
infer_backend = 'pt'

# dataset
dataset = ['AI-ModelScope/LaTeX_OCR#20000']
data_seed = 42
split_dataset_ratio = 0.01
num_proc = 4
strict = False

# generation_config
max_new_tokens = 512
temperature = 0
stream = True

In [None]:
# Get model and template, and load LoRA weights.
engine = PtEngine(model_id_or_path, adapters=[last_model_checkpoint])
template = get_template(engine.model_meta.template, engine.tokenizer, default_system=system)
# The default mode of the template is 'pt', so there is no need to make any changes.
# template.set_mode('pt')

model_parameter_info = get_model_parameter_info(engine.model)
logger.info(f'model_parameter_info: {model_parameter_info}')

In [None]:
# Due to the data_seed setting, the validation set here is the same as the validation set used during training.
_, val_dataset = load_dataset(dataset, split_dataset_ratio=split_dataset_ratio, num_proc=num_proc,
 strict=strict, seed=data_seed)
val_dataset = val_dataset.select(range(10)) # Take the first 10 items

In [None]:
# Streaming inference and save images from the validation set.
# The batch processing code can be found here: https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_mllm.py
def infer_stream(engine: InferEngine, infer_request: InferRequest):
 request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)
 gen_list = engine.infer([infer_request], request_config)
 query = infer_request.messages[0]['content']
 print(f'query: {query}\nresponse: ', end='')
 for resp in gen_list[0]:
 if resp is None:
 continue
 print(resp.choices[0].delta.content, end='', flush=True)
 print()

from IPython.display import display
os.makedirs('images', exist_ok=True)
for i, data in enumerate(val_dataset):
 image = data['images'][0]
 image = load_image(image['bytes'] or image['path'])
 image.save(f'images/{i}.png')
 display(image)
 infer_stream(engine, InferRequest(**data))
 print('-' * 50)