Spaces:
Running on L40S
Running on L40S
| import importlib | |
| import importlib.util | |
| import os | |
| import subprocess | |
| import sys | |
| from threading import Lock, Thread | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer | |
| _MODEL = None | |
| _PROCESSOR = None | |
| _MODEL_PATH = None | |
| _MODEL_LOCK = Lock() | |
| _FLASH_ATTN_LOCK = Lock() | |
| _FLASH_ATTN_PACKAGE = "flash_attn" | |
| _FLASH_ATTN_REQUIREMENT = os.getenv("FLASH_ATTN_REQUIREMENT", "flash-attn==2.8.3") | |
| def _get_attn_implementation(): | |
| return os.getenv("ATTN_IMPLEMENTATION", "sdpa") | |
| def _get_model_revision(): | |
| return os.getenv("MODEL_REVISION") | |
| def ensure_flash_attn_installed(): | |
| if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is not None: | |
| return | |
| with _FLASH_ATTN_LOCK: | |
| if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is not None: | |
| return | |
| install_cmd = [ | |
| sys.executable, | |
| "-m", | |
| "pip", | |
| "install", | |
| _FLASH_ATTN_REQUIREMENT, | |
| "--no-build-isolation", | |
| ] | |
| print(f"Installing {_FLASH_ATTN_REQUIREMENT} with --no-build-isolation...") | |
| subprocess.check_call(install_cmd, env=os.environ.copy()) | |
| importlib.invalidate_caches() | |
| if importlib.util.find_spec(_FLASH_ATTN_PACKAGE) is None: | |
| raise RuntimeError(f"Failed to import {_FLASH_ATTN_PACKAGE} after installation.") | |
| def _ensure_model_loaded(model_path): | |
| global _MODEL, _PROCESSOR, _MODEL_PATH | |
| if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path: | |
| return _MODEL, _PROCESSOR | |
| with _MODEL_LOCK: | |
| if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path: | |
| return _MODEL, _PROCESSOR | |
| ensure_flash_attn_installed() | |
| attn_implementation = _get_attn_implementation() | |
| revision = _get_model_revision() | |
| processor_kwargs = { | |
| "trust_remote_code": True, | |
| } | |
| if revision: | |
| processor_kwargs["revision"] = revision | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": {"": "cuda:0"}, | |
| "torch_dtype": torch.bfloat16, | |
| "attn_implementation": attn_implementation, | |
| } | |
| if revision: | |
| model_kwargs["revision"] = revision | |
| _MODEL = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) | |
| _PROCESSOR = AutoProcessor.from_pretrained(model_path, **processor_kwargs) | |
| _MODEL_PATH = model_path | |
| return _MODEL, _PROCESSOR | |
| def preload_model(model_path): | |
| return _ensure_model_loaded(model_path) | |
| def _run_generation_stream(payload): | |
| model_path = payload["model_path"] | |
| model, processor = _ensure_model_loaded(model_path) | |
| inputs = processor( | |
| conversation=payload["conversation"], | |
| add_system_prompt=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) | |
| generation_kwargs = { | |
| **inputs, | |
| **payload.get("generation_config", {}), | |
| } | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generation_kwargs["streamer"] = streamer | |
| generation_error = {} | |
| def _generation_worker(): | |
| try: | |
| with torch.inference_mode(): | |
| model.generate(**generation_kwargs) | |
| except Exception as exc: | |
| generation_error["exc"] = exc | |
| streamer.on_finalized_text("", stream_end=True) | |
| thread = Thread(target=_generation_worker, daemon=True) | |
| thread.start() | |
| for token in streamer: | |
| yield token | |
| if "exc" in generation_error: | |
| raise generation_error["exc"] | |
| class PenguinVLQwen3DirectClient(object): | |
| def __init__(self, model_path): | |
| self.model_path = model_path | |
| def submit(self, payload): | |
| return _run_generation_stream({ | |
| "model_path": self.model_path, | |
| "conversation": payload["conversation"], | |
| "generation_config": payload.get("generation_config", {}), | |
| }) | |