Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
3389c47
1
Parent(s):
e358663
Auto-sync from demo at Mon Jan 5 10:28:58 UTC 2026
Browse files
graphgen/models/llm/local/vllm_wrapper.py
CHANGED
|
@@ -18,10 +18,14 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 18 |
gpu_memory_utilization: float = 0.9,
|
| 19 |
temperature: float = 0.6,
|
| 20 |
top_p: float = 1.0,
|
| 21 |
-
|
| 22 |
**kwargs: Any,
|
| 23 |
):
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
try:
|
| 26 |
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
| 27 |
except ImportError as exc:
|
|
@@ -39,9 +43,6 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 39 |
disable_log_stats=False,
|
| 40 |
)
|
| 41 |
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 42 |
-
self.temperature = temperature
|
| 43 |
-
self.top_p = top_p
|
| 44 |
-
self.topk = topk
|
| 45 |
|
| 46 |
@staticmethod
|
| 47 |
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
|
|
@@ -89,7 +90,7 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 89 |
sp = self.SamplingParams(
|
| 90 |
temperature=0,
|
| 91 |
max_tokens=1,
|
| 92 |
-
logprobs=self.
|
| 93 |
prompt_logprobs=1,
|
| 94 |
)
|
| 95 |
|
|
|
|
| 18 |
gpu_memory_utilization: float = 0.9,
|
| 19 |
temperature: float = 0.6,
|
| 20 |
top_p: float = 1.0,
|
| 21 |
+
top_k: int = 5,
|
| 22 |
**kwargs: Any,
|
| 23 |
):
|
| 24 |
+
temperature = float(temperature)
|
| 25 |
+
top_p = float(top_p)
|
| 26 |
+
top_k = int(top_k)
|
| 27 |
+
|
| 28 |
+
super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
|
| 29 |
try:
|
| 30 |
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
| 31 |
except ImportError as exc:
|
|
|
|
| 43 |
disable_log_stats=False,
|
| 44 |
)
|
| 45 |
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
@staticmethod
|
| 48 |
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
|
|
|
|
| 90 |
sp = self.SamplingParams(
|
| 91 |
temperature=0,
|
| 92 |
max_tokens=1,
|
| 93 |
+
logprobs=self.top_k,
|
| 94 |
prompt_logprobs=1,
|
| 95 |
)
|
| 96 |
|