File size: 5,141 Bytes
31086ae
 
d02622b
4331db7
d02622b
 
 
 
 
 
 
31086ae
d02622b
 
 
 
 
 
 
a8c3e2a
d02622b
3389c47
b43eb86
d02622b
 
3389c47
d02622b
 
 
 
31086ae
d02622b
 
 
 
 
 
982cb95
 
d02622b
31086ae
d02622b
 
4331db7
b43eb86
d02622b
b43eb86
 
 
 
 
 
 
 
 
d02622b
4331db7
 
 
b43eb86
 
 
4331db7
 
d02622b
 
 
 
31086ae
d02622b
 
b43eb86
 
a8c3e2a
b43eb86
d02622b
 
4331db7
b43eb86
4331db7
 
 
 
 
 
 
31086ae
4331db7
 
31086ae
b43eb86
4331db7
 
d02622b
 
 
a8c3e2a
d02622b
31086ae
 
d02622b
 
 
3389c47
d02622b
 
4331db7
b43eb86
4331db7
 
 
a8c3e2a
982cb95
b43eb86
4331db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982cb95
b43eb86
4331db7
 
d02622b
 
 
 
31086ae
 
d02622b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import math
import uuid
from typing import Any, List, Optional
import asyncio

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token


class VLLMWrapper(BaseLLMWrapper):
    """
    Async inference backend based on vLLM.
    """

    def __init__(
        self,
        model: str,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        temperature: float = 0.6,
        top_p: float = 1.0,
        top_k: int = 5,
        timeout: float = 600,
        **kwargs: Any,
    ):
        super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
        try:
            from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
        except ImportError as exc:
            raise ImportError(
                "VLLMWrapper requires vllm. Install it with: uv pip install vllm"
            ) from exc

        self.SamplingParams = SamplingParams

        engine_args = AsyncEngineArgs(
            model=model,
            tensor_parallel_size=int(tensor_parallel_size),
            gpu_memory_utilization=float(gpu_memory_utilization),
            trust_remote_code=kwargs.get("trust_remote_code", True),
            disable_log_stats=False,
        )
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)
        self.timeout = float(timeout)
        self.tokenizer = self.engine.engine.tokenizer.tokenizer

    def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
        messages = history or []
        messages.append({"role": "user", "content": prompt})

        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

    async def _consume_generator(self, generator):
        final_output = None
        async for request_output in generator:
            if request_output.finished:
                final_output = request_output
                break
        return final_output

    async def generate_answer(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> str:
        full_prompt = self._build_inputs(text, history)
        request_id = f"graphgen_req_{uuid.uuid4()}"

        sp = self.SamplingParams(
            temperature=self.temperature if self.temperature >= 0 else 1.0,
            top_p=self.top_p if self.top_p >= 0 else 1.0,
            max_tokens=extra.get("max_new_tokens", 2048),
            repetition_penalty=extra.get("repetition_penalty", 1.05),
        )

        try:
            result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
            final_output = await asyncio.wait_for(
                self._consume_generator(result_generator),
                timeout=self.timeout
            )

            if not final_output or not final_output.outputs:
                return ""

            result_text = final_output.outputs[0].text
            return result_text

        except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
            await self.engine.abort(request_id)
            raise

    async def generate_topk_per_token(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> List[Token]:
        full_prompt = self._build_inputs(text, history)
        request_id = f"graphgen_topk_{uuid.uuid4()}"

        sp = self.SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=self.top_k,
        )

        try:
            result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
            final_output = await asyncio.wait_for(
                self._consume_generator(result_generator),
                timeout=self.timeout
            )


            if (
                not final_output
                or not final_output.outputs
                or not final_output.outputs[0].logprobs
            ):
                return []

            top_logprobs = final_output.outputs[0].logprobs[0]

            candidate_tokens = []
            for _, logprob_obj in top_logprobs.items():
                tok_str = (
                    logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
                )
                prob = float(math.exp(logprob_obj.logprob))
                candidate_tokens.append(Token(tok_str, prob))

            candidate_tokens.sort(key=lambda x: -x.prob)

            if candidate_tokens:
                main_token = Token(
                    text=candidate_tokens[0].text,
                    prob=candidate_tokens[0].prob,
                    top_candidates=candidate_tokens,
                )
                return [main_token]
            return []

        except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
            await self.engine.abort(request_id)
            raise

    async def generate_inputs_prob(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> List[Token]:
        raise NotImplementedError(
            "VLLMWrapper does not support per-token logprobs yet."
        )