| """ |
| Example usage for the fine-tuned merged model with vLLM. |
| |
| 1) Start the server (from docs + project defaults): |
| |
| OMP_NUM_THREADS=1 \ |
| vllm serve outputs/mimic_qwen3vl_lora_8bit_5_merged \ |
| --host 0.0.0.0 \ |
| --port 8000 \ |
| --dtype bfloat16 \ |
| --limit-mm-per-prompt.video 0 |
| |
| 2) Run this client script: |
| |
| python3 vllm_inference.py --model outputs/mimic_qwen3vl_lora_8bit_5_merged |
| """ |
|
|
| import argparse |
| import base64 |
| import mimetypes |
| import os |
| import time |
| from pathlib import Path |
|
|
| from openai import OpenAI |
|
|
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "4" |
|
|
|
|
| DEFAULT_MODEL_PATH = "outputs/mimic_qwen3vl_lora_8bit_5_merged" |
| DEFAULT_BASE_URL = "http://127.0.0.1:8002/v1" |
| DEFAULT_SYSTEM_PROMPT_PATH = Path(__file__).with_name("new_system_prompt_new.txt") |
| DEFAULT_IMAGE_1 = Path( |
| "/home/dgxuser16/NTL/mccarthy/ahmad/cap/dataset/images_1/s50000230/7e962a95-d661c0db-4769286c-e150a106-fb9586c6.jpg" |
| ) |
| DEFAULT_IMAGE_2 = Path( |
| "/home/dgxuser16/NTL/mccarthy/ahmad/cap/dataset/images_1/s50000230/f605b192-2e612578-c5c95dc3-b9d6d13b-e0eee500.jpg" |
| ) |
|
|
|
|
| def image_to_data_url(image_path: Path) -> str: |
| if not image_path.exists(): |
| raise FileNotFoundError(f"Image not found: {image_path}") |
|
|
| mime_type, _ = mimetypes.guess_type(str(image_path)) |
| if mime_type is None: |
| mime_type = "application/octet-stream" |
|
|
| encoded = base64.b64encode(image_path.read_bytes()).decode("utf-8") |
| return f"data:{mime_type};base64,{encoded}" |
|
|
|
|
| def build_messages(system_prompt: str, image_1: Path, image_2: Path) -> list[dict]: |
| return [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "image_url", |
| "image_url": {"url": image_to_data_url(image_1)}, |
| }, |
| { |
| "type": "image_url", |
| "image_url": {"url": image_to_data_url(image_2)}, |
| }, |
| { |
| "type": "text", |
| "text": system_prompt, |
| }, |
| ], |
| } |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Run inference against a vLLM server for the fine-tuned Qwen3-VL model." |
| ) |
| parser.add_argument( |
| "--base-url", |
| default=DEFAULT_BASE_URL, |
| help="OpenAI-compatible vLLM base URL.", |
| ) |
| parser.add_argument( |
| "--model", |
| default=DEFAULT_MODEL_PATH, |
| help="Model identifier served by vLLM (use the same value passed to `vllm serve`).", |
| ) |
| parser.add_argument( |
| "--system-prompt-path", |
| type=Path, |
| default=DEFAULT_SYSTEM_PROMPT_PATH, |
| help="Path to prompt text file.", |
| ) |
| parser.add_argument( |
| "--image-1", |
| type=Path, |
| default=DEFAULT_IMAGE_1, |
| help="Path to first image.", |
| ) |
| parser.add_argument( |
| "--image-2", |
| type=Path, |
| default=DEFAULT_IMAGE_2, |
| help="Path to second image.", |
| ) |
| parser.add_argument( |
| "--max-tokens", |
| type=int, |
| default=2048, |
| help="Maximum generation tokens.", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.0, |
| help="Sampling temperature.", |
| ) |
| parser.add_argument( |
| "--timeout", |
| type=float, |
| default=3600, |
| help="Client timeout in seconds.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| if not args.system_prompt_path.exists(): |
| raise FileNotFoundError(f"Prompt file not found: {args.system_prompt_path}") |
|
|
| system_prompt = args.system_prompt_path.read_text(encoding="utf-8").strip() |
| messages = build_messages(system_prompt=system_prompt, image_1=args.image_1, image_2=args.image_2) |
|
|
| api_key = os.getenv("OPENAI_API_KEY", "EMPTY") |
| client = OpenAI(api_key=api_key, base_url=args.base_url, timeout=args.timeout) |
|
|
| start = time.perf_counter() |
| response = client.chat.completions.create( |
| model=args.model, |
| messages=messages, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| ) |
| elapsed = time.perf_counter() - start |
|
|
| output_text = response.choices[0].message.content |
|
|
| print(f"Model: {args.model}") |
| print(f"Latency (s): {elapsed:.3f}") |
|
|
| usage = response.usage |
| if usage is not None: |
| print(f"Prompt tokens: {usage.prompt_tokens}") |
| print(f"Completion tokens: {usage.completion_tokens}") |
| print(f"Total tokens: {usage.total_tokens}") |
|
|
| print("\n--- Generated Output ---") |
| print(output_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|