Spaces:
Running on Zero
Running on Zero
File size: 3,278 Bytes
0cb9ad5 86630d8 0cb9ad5 86630d8 0cb9ad5 7821e0a 0cb9ad5 86630d8 0cb9ad5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import os
from threading import Lock, Thread
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
_MODEL = None
_PROCESSOR = None
_MODEL_PATH = None
_MODEL_LOCK = Lock()
def _get_attn_implementation():
return os.getenv("ATTN_IMPLEMENTATION", "flash_attention_2")
def _get_model_revision():
return os.getenv("MODEL_REVISION")
def _ensure_model_loaded(model_path):
global _MODEL, _PROCESSOR, _MODEL_PATH
if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path:
return _MODEL, _PROCESSOR
with _MODEL_LOCK:
if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path:
return _MODEL, _PROCESSOR
attn_implementation = _get_attn_implementation()
revision = _get_model_revision()
processor_kwargs = {
"trust_remote_code": True,
}
if revision:
processor_kwargs["revision"] = revision
model_kwargs = {
"trust_remote_code": True,
"device_map": {"": "cuda:0"},
"torch_dtype": torch.bfloat16,
"attn_implementation": attn_implementation,
}
if revision:
model_kwargs["revision"] = revision
_MODEL = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
_PROCESSOR = AutoProcessor.from_pretrained(model_path, **processor_kwargs)
_MODEL_PATH = model_path
return _MODEL, _PROCESSOR
def preload_model(model_path):
return _ensure_model_loaded(model_path)
@spaces.GPU(duration=120)
def _run_generation_stream(payload):
model_path = payload["model_path"]
model, processor = _ensure_model_loaded(model_path)
inputs = processor(
conversation=payload["conversation"],
add_system_prompt=True,
add_generation_prompt=True,
return_tensors="pt",
)
inputs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
generation_kwargs = {
**inputs,
**payload.get("generation_config", {}),
}
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs["streamer"] = streamer
generation_error = {}
def _generation_worker():
try:
with torch.inference_mode():
model.generate(**generation_kwargs)
except Exception as exc:
generation_error["exc"] = exc
streamer.on_finalized_text("", stream_end=True)
thread = Thread(target=_generation_worker, daemon=True)
thread.start()
for token in streamer:
yield token
if "exc" in generation_error:
raise generation_error["exc"]
class PenguinVLQwen3DirectClient(object):
def __init__(self, model_path):
self.model_path = model_path
def submit(self, payload):
return _run_generation_stream({
"model_path": self.model_path,
"conversation": payload["conversation"],
"generation_config": payload.get("generation_config", {}),
})
|