github-actions[bot] commited on
Commit
4331db7
·
1 Parent(s): f851e18

Auto-sync from demo at Fri Jan 9 03:05:08 UTC 2026

Browse files
graphgen/bases/base_llm_wrapper.py CHANGED
@@ -26,11 +26,11 @@ class BaseLLMWrapper(abc.ABC):
26
  **kwargs: Any,
27
  ):
28
  self.system_prompt = system_prompt
29
- self.temperature = temperature
30
- self.max_tokens = max_tokens
31
- self.repetition_penalty = repetition_penalty
32
- self.top_p = top_p
33
- self.top_k = top_k
34
  self.tokenizer = tokenizer
35
 
36
  for k, v in kwargs.items():
 
26
  **kwargs: Any,
27
  ):
28
  self.system_prompt = system_prompt
29
+ self.temperature = float(temperature)
30
+ self.max_tokens = int(max_tokens)
31
+ self.repetition_penalty = float(repetition_penalty)
32
+ self.top_p = float(top_p)
33
+ self.top_k = int(top_k)
34
  self.tokenizer = tokenizer
35
 
36
  for k, v in kwargs.items():
graphgen/models/llm/local/vllm_wrapper.py CHANGED
@@ -1,6 +1,7 @@
1
  import math
2
  import uuid
3
  from typing import Any, List, Optional
 
4
 
5
  from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
6
  from graphgen.bases.datatypes import Token
@@ -19,12 +20,9 @@ class VLLMWrapper(BaseLLMWrapper):
19
  temperature: float = 0.6,
20
  top_p: float = 1.0,
21
  top_k: int = 5,
 
22
  **kwargs: Any,
23
  ):
24
- temperature = float(temperature)
25
- top_p = float(top_p)
26
- top_k = int(top_k)
27
-
28
  super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
29
  try:
30
  from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
@@ -43,6 +41,7 @@ class VLLMWrapper(BaseLLMWrapper):
43
  disable_log_stats=False,
44
  )
45
  self.engine = AsyncLLMEngine.from_engine_args(engine_args)
 
46
 
47
  @staticmethod
48
  def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
@@ -58,6 +57,12 @@ class VLLMWrapper(BaseLLMWrapper):
58
  lines.append(prompt)
59
  return "\n".join(lines)
60
 
 
 
 
 
 
 
61
  async def generate_answer(
62
  self, text: str, history: Optional[List[str]] = None, **extra: Any
63
  ) -> str:
@@ -72,14 +77,21 @@ class VLLMWrapper(BaseLLMWrapper):
72
 
73
  result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
74
 
75
- final_output = None
76
- async for request_output in result_generator:
77
- final_output = request_output
 
 
 
 
 
78
 
79
- if not final_output or not final_output.outputs:
80
- return ""
81
 
82
- return final_output.outputs[0].text
 
 
83
 
84
  async def generate_topk_per_token(
85
  self, text: str, history: Optional[List[str]] = None, **extra: Any
@@ -91,42 +103,47 @@ class VLLMWrapper(BaseLLMWrapper):
91
  temperature=0,
92
  max_tokens=1,
93
  logprobs=self.top_k,
94
- prompt_logprobs=1,
95
  )
96
 
97
  result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
98
 
99
- final_output = None
100
- async for request_output in result_generator:
101
- final_output = request_output
102
-
103
- if (
104
- not final_output
105
- or not final_output.outputs
106
- or not final_output.outputs[0].logprobs
107
- ):
108
- return []
109
-
110
- top_logprobs = final_output.outputs[0].logprobs[0]
111
-
112
- candidate_tokens = []
113
- for _, logprob_obj in top_logprobs.items():
114
- tok_str = (
115
- logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
116
  )
117
- prob = float(math.exp(logprob_obj.logprob))
118
- candidate_tokens.append(Token(tok_str, prob))
119
 
120
- candidate_tokens.sort(key=lambda x: -x.prob)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- if candidate_tokens:
123
- main_token = Token(
124
- text=candidate_tokens[0].text,
125
- prob=candidate_tokens[0].prob,
126
- top_candidates=candidate_tokens,
127
- )
128
- return [main_token]
129
- return []
130
 
131
  async def generate_inputs_prob(
132
  self, text: str, history: Optional[List[str]] = None, **extra: Any
 
1
  import math
2
  import uuid
3
  from typing import Any, List, Optional
4
+ import asyncio
5
 
6
  from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
7
  from graphgen.bases.datatypes import Token
 
20
  temperature: float = 0.6,
21
  top_p: float = 1.0,
22
  top_k: int = 5,
23
+ timeout: float = 300,
24
  **kwargs: Any,
25
  ):
 
 
 
 
26
  super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
27
  try:
28
  from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
 
41
  disable_log_stats=False,
42
  )
43
  self.engine = AsyncLLMEngine.from_engine_args(engine_args)
44
+ self.timeout = float(timeout)
45
 
46
  @staticmethod
47
  def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
 
57
  lines.append(prompt)
58
  return "\n".join(lines)
59
 
60
+ async def _consume_generator(self, generator):
61
+ final_output = None
62
+ async for request_output in generator:
63
+ final_output = request_output
64
+ return final_output
65
+
66
  async def generate_answer(
67
  self, text: str, history: Optional[List[str]] = None, **extra: Any
68
  ) -> str:
 
77
 
78
  result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
79
 
80
+ try:
81
+ final_output = await asyncio.wait_for(
82
+ self._consume_generator(result_generator),
83
+ timeout=self.timeout
84
+ )
85
+
86
+ if not final_output or not final_output.outputs:
87
+ return ""
88
 
89
+ result_text = final_output.outputs[0].text
90
+ return result_text
91
 
92
+ except (Exception, asyncio.CancelledError):
93
+ await self.engine.abort(request_id)
94
+ raise
95
 
96
  async def generate_topk_per_token(
97
  self, text: str, history: Optional[List[str]] = None, **extra: Any
 
103
  temperature=0,
104
  max_tokens=1,
105
  logprobs=self.top_k,
 
106
  )
107
 
108
  result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
109
 
110
+ try:
111
+ final_output = await asyncio.wait_for(
112
+ self._consume_generator(result_generator),
113
+ timeout=self.timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
 
 
115
 
116
+ if (
117
+ not final_output
118
+ or not final_output.outputs
119
+ or not final_output.outputs[0].logprobs
120
+ ):
121
+ return []
122
+
123
+ top_logprobs = final_output.outputs[0].logprobs[0]
124
+
125
+ candidate_tokens = []
126
+ for _, logprob_obj in top_logprobs.items():
127
+ tok_str = (
128
+ logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
129
+ )
130
+ prob = float(math.exp(logprob_obj.logprob))
131
+ candidate_tokens.append(Token(tok_str, prob))
132
+
133
+ candidate_tokens.sort(key=lambda x: -x.prob)
134
+
135
+ if candidate_tokens:
136
+ main_token = Token(
137
+ text=candidate_tokens[0].text,
138
+ prob=candidate_tokens[0].prob,
139
+ top_candidates=candidate_tokens,
140
+ )
141
+ return [main_token]
142
+ return []
143
 
144
+ except (Exception, asyncio.CancelledError):
145
+ await self.engine.abort(request_id)
146
+ raise
 
 
 
 
 
147
 
148
  async def generate_inputs_prob(
149
  self, text: str, history: Optional[List[str]] = None, **extra: Any