import importlib import importlib.util import os import subprocess import sys from threading import Lock, Thread import torch from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer _MODEL = None _PROCESSOR = None _MODEL_PATH = None _MODEL_LOCK = Lock() _FLASH_ATTN_LOCK = Lock() _FLASH_ATTN_PACKAGE = "flash_attn" _FLASH_ATTN_REQUIREMENT = os.getenv("FLASH_ATTN_REQUIREMENT", "flash-attn==2.8.3") def _get_attn_implementation(): return os.getenv("ATTN_IMPLEMENTATION", "sdpa") def _get_model_revision(): return os.getenv("MODEL_REVISION") def ensure_flash_attn_installed(): if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is not None: return with _FLASH_ATTN_LOCK: if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is not None: return install_cmd = [ sys.executable, "-m", "pip", "install", _FLASH_ATTN_REQUIREMENT, "--no-build-isolation", ] print(f"Installing {_FLASH_ATTN_REQUIREMENT} with --no-build-isolation...") subprocess.check_call(install_cmd, env=os.environ.copy()) importlib.invalidate_caches() if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is None: raise RuntimeError(f"Failed to import {_FLASH_ATTN_PACKAGE} after installation.") def _ensure_model_loaded(model_path): global _MODEL, _PROCESSOR, _MODEL_PATH if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path: return _MODEL, _PROCESSOR with _MODEL_LOCK: if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path: return _MODEL, _PROCESSOR ensure_flash_attn_installed() attn_implementation = _get_attn_implementation() revision = _get_model_revision() processor_kwargs = { "trust_remote_code": True, } if revision: processor_kwargs["revision"] = revision model_kwargs = { "trust_remote_code": True, "device_map": {"": "cuda:0"}, "torch_dtype": torch.bfloat16, "attn_implementation": attn_implementation, } if revision: model_kwargs["revision"] = revision _MODEL = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) _PROCESSOR = AutoProcessor.from_pretrained(model_path, **processor_kwargs) _MODEL_PATH = model_path return _MODEL, _PROCESSOR def preload_model(model_path): return _ensure_model_loaded(model_path) def _run_generation_stream(payload): model_path = payload["model_path"] model, processor = _ensure_model_loaded(model_path) inputs = processor( conversation=payload["conversation"], add_system_prompt=True, add_generation_prompt=True, return_tensors="pt", ) inputs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) generation_kwargs = { **inputs, **payload.get("generation_config", {}), } streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs["streamer"] = streamer generation_error = {} def _generation_worker(): try: with torch.inference_mode(): model.generate(**generation_kwargs) except Exception as exc: generation_error["exc"] = exc streamer.on_finalized_text("", stream_end=True) thread = Thread(target=_generation_worker, daemon=True) thread.start() for token in streamer: yield token if "exc" in generation_error: raise generation_error["exc"] class PenguinVLQwen3DirectClient(object): def __init__(self, model_path): self.model_path = model_path def submit(self, payload): return _run_generation_stream({ "model_path": self.model_path, "conversation": payload["conversation"], "generation_config": payload.get("generation_config", {}), })