github-actions[bot] commited on
Commit
b43eb86
·
1 Parent(s): ad9c5d9

Auto-sync from demo at Wed Jan 14 04:31:41 UTC 2026

Browse files
graphgen/models/llm/local/vllm_wrapper.py CHANGED
@@ -20,7 +20,7 @@ class VLLMWrapper(BaseLLMWrapper):
20
  temperature: float = 0.6,
21
  top_p: float = 1.0,
22
  top_k: int = 5,
23
- timeout: float = 300,
24
  **kwargs: Any,
25
  ):
26
  super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
@@ -42,25 +42,24 @@ class VLLMWrapper(BaseLLMWrapper):
42
  )
43
  self.engine = AsyncLLMEngine.from_engine_args(engine_args)
44
  self.timeout = float(timeout)
 
45
 
46
- @staticmethod
47
- def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
48
- msgs = history or []
49
- lines = []
50
- for m in msgs:
51
- if isinstance(m, dict):
52
- role = m.get("role", "")
53
- content = m.get("content", "")
54
- lines.append(f"{role}: {content}")
55
- else:
56
- lines.append(str(m))
57
- lines.append(prompt)
58
- return "\n".join(lines)
59
 
60
  async def _consume_generator(self, generator):
61
  final_output = None
62
  async for request_output in generator:
63
- final_output = request_output
 
 
64
  return final_output
65
 
66
  async def generate_answer(
@@ -70,14 +69,14 @@ class VLLMWrapper(BaseLLMWrapper):
70
  request_id = f"graphgen_req_{uuid.uuid4()}"
71
 
72
  sp = self.SamplingParams(
73
- temperature=self.temperature if self.temperature > 0 else 1.0,
74
- top_p=self.top_p if self.temperature > 0 else 1.0,
75
  max_tokens=extra.get("max_new_tokens", 2048),
 
76
  )
77
 
78
- result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
79
-
80
  try:
 
81
  final_output = await asyncio.wait_for(
82
  self._consume_generator(result_generator),
83
  timeout=self.timeout
@@ -89,7 +88,7 @@ class VLLMWrapper(BaseLLMWrapper):
89
  result_text = final_output.outputs[0].text
90
  return result_text
91
 
92
- except (Exception, asyncio.CancelledError):
93
  await self.engine.abort(request_id)
94
  raise
95
 
@@ -105,14 +104,14 @@ class VLLMWrapper(BaseLLMWrapper):
105
  logprobs=self.top_k,
106
  )
107
 
108
- result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
109
-
110
  try:
 
111
  final_output = await asyncio.wait_for(
112
  self._consume_generator(result_generator),
113
  timeout=self.timeout
114
  )
115
 
 
116
  if (
117
  not final_output
118
  or not final_output.outputs
@@ -141,7 +140,7 @@ class VLLMWrapper(BaseLLMWrapper):
141
  return [main_token]
142
  return []
143
 
144
- except (Exception, asyncio.CancelledError):
145
  await self.engine.abort(request_id)
146
  raise
147
 
 
20
  temperature: float = 0.6,
21
  top_p: float = 1.0,
22
  top_k: int = 5,
23
+ timeout: float = 600,
24
  **kwargs: Any,
25
  ):
26
  super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
 
42
  )
43
  self.engine = AsyncLLMEngine.from_engine_args(engine_args)
44
  self.timeout = float(timeout)
45
+ self.tokenizer = self.engine.engine.tokenizer.tokenizer
46
 
47
+ def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
48
+ messages = history or []
49
+ messages.append({"role": "user", "content": prompt})
50
+
51
+ return self.tokenizer.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
 
 
 
 
56
 
57
  async def _consume_generator(self, generator):
58
  final_output = None
59
  async for request_output in generator:
60
+ if request_output.finished:
61
+ final_output = request_output
62
+ break
63
  return final_output
64
 
65
  async def generate_answer(
 
69
  request_id = f"graphgen_req_{uuid.uuid4()}"
70
 
71
  sp = self.SamplingParams(
72
+ temperature=self.temperature if self.temperature >= 0 else 1.0,
73
+ top_p=self.top_p if self.top_p >= 0 else 1.0,
74
  max_tokens=extra.get("max_new_tokens", 2048),
75
+ repetition_penalty=extra.get("repetition_penalty", 1.05),
76
  )
77
 
 
 
78
  try:
79
+ result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
80
  final_output = await asyncio.wait_for(
81
  self._consume_generator(result_generator),
82
  timeout=self.timeout
 
88
  result_text = final_output.outputs[0].text
89
  return result_text
90
 
91
+ except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
92
  await self.engine.abort(request_id)
93
  raise
94
 
 
104
  logprobs=self.top_k,
105
  )
106
 
 
 
107
  try:
108
+ result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
109
  final_output = await asyncio.wait_for(
110
  self._consume_generator(result_generator),
111
  timeout=self.timeout
112
  )
113
 
114
+
115
  if (
116
  not final_output
117
  or not final_output.outputs
 
140
  return [main_token]
141
  return []
142
 
143
+ except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
144
  await self.engine.abort(request_id)
145
  raise
146