injected_thinking / Inference /inference_demo_drivemllm.py
BechusRantus's picture
Upload folder using huggingface_hub
7134ce7 verified
# Copyright (c) Kangan Qian. All rights reserved.
# Authors: Kangan Qian (Tsinghua University, Xiaomi Corporation)
# Description: Interface for Qwen2.5-VL model inference with tool integration
import json
import time
import base64
import io
import sys
from typing import Callable, Any
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from scripts.tools.tool_libraries_simple import FuncAgent
from qwen_vl_utils import process_vision_info
def pil_to_base64(pil_image: Image.Image) -> str:
"""
Convert a PIL image object to a base64-encoded string.
Args:
pil_image (Image.Image): PIL image object to convert
Returns:
str: Base64-encoded string representation of the image
"""
try:
binary_stream = io.BytesIO()
pil_image.save(binary_stream, format="PNG")
binary_data = binary_stream.getvalue()
return base64.b64encode(binary_data).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Image to base64 conversion failed: {e}")
def inference_with_retry(
inference_func: Callable,
*args: Any,
max_retries: int = 3,
retry_delay: int = 3,
**kwargs: Any
) -> str:
"""
Execute an inference function with automatic retries on failure.
Args:
inference_func (Callable): Inference function to call
*args: Positional arguments for the inference function
max_retries (int): Maximum number of retry attempts
retry_delay (int): Delay between retry attempts in seconds
**kwargs: Keyword arguments for the inference function
Returns:
str: Output from the inference function
Raises:
RuntimeError: If maximum retries are exceeded without success
"""
retries = 0
while retries < max_retries:
try:
return inference_func(*args, **kwargs)
except Exception as e:
print(f"Inference error: {e}. Retry {retries+1}/{max_retries}...")
retries += 1
time.sleep(retry_delay)
raise RuntimeError(f"Inference failed after {max_retries} retries")
class Qwen2_5VLInterface:
def __init__(self, model_path: str) -> None:
"""
Initialize Qwen2.5-VL model interface
Args:
model_path (str): Path to pretrained model
"""
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
self.processor = AutoProcessor.from_pretrained(model_path)
def inference(self, pil_image: Image.Image, prompt: str, max_tokens: int = 4096) -> str:
"""
Perform inference using the Qwen2.5-VL model
Args:
pil_image (Image.Image): Input image
prompt (str): Text prompt for the model
max_tokens (int): Maximum number of tokens to generate
Returns:
str: Model output text
"""
# Convert image to base64 for model input
image_base64 = pil_to_base64(pil_image)
image_url = f"data:image;base64,{image_base64}"
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_url},
{"type": "text", "text": prompt},
],
}
]
# Process inputs
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Generate model output
generated_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def execute_tool_call(
func_agent: FuncAgent,
tool_name: str,
tool_args: dict,
verbose: bool = True
) -> dict:
"""
Execute a tool call using the function agent
Args:
func_agent (FuncAgent): Function agent instance
tool_name (str): Name of the tool to execute
tool_args (dict): Arguments for the tool
verbose (bool): Whether to print tool execution details
Returns:
dict: Tool response containing name, arguments, and prompt
"""
try:
tool_function = getattr(func_agent, tool_name)
except AttributeError:
print(f"Error: Tool '{tool_name}' not found")
return None
if not callable(tool_function):
print(f"Error: '{tool_name}' is not a callable function")
return None
try:
tool_prompt, tool_result_data = tool_function(**tool_args)
except Exception as e:
print(f"Error executing tool '{tool_name}': {e}")
return None
if tool_prompt is None:
tool_prompt = ""
tool_response = {
"name": tool_name,
"args": tool_args,
"prompt": tool_prompt,
}
if verbose:
print(f"Tool: {tool_name}")
print(f"Arguments: {tool_args}")
print(f"Prompt: {tool_prompt}")
return tool_response
def run_chat_model_inference(
image_path: str,
prompt: str,
model_path: str = "/path/to/model/checkpoint"
) -> str:
"""
Run inference using the chat model
Args:
image_path (str): Path to input image file
prompt (str): Text prompt for the model
model_path (str): Path to model checkpoint
Returns:
str: Model output text
"""
image = Image.open(image_path)
model_interface = Qwen2_5VLInterface(model_path)
return inference_with_retry(
model_interface.inference,
image,
prompt,
max_retries=3,
retry_delay=3
)
def main():
"""Main function to process JSON data and run model inference"""
# Initialize function agent
func_agent = FuncAgent()
# Load JSON data
json_file = "./Inference/inference_demo_data_drivemllm.json"
with open(json_file, "r", encoding="utf-8") as file:
json_data = json.load(file)
# Process each sample in the JSON data
for sample in json_data:
image_path = sample['image'][0]
tool_chain = sample['tool_result']
system_prompt = sample['system_prompts']
question = sample['question']
# Build tool prompt from tool chain
tool_prompt = ""
for tool_node in tool_chain:
tool_name = tool_node['name']
tool_args = tool_node['args']
tool_response = execute_tool_call(func_agent, tool_name, tool_args)
if tool_response:
tool_prompt += tool_response['prompt']
# Construct full prompt for model inference
full_prompt = f"{system_prompt}\n{question}\nTool results:{tool_prompt}"
# Run model inference
model_output = run_chat_model_inference(image_path, full_prompt, model_path="pretrained_model/AgentThink")
print("Model output:", model_output)
if __name__ == "__main__":
main()