Spaces:
Running
Running
github-actions[bot] commited on
Commit ·
a84e551
1
Parent(s): c91a105
Auto-sync from demo at Tue Feb 10 08:19:41 UTC 2026
Browse files
graphgen/models/llm/local/vllm_wrapper.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
|
|
| 1 |
import math
|
| 2 |
import uuid
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
-
import asyncio
|
| 5 |
|
| 6 |
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 7 |
from graphgen.bases.datatypes import Token
|
|
@@ -43,6 +43,7 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 43 |
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 44 |
self.timeout = float(timeout)
|
| 45 |
self.tokenizer = self.engine.engine.tokenizer.tokenizer
|
|
|
|
| 46 |
|
| 47 |
def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
|
| 48 |
messages = history or []
|
|
@@ -51,7 +52,8 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 51 |
return self.tokenizer.apply_chat_template(
|
| 52 |
messages,
|
| 53 |
tokenize=False,
|
| 54 |
-
add_generation_prompt=True
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
async def _consume_generator(self, generator):
|
|
@@ -76,10 +78,11 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 76 |
)
|
| 77 |
|
| 78 |
try:
|
| 79 |
-
result_generator = self.engine.generate(
|
|
|
|
|
|
|
| 80 |
final_output = await asyncio.wait_for(
|
| 81 |
-
self._consume_generator(result_generator),
|
| 82 |
-
timeout=self.timeout
|
| 83 |
)
|
| 84 |
|
| 85 |
if not final_output or not final_output.outputs:
|
|
@@ -105,13 +108,13 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 105 |
)
|
| 106 |
|
| 107 |
try:
|
| 108 |
-
result_generator = self.engine.generate(
|
|
|
|
|
|
|
| 109 |
final_output = await asyncio.wait_for(
|
| 110 |
-
self._consume_generator(result_generator),
|
| 111 |
-
timeout=self.timeout
|
| 112 |
)
|
| 113 |
|
| 114 |
-
|
| 115 |
if (
|
| 116 |
not final_output
|
| 117 |
or not final_output.outputs
|
|
@@ -124,7 +127,9 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 124 |
candidate_tokens = []
|
| 125 |
for _, logprob_obj in top_logprobs.items():
|
| 126 |
tok_str = (
|
| 127 |
-
logprob_obj.decoded_token.strip()
|
|
|
|
|
|
|
| 128 |
)
|
| 129 |
prob = float(math.exp(logprob_obj.logprob))
|
| 130 |
candidate_tokens.append(Token(tok_str, prob))
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import math
|
| 3 |
import uuid
|
| 4 |
from typing import Any, List, Optional
|
|
|
|
| 5 |
|
| 6 |
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 7 |
from graphgen.bases.datatypes import Token
|
|
|
|
| 43 |
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 44 |
self.timeout = float(timeout)
|
| 45 |
self.tokenizer = self.engine.engine.tokenizer.tokenizer
|
| 46 |
+
self.enable_thinking = kwargs.get("enable_thinking", False)
|
| 47 |
|
| 48 |
def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
|
| 49 |
messages = history or []
|
|
|
|
| 52 |
return self.tokenizer.apply_chat_template(
|
| 53 |
messages,
|
| 54 |
tokenize=False,
|
| 55 |
+
add_generation_prompt=True,
|
| 56 |
+
enable_thinking=self.enable_thinking,
|
| 57 |
)
|
| 58 |
|
| 59 |
async def _consume_generator(self, generator):
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
try:
|
| 81 |
+
result_generator = self.engine.generate(
|
| 82 |
+
full_prompt, sp, request_id=request_id
|
| 83 |
+
)
|
| 84 |
final_output = await asyncio.wait_for(
|
| 85 |
+
self._consume_generator(result_generator), timeout=self.timeout
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
if not final_output or not final_output.outputs:
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
try:
|
| 111 |
+
result_generator = self.engine.generate(
|
| 112 |
+
full_prompt, sp, request_id=request_id
|
| 113 |
+
)
|
| 114 |
final_output = await asyncio.wait_for(
|
| 115 |
+
self._consume_generator(result_generator), timeout=self.timeout
|
|
|
|
| 116 |
)
|
| 117 |
|
|
|
|
| 118 |
if (
|
| 119 |
not final_output
|
| 120 |
or not final_output.outputs
|
|
|
|
| 127 |
candidate_tokens = []
|
| 128 |
for _, logprob_obj in top_logprobs.items():
|
| 129 |
tok_str = (
|
| 130 |
+
logprob_obj.decoded_token.strip()
|
| 131 |
+
if logprob_obj.decoded_token
|
| 132 |
+
else ""
|
| 133 |
)
|
| 134 |
prob = float(math.exp(logprob_obj.logprob))
|
| 135 |
candidate_tokens.append(Token(tok_str, prob))
|