Penguin-VL / inference /server /direct_client.py
lkeab's picture
Update Space README and remove ZeroGPU wrapper
7821e0a verified
import importlib
import importlib.util
import os
import subprocess
import sys
from threading import Lock, Thread
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
_MODEL = None
_PROCESSOR = None
_MODEL_PATH = None
_MODEL_LOCK = Lock()
_FLASH_ATTN_LOCK = Lock()
_FLASH_ATTN_PACKAGE = "flash_attn"
_FLASH_ATTN_REQUIREMENT = os.getenv("FLASH_ATTN_REQUIREMENT", "flash-attn==2.8.3")
def _get_attn_implementation():
return os.getenv("ATTN_IMPLEMENTATION", "sdpa")
def _get_model_revision():
return os.getenv("MODEL_REVISION")
def ensure_flash_attn_installed():
if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is not None:
return
with _FLASH_ATTN_LOCK:
if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is not None:
return
install_cmd = [
sys.executable,
"-m",
"pip",
"install",
_FLASH_ATTN_REQUIREMENT,
"--no-build-isolation",
]
print(f"Installing {_FLASH_ATTN_REQUIREMENT} with --no-build-isolation...")
subprocess.check_call(install_cmd, env=os.environ.copy())
importlib.invalidate_caches()
if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is None:
raise RuntimeError(f"Failed to import {_FLASH_ATTN_PACKAGE} after installation.")
def _ensure_model_loaded(model_path):
global _MODEL, _PROCESSOR, _MODEL_PATH
if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path:
return _MODEL, _PROCESSOR
with _MODEL_LOCK:
if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path:
return _MODEL, _PROCESSOR
ensure_flash_attn_installed()
attn_implementation = _get_attn_implementation()
revision = _get_model_revision()
processor_kwargs = {
"trust_remote_code": True,
}
if revision:
processor_kwargs["revision"] = revision
model_kwargs = {
"trust_remote_code": True,
"device_map": {"": "cuda:0"},
"torch_dtype": torch.bfloat16,
"attn_implementation": attn_implementation,
}
if revision:
model_kwargs["revision"] = revision
_MODEL = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
_PROCESSOR = AutoProcessor.from_pretrained(model_path, **processor_kwargs)
_MODEL_PATH = model_path
return _MODEL, _PROCESSOR
def preload_model(model_path):
return _ensure_model_loaded(model_path)
def _run_generation_stream(payload):
model_path = payload["model_path"]
model, processor = _ensure_model_loaded(model_path)
inputs = processor(
conversation=payload["conversation"],
add_system_prompt=True,
add_generation_prompt=True,
return_tensors="pt",
)
inputs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
generation_kwargs = {
**inputs,
**payload.get("generation_config", {}),
}
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs["streamer"] = streamer
generation_error = {}
def _generation_worker():
try:
with torch.inference_mode():
model.generate(**generation_kwargs)
except Exception as exc:
generation_error["exc"] = exc
streamer.on_finalized_text("", stream_end=True)
thread = Thread(target=_generation_worker, daemon=True)
thread.start()
for token in streamer:
yield token
if "exc" in generation_error:
raise generation_error["exc"]
class PenguinVLQwen3DirectClient(object):
def __init__(self, model_path):
self.model_path = model_path
def submit(self, payload):
return _run_generation_stream({
"model_path": self.model_path,
"conversation": payload["conversation"],
"generation_config": payload.get("generation_config", {}),
})