File size: 4,669 Bytes
c292e01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
Example usage for the fine-tuned merged model with vLLM.
1) Start the server (from docs + project defaults):
OMP_NUM_THREADS=1 \
vllm serve outputs/mimic_qwen3vl_lora_8bit_5_merged \
--host 0.0.0.0 \
--port 8000 \
--dtype bfloat16 \
--limit-mm-per-prompt.video 0
2) Run this client script:
python3 vllm_inference.py --model outputs/mimic_qwen3vl_lora_8bit_5_merged
"""
import argparse
import base64
import mimetypes
import os
import time
from pathlib import Path
from openai import OpenAI
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
DEFAULT_MODEL_PATH = "outputs/mimic_qwen3vl_lora_8bit_5_merged"
DEFAULT_BASE_URL = "http://127.0.0.1:8002/v1"
DEFAULT_SYSTEM_PROMPT_PATH = Path(__file__).with_name("new_system_prompt_new.txt")
DEFAULT_IMAGE_1 = Path(
"/home/dgxuser16/NTL/mccarthy/ahmad/cap/dataset/images_1/s50000230/7e962a95-d661c0db-4769286c-e150a106-fb9586c6.jpg"
)
DEFAULT_IMAGE_2 = Path(
"/home/dgxuser16/NTL/mccarthy/ahmad/cap/dataset/images_1/s50000230/f605b192-2e612578-c5c95dc3-b9d6d13b-e0eee500.jpg"
)
def image_to_data_url(image_path: Path) -> str:
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
mime_type, _ = mimetypes.guess_type(str(image_path))
if mime_type is None:
mime_type = "application/octet-stream"
encoded = base64.b64encode(image_path.read_bytes()).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
def build_messages(system_prompt: str, image_1: Path, image_2: Path) -> list[dict]:
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_to_data_url(image_1)},
},
{
"type": "image_url",
"image_url": {"url": image_to_data_url(image_2)},
},
{
"type": "text",
"text": system_prompt,
},
],
}
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run inference against a vLLM server for the fine-tuned Qwen3-VL model."
)
parser.add_argument(
"--base-url",
default=DEFAULT_BASE_URL,
help="OpenAI-compatible vLLM base URL.",
)
parser.add_argument(
"--model",
default=DEFAULT_MODEL_PATH,
help="Model identifier served by vLLM (use the same value passed to `vllm serve`).",
)
parser.add_argument(
"--system-prompt-path",
type=Path,
default=DEFAULT_SYSTEM_PROMPT_PATH,
help="Path to prompt text file.",
)
parser.add_argument(
"--image-1",
type=Path,
default=DEFAULT_IMAGE_1,
help="Path to first image.",
)
parser.add_argument(
"--image-2",
type=Path,
default=DEFAULT_IMAGE_2,
help="Path to second image.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum generation tokens.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature.",
)
parser.add_argument(
"--timeout",
type=float,
default=3600,
help="Client timeout in seconds.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.system_prompt_path.exists():
raise FileNotFoundError(f"Prompt file not found: {args.system_prompt_path}")
system_prompt = args.system_prompt_path.read_text(encoding="utf-8").strip()
messages = build_messages(system_prompt=system_prompt, image_1=args.image_1, image_2=args.image_2)
api_key = os.getenv("OPENAI_API_KEY", "EMPTY")
client = OpenAI(api_key=api_key, base_url=args.base_url, timeout=args.timeout)
start = time.perf_counter()
response = client.chat.completions.create(
model=args.model,
messages=messages,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
elapsed = time.perf_counter() - start
output_text = response.choices[0].message.content
print(f"Model: {args.model}")
print(f"Latency (s): {elapsed:.3f}")
usage = response.usage
if usage is not None:
print(f"Prompt tokens: {usage.prompt_tokens}")
print(f"Completion tokens: {usage.completion_tokens}")
print(f"Total tokens: {usage.total_tokens}")
print("\n--- Generated Output ---")
print(output_text)
if __name__ == "__main__":
main()
|