import asyncio import base64 import json import os import shutil import tempfile import uuid from multiprocessing import cpu_count from typing import List, Optional, Tuple, Union import numpy as np import requests as url_requests from accelerate import Accelerator, DistributedType from tqdm import tqdm from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model try: from decord import VideoReader, cpu except ImportError: pass from dotenv import find_dotenv, load_dotenv from loguru import logger as eval_logger from openai import AsyncOpenAI from PIL import Image from lmms_eval.api.model import lmms from lmms_eval.mcp import MCPClient from lmms_eval.models.simple.openai_compatible import ( OpenAICompatible as OpenAICompatibleSimple, ) from lmms_eval.protocol import ChatMessages load_dotenv(verbose=True) @register_model("async_openai_compatible_chat") class AsyncOpenAIChat(lmms): is_simple = False def __init__( self, model_version: str = "grok-2-latest", base_url: str = None, api_key: str = None, timeout: int = 600, max_retries: int = 5, max_size_in_mb: int = 20, mcp_server_path: str = None, num_cpus: int = None, work_dir: str = None, fps: Optional[int] = None, nframes: Optional[int] = 64, max_frames: Optional[int] = 768, max_pixels: Optional[int] = 151200, min_pixels: Optional[int] = 28 * 28, is_qwen3_vl: bool = False, **kwargs, ) -> None: super().__init__() self.model_version = model_version self.timeout = timeout self.max_retries = max_retries self.max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image if num_cpus is None: self.num_cpus = cpu_count() // 2 else: self.num_cpus = num_cpus self.work_dir = work_dir if work_dir is not None else tempfile.mkdtemp() self.fps = fps self.nframes = nframes self.base_url = base_url if base_url is not None else os.getenv("OPENAI_API_BASE") self.api_key = api_key if api_key is not None else os.getenv("OPENAI_API_KEY") self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url, timeout=timeout) self.max_pixels = max_pixels self.min_pixels = min_pixels self.max_frames = max_frames self.is_qwen3_vl = is_qwen3_vl if mcp_server_path is not None: self.mcp_client = MCPClient(mcp_server_path) os.makedirs(self.work_dir, exist_ok=True) else: self.mcp_client = None accelerator = Accelerator() # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." if accelerator.num_processes > 1: assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." self.accelerator = accelerator if self.accelerator.is_local_main_process: eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes else: self.accelerator = accelerator self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes self.device = self.accelerator.device @property def model(self): # returns the model, unwrapping it if using Accelerate return self.client @property def batch_size(self): return self.batch_size_per_gpu @property def rank(self): return self._rank @property def world_size(self): return self._world_size def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: assert False, "TODO, not implemented" def generate_until_multi_round(self, requests) -> List[str]: raise NotImplementedError("TODO: Implement multi-round generation for LLaVAHF") async def maybe_forward_with_tool(self, request: Instance, idx: int): """ Forward the request to the OpenAI API, using tools if available. This method is designed to handle chat messages and tool calls in an asynchronous manner. It retrieves the chat messages from the request, prepares the payload, and sends it to the OpenAI API. If the response indicates that tool calls are needed, it will call the tools using the MCP client and continue the conversation until a final response is received. :param request: The request instance containing the chat messages and other parameters. :param idx: The index of the request in the batch. (Use to restore the original order of responses) """ ctx, doc_to_messages, gen_kwargs, doc_id, task, split = request.args chat_messages = doc_to_messages(self.task_dict[task][split][doc_id]) chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages}) video_kwargs = {"max_pixels": self.max_pixels, "min_pixels": self.min_pixels} if self.fps is not None: video_kwargs["fps"] = self.fps else: video_kwargs["nframes"] = self.nframes if self.max_frames is not None: video_kwargs["max_frames"] = self.max_frames if self.is_qwen3_vl: messages = chat_messages.to_qwen3_vl_openai_messages(video_kwargs) else: messages = chat_messages.to_openai_messages(video_kwargs) images, videos, audios = chat_messages.extract_media() if self.mcp_client is not None: for image_idx, image in enumerate(images): image_path = os.path.join(self.work_dir, f"{uuid.uuid4()}.jpg") image.save(image_path) messages[-1]["content"].append({"type": "text", "text": f"\nImage {image_idx} has image path: {image_path}"}) for video_idx, video in enumerate(videos): messages[-1]["content"].append({"type": "text", "text": f"\nVideo {video_idx} has video path: {video}"}) payload = {"messages": messages} payload["model"] = self.model_version all_response = "" if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 if "temperature" not in gen_kwargs: gen_kwargs["temperature"] = 0 if "top_p" not in gen_kwargs: gen_kwargs["top_p"] = None if "do_sample" not in gen_kwargs: gen_kwargs["do_sample"] = False # payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"] payload["max_tokens"] = gen_kwargs["max_new_tokens"] payload["temperature"] = gen_kwargs["temperature"] if self.mcp_client is not None: # get the function list from the MCP server functions = await self.mcp_client.get_function_list() payload["tools"] = functions payload["tool_choice"] = "auto" # or "auto" for automatic tool selection response = await self.client.chat.completions.create(**payload) last_response = response.choices[0].message.content # Sometimes asyncio return None, skip this case try: all_response += last_response except Exception as e: all_response += f"Error: {str(e)}" while response.choices[0].finish_reason == "tool_calls": messages.append({"role": "assistant", "content": last_response}) messages.append({"role": "assistant", "tool_calls": response.choices[0].message.tool_calls}) message = response.choices[0].message tool_messages = [] if message.tool_calls: eval_logger.debug("Calling tool with MCP client") for call in message.tool_calls: eval_logger.debug(f"Calling {call.function.name}...") result = await self.mcp_client.run_tool(call.function.name, eval(call.function.arguments)) all_response += f"{call.function.name} {call.function.arguments}" tool_messages.append({"role": "tool", "name": call.function.name, "content": []}) for content in result.content: tool_message = self.mcp_client.convert_result_to_openai_format(content) for content in tool_message: if content["type"] == "image_url": all_response += "" elif content["type"] == "text": all_response += content["text"] tool_messages[-1]["content"].extend(tool_message) all_response += "" response = await self.client.chat.completions.create( model=self.model_version, messages=messages + tool_messages, max_tokens=gen_kwargs["max_new_tokens"], temperature=gen_kwargs["temperature"], tools=functions, tool_choice="auto", ) last_response = response.choices[0].message.content try: all_response += last_response except Exception as e: all_response += str(e) self.add_request_response_to_cache(request, all_response) return all_response, idx def generate_until(self, requests) -> List[str]: self.load_cache() results, requests = self.get_response_from_cache(requests) async def run(): res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") sem = asyncio.Semaphore(self.num_cpus) async def _process(req, idx): async with sem: return await self.maybe_forward_with_tool(req, idx) tasks = [asyncio.create_task(_process(req, idx)) for idx, req in enumerate(requests)] for task in asyncio.as_completed(tasks): content, idx = await task res.append((content, idx)) pbar.update(1) pbar.close() return res eval_results = asyncio.run(run()) eval_results.sort(key=lambda x: x[1]) # Sort by index to restore original results = results + [content for content, _ in eval_results] if self.mcp_client is not None: shutil.rmtree(self.work_dir) return results