Spaces:
Running
Running
File size: 5,141 Bytes
31086ae d02622b 4331db7 d02622b 31086ae d02622b a8c3e2a d02622b 3389c47 b43eb86 d02622b 3389c47 d02622b 31086ae d02622b 982cb95 d02622b 31086ae d02622b 4331db7 b43eb86 d02622b b43eb86 d02622b 4331db7 b43eb86 4331db7 d02622b 31086ae d02622b b43eb86 a8c3e2a b43eb86 d02622b 4331db7 b43eb86 4331db7 31086ae 4331db7 31086ae b43eb86 4331db7 d02622b a8c3e2a d02622b 31086ae d02622b 3389c47 d02622b 4331db7 b43eb86 4331db7 a8c3e2a 982cb95 b43eb86 4331db7 982cb95 b43eb86 4331db7 d02622b 31086ae d02622b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import math
import uuid
from typing import Any, List, Optional
import asyncio
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
class VLLMWrapper(BaseLLMWrapper):
"""
Async inference backend based on vLLM.
"""
def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.9,
temperature: float = 0.6,
top_p: float = 1.0,
top_k: int = 5,
timeout: float = 600,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
try:
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
except ImportError as exc:
raise ImportError(
"VLLMWrapper requires vllm. Install it with: uv pip install vllm"
) from exc
self.SamplingParams = SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tensor_parallel_size=int(tensor_parallel_size),
gpu_memory_utilization=float(gpu_memory_utilization),
trust_remote_code=kwargs.get("trust_remote_code", True),
disable_log_stats=False,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.timeout = float(timeout)
self.tokenizer = self.engine.engine.tokenizer.tokenizer
def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
messages = history or []
messages.append({"role": "user", "content": prompt})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
async def _consume_generator(self, generator):
final_output = None
async for request_output in generator:
if request_output.finished:
final_output = request_output
break
return final_output
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
full_prompt = self._build_inputs(text, history)
request_id = f"graphgen_req_{uuid.uuid4()}"
sp = self.SamplingParams(
temperature=self.temperature if self.temperature >= 0 else 1.0,
top_p=self.top_p if self.top_p >= 0 else 1.0,
max_tokens=extra.get("max_new_tokens", 2048),
repetition_penalty=extra.get("repetition_penalty", 1.05),
)
try:
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
final_output = await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
)
if not final_output or not final_output.outputs:
return ""
result_text = final_output.outputs[0].text
return result_text
except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
await self.engine.abort(request_id)
raise
async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
full_prompt = self._build_inputs(text, history)
request_id = f"graphgen_topk_{uuid.uuid4()}"
sp = self.SamplingParams(
temperature=0,
max_tokens=1,
logprobs=self.top_k,
)
try:
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
final_output = await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
)
if (
not final_output
or not final_output.outputs
or not final_output.outputs[0].logprobs
):
return []
top_logprobs = final_output.outputs[0].logprobs[0]
candidate_tokens = []
for _, logprob_obj in top_logprobs.items():
tok_str = (
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
)
prob = float(math.exp(logprob_obj.logprob))
candidate_tokens.append(Token(tok_str, prob))
candidate_tokens.sort(key=lambda x: -x.prob)
if candidate_tokens:
main_token = Token(
text=candidate_tokens[0].text,
prob=candidate_tokens[0].prob,
top_candidates=candidate_tokens,
)
return [main_token]
return []
except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
await self.engine.abort(request_id)
raise
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
raise NotImplementedError(
"VLLMWrapper does not support per-token logprobs yet."
)
|