File size: 3,278 Bytes
0cb9ad5
 
 
86630d8
0cb9ad5
 
 
 
 
 
 
 
 
 
 
86630d8
0cb9ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7821e0a
 
0cb9ad5
 
86630d8
0cb9ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
from threading import Lock, Thread

import spaces
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer


_MODEL = None
_PROCESSOR = None
_MODEL_PATH = None
_MODEL_LOCK = Lock()


def _get_attn_implementation():
    return os.getenv("ATTN_IMPLEMENTATION", "flash_attention_2")


def _get_model_revision():
    return os.getenv("MODEL_REVISION")


def _ensure_model_loaded(model_path):
    global _MODEL, _PROCESSOR, _MODEL_PATH

    if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path:
        return _MODEL, _PROCESSOR

    with _MODEL_LOCK:
        if _MODEL is not None and _PROCESSOR is not None and _MODEL_PATH == model_path:
            return _MODEL, _PROCESSOR

        attn_implementation = _get_attn_implementation()
        revision = _get_model_revision()

        processor_kwargs = {
            "trust_remote_code": True,
        }
        if revision:
            processor_kwargs["revision"] = revision

        model_kwargs = {
            "trust_remote_code": True,
            "device_map": {"": "cuda:0"},
            "torch_dtype": torch.bfloat16,
            "attn_implementation": attn_implementation,
        }
        if revision:
            model_kwargs["revision"] = revision

        _MODEL = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
        _PROCESSOR = AutoProcessor.from_pretrained(model_path, **processor_kwargs)
        _MODEL_PATH = model_path
        return _MODEL, _PROCESSOR


def preload_model(model_path):
    return _ensure_model_loaded(model_path)


@spaces.GPU(duration=120)
def _run_generation_stream(payload):
    model_path = payload["model_path"]
    model, processor = _ensure_model_loaded(model_path)

    inputs = processor(
        conversation=payload["conversation"],
        add_system_prompt=True,
        add_generation_prompt=True,
        return_tensors="pt",
    )
    inputs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    if "pixel_values" in inputs:
        inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)

    generation_kwargs = {
        **inputs,
        **payload.get("generation_config", {}),
    }
    streamer = TextIteratorStreamer(
        processor.tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )
    generation_kwargs["streamer"] = streamer

    generation_error = {}

    def _generation_worker():
        try:
            with torch.inference_mode():
                model.generate(**generation_kwargs)
        except Exception as exc:
            generation_error["exc"] = exc
            streamer.on_finalized_text("", stream_end=True)

    thread = Thread(target=_generation_worker, daemon=True)
    thread.start()

    for token in streamer:
        yield token

    if "exc" in generation_error:
        raise generation_error["exc"]


class PenguinVLQwen3DirectClient(object):

    def __init__(self, model_path):
        self.model_path = model_path

    def submit(self, payload):
        return _run_generation_stream({
            "model_path": self.model_path,
            "conversation": payload["conversation"],
            "generation_config": payload.get("generation_config", {}),
        })